In [1]:
import musdataset

In [2]:
import torch.nn as nn
import numpy as np
import torchaudio
import torch
from musdataset import model, optimizer, scheduler, criterion, last_epoch, num_epochs, PATH, dataloaders


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(device))

# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer']) 
# scheduler.load_state_dict(checkpoint['scheduler'])
# last_epoch = checkpoint['last_epoch']

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)



def train(model, criterion, optimizer, scheduler, last_epoch,  num_epochs):

    for epoch in range(last_epoch, num_epochs):       
        print('-' * 10)
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            correct = 0  
            total = 0
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, predicted = torch.max(outputs, 1)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                correct += torch.sum(predicted == labels.data).item()
                total += labels.size(0)
            
            if phase == 'train':
                scheduler.step()

            print("Epoch: ", epoch + 1 , " | Correct/Total: ", correct, "/",total )
            print("Accuracy: ", (correct/total) * 100, "%")
    last_epoch = num_epochs
    print("Training done")


def test(model):
    model.eval()
    with torch.no_grad():
        correct = 0  
        total = 0
        for inputs, labels in dataloaders['test']:
            correct = 0  
            total = 0
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            _ , predicted = torch.max(outputs, 1)
        
            correct += torch.sum(predicted == labels.data)
            total += labels.size(0)
        print("Correct/Total: ", correct, "/",total )
        print("Accuracy: ", (correct/total) * 100, "%")
    
    print("Testing done")
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of parameters: ", total_params)

def prod(model = None): #for real demo purposes
    pass


train(model, criterion, optimizer, scheduler, last_epoch, num_epochs)

checkpoint = {
     'model': model.state_dict(),
     'optimizer': optimizer.state_dict(),
     'scheduler': scheduler.state_dict(),
     'last_epoch': last_epoch + 1
    }
torch.save(checkpoint, PATH)

# test(model)

# prod(model)


True
Tesla T4
----------
Epoch:  1  | Correct/Total:  190657 / 289205
Accuracy:  65.92451721097491 %
Epoch:  1  | Correct/Total:  8653 / 12678
Accuracy:  68.25209023505285 %
----------
Epoch:  2  | Correct/Total:  242480 / 289205
Accuracy:  83.84364032433741 %
Epoch:  2  | Correct/Total:  8465 / 12678
Accuracy:  66.76920649944786 %
----------
Epoch:  3  | Correct/Total:  257997 / 289205
Accuracy:  89.20903857125569 %
Epoch:  3  | Correct/Total:  9050 / 12678
Accuracy:  71.38349897460166 %
----------
Epoch:  4  | Correct/Total:  267206 / 289205
Accuracy:  92.39328504002351 %
Epoch:  4  | Correct/Total:  9175 / 12678
Accuracy:  72.3694589051901 %
----------
Epoch:  5  | Correct/Total:  273328 / 289205
Accuracy:  94.51012257741048 %
Epoch:  5  | Correct/Total:  9286 / 12678
Accuracy:  73.24499132355261 %
----------
Epoch:  6  | Correct/Total:  277781 / 289205
Accuracy:  96.04986082536608 %
Epoch:  6  | Correct/Total:  9368 / 12678
Accuracy:  73.8917810380186 %
----------
Epoch:  7  | Corr