In [None]:
from model import *
from hparams import *
from dataset import  get_data_loader, get_data_loader_michigan

%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

In [None]:
#new jupyter notebook
from tqdm import tqdm

train_loader = get_data_loader(split='train', args=Hparams.args)
for data in tqdm(train_loader):
    mel_spectrogram, yin, pyin = data#move_data_to_device(data, 'cpu')
    print(mel_spectrogram.shape)
    print(yin.shape)
    print(pyin.shape)
    #assert list(x.shape) == [8, 250, 256]  # shape in [B, T, D],
                                # i.e., [Batch size, num of frame per sample, spectrogram feature dimension]
    #assert list(onset.shape) == list(offset.shape) == list(octave.shape) == list(pitch_class.shape) == [8, 250]
    break
print('Congrats!')

In [None]:
#For the Michigan dataset
train_ds, test_ds, data_loader_train, data_loader_test = get_data_loader_michigan(args=Hparams_michigan.args, test_size=0.2)

In [None]:
from tqdm import tqdm
# to plot
# train_ds.plot_item(0)

# data loading
for data in tqdm(data_loader_train):
    mel_spectrogram_normalised_log_scale_torch, yin_normalised_torch, pyin_normalised_torch, word, toneclass = data
    print(f"(Batch, feature)")
    print(f"Spectrogram: {mel_spectrogram_normalised_log_scale_torch.shape} {type(mel_spectrogram_normalised_log_scale_torch)}")
    print(f"Yin: {yin_normalised_torch.shape} {type(yin_normalised_torch)}")
    print(f"Pyin: {pyin_normalised_torch.shape} {type(pyin_normalised_torch)}")
    print(f"Word: {len(word)} {type(word)}")
    print(f"Toneclass: {toneclass.shape} {type(toneclass)} {toneclass}")
    break
print('Congrats!')

In [None]:
import torch

from sklearn.metrics import accuracy_score

class Metrics:
    def __init__(self):
        self.buffer = {}

    def update(self, out, tgt, loss):
        with torch.no_grad():
            out = out.argmax(dim=1)
            out = torch.flatten(out)
            tgt = torch.flatten(tgt)

            acc = accuracy_score(tgt.cpu(), out.cpu())

            batch_metric = {
                'loss': loss.item(),
                'accuracy': acc,
            }

            for k in batch_metric:
                if k in self.buffer:
                    self.buffer[k].append(batch_metric[k])
                else:
                    self.buffer[k] = [batch_metric[k]]

    def get_value(self):
        for k in self.buffer:
            self.buffer[k] = sum(self.buffer[k]) / len(self.buffer[k])
        ret = self.buffer
        self.buffer = {}
        return ret

In [None]:
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from dataset import get_data_loader, move_data_to_device


def fit(model, args, learning_params):
    # Set paths
    save_model_dir = f"{args['save_model_dir']}{model.feat_dim}_lr-{learning_params['lr']}"
    if not os.path.exists(save_model_dir):
        os.mkdir(save_model_dir)

    model.to(args['device'])
    optimizer = optim.AdamW(model.parameters(), lr=learning_params['lr'])
    loss_func = nn.CrossEntropyLoss()
    metric = Metrics()

    # Start training
    print('Start training...')
    start_time = time.time()
    best_model_id = -1
    min_valid_loss = 10000
    prev_loss = 10000
    threshold = 1e-6

    for epoch in range(1, learning_params['epoch'] + 1):
        model.train()
        
        # Train
        pbar = tqdm(data_loader_train)
        for batch_idx, batch in enumerate(pbar):
            mel_spectrogram_normalised_log_scale_torch, yin_normalised_torch, pyin_normalised_torch, word, tone_class = batch
            tone_class -= 1 # 0-index
        
            x = mel_spectrogram_normalised_log_scale_torch.to(args['device'])
            x = x[:, None, :, :]
            tgt = tone_class.to(args['device'])
            out = model(x)
            loss = loss_func(out, tgt)
            metric.update(out, tgt, loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description('Epoch {}, Loss: {:.4f}'.format(epoch, loss.item()))
        metric_train = metric.get_value()

        # Validation
        model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(data_loader_test):
                mel_spectrogram_normalised_log_scale_torch, yin_normalised_torch, pyin_normalised_torch, word, tone_class = batch
                tone_class -= 1 # 0-index

                x = mel_spectrogram_normalised_log_scale_torch.to(args['device'])
                x = x[:, None, :, :]
                tgt = tone_class.to(args['device'])
                out = model(x)
                loss = loss_func(out, tgt)
                metric.update(out, tgt, loss)
        metric_test = metric.get_value()

        # Logging
        print('[Epoch {:02d}], Train Loss: {:.5f}, Valid Loss {:.5f}, Time {:.2f}s'.format(
            epoch, metric_train['loss'], metric_test['loss'], time.time() - start_time,
        ))
        print('Split Train Loss, Accuracy: Loss {:.4f} | Accuracy {:.4f}'.format(
            metric_train['loss'],
            metric_train['accuracy']
        ))
        print('Split Test Loss, Accuracy: Loss {:.4f} | Accuracy {:.4f}'.format(
            metric_test['loss'],
            metric_test['accuracy']
        ))

        # Save the best model
        if metric_test['loss'] < min_valid_loss:
            min_valid_loss = metric_test['loss']
            best_model_id = epoch

            save_dict = model.state_dict()
            target_model_path = save_model_dir + '/best_model.pth'
            torch.save(save_dict, target_model_path)

        if abs(metric_test['loss'] - prev_loss) < threshold:
            break

        prev_loss = metric_test['loss']

    print('Training done in {:.1f} minutes.'.format((time.time() - start_time) / 60))
    return best_model_id

In [None]:
# Set learning params
learning_params = {
    'epoch': 10,
    'lr': 1e-3,
}

model = ToneEval_Base(input_shape=(1, 128, 75))
fit(model, args=Hparams.args, learning_params=learning_params)