In [12]:
import os
import time

import torch 
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchaudio

import matplotlib.pyplot as plt
import librosa
import numpy as np

from model import SingleExtractor
from loss import TripletLoss
from dataPrep import MTATDataset

In [2]:
train_dataset = MTATDataset(pos_dir='../Data/spectrogram_pos',\
                            neg_dir='../Data/spectrogram_pos',\
                            negative_sample_size=4)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [13]:
class Trainer(object):
    def __init__(self, dataloader, negative_sample_size=4, n_epochs=500, loss_mode='cosine', device='cpu'):
        self.model = SingleExtractor(conv_channels=128,
                                     sample_rate=16000,
                                     n_fft=513,
                                     n_harmonic=6,
                                     semitone_scale=2,
                                     learn_bw='only_Q').to(device)
        self.device = device
        self.negative_sample_size = negative_sample_size
        self.n_epochs = n_epochs
        self.criterion = TripletLoss(mode=loss_mode)
        self.optimizer = Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
        self.current_optimizer = 'adam'
        self.drop_counter = 0
        self.trianing_loss = []
        self.best_train_loss = 100
        self.model_save_path = 'checkpoints'
        self.dataloader = dataloader
         
    def optimizerScheduler(self):
        # Adam to sgd
        if self.current_optimizer == 'adam' and self.drop_counter == 60:
            self.optimizer = SGD(self.model.parameters(), 1e-3, momentum=0.9, weight_decay=0.0001, nesterov=True)
            self.current_optimizer = 'sgd_1'
            self.drop_counter = 0
            print('sgd 1e-3')
        # First drop
        elif self.current_optimizer == 'sgd_1' and self.drop_counter == 20:
            for pg in self.optimizer.param_groups:
                pg['lr'] = 1e-4
            self.current_optimizer = 'sgd_2'
            self.drop_counter = 0
            print('sgd 1e-4')
        # Second drop
        elif self.current_optimizer == 'sgd_2' and self.drop_counter == 20:
            for pg in self.optimizer.param_groups:
                pg['lr'] = 1e-5
            self.current_optimizer = 'sgd_3'
            print('sgd 1e-5')

            
    def train(self):
        t0 = time.time()
        for epoch in range(self.n_epochs):
            self.drop_counter += 1
            self.model.train()
            epoch_loss = []
            for i, (anchor, pos, negs) in enumerate(self.dataloader):
                self.optimizer.zero_grad()
                
                anchor = Variable(anchor).to(self.device)
                pos = Variable(pos).to(self.device)
                # shape: (1, negative_sample_size, x, x) => (negative_sample_size, x, x)
                # cannot handle when batch_size is not 1
                negs = negs.squeeze(0)
                negs = Variable(negs).to(self.device)
                
                # Feed tensors into the Siamese harmonic network
                ha = self.model(anchor)
                hp = self.model(pos)
                hn = self.model(negs)
                
                # Compute triplet loss
                loss = self.criterion(ha, hp, hn)
                epoch_loss.append(loss.item())
                
                print (loss.item())
                loss.backward()
                self.optimizer.step()
                
                if epoch_loss[-1] < self.best_train_loss:
                    self.best_train_loss = epoch_loss[-1]
                    torch.save(self.model.state_dict(),\
                               os.path.join(self.model_save_path, f'best_training_model_epoch{epoch}_iter{i}.pth'))
                    
            self.trianing_loss.append(np.mean(epoch_loss))
            self.optimizerScheduler()
            
        print ("Epoch: {:3d} | Train loss: {:.3f} | Time: {:4d}s".format(epoch, self.trianing_loss[-1], time.time()-t0))


In [17]:
styEncTrain = Trainer(dataloader=train_dataloader, device=torch.device('cpu'))

In [None]:
styEncTrain.train()

0.3998095393180847
0.4006498456001282
0.3996548652648926


In [7]:
os.path.join('checkpoints', f'best_training_model_epoch{2}_iter{30}.pth')

'checkpoints/best_training_model_epoch2_iter30.pth'