In [1]:
from musdataset import model, optimizer, scheduler, criterion, last_epoch, num_epochs, PATH, dataloaders, val_patience
import torch.nn as nn
import numpy as np
import torchaudio
import torch


# import torch_xla
# import torch_xla.core.xla_model as xm
# device_count = 1
# device = xm.xla_device()
# torch.set_num_threads(24)
# checkpoint = torch.load(PATH, map_location=torch.device('cpu'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_count = torch.cuda.device_count()
model = model.to(device)
if device_count > 1:
    model = nn.DataParallel(model)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(device))


def adjust_state_dict_for_device(state_dict):
    global device_count
    
    # Check if 'module.' prefix exists in any key
    has_module_prefix = any(k.startswith('module.') for k in state_dict.keys())

    if has_module_prefix and device_count <= 1:
        # Remove 'module.' prefix
        new_state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}
        
    elif not has_module_prefix and device_count > 1:
        # Add 'module.' prefix
        new_state_dict = {f'module.{k}': v for k, v in state_dict.items()}
        
    else:
        # No change needed
        new_state_dict = state_dict

    return new_state_dict


checkpoint = torch.load(PATH)
adjusted_state_dict = adjust_state_dict_for_device(checkpoint['model'])
model.load_state_dict(adjusted_state_dict)
optimizer.load_state_dict(checkpoint['optimizer']) 
scheduler.load_state_dict(checkpoint['scheduler'])
last_epoch = checkpoint['last_epoch']


def train(model, criterion, optimizer, scheduler, num_epochs, patience=val_patience):
    global last_epoch
    best_valid_acc = 0
    bad_epochs = 0
    for epoch in range(last_epoch, num_epochs):       
        print('-' * 10)
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            correct = 0  
            total = 0

            for inputs, labels in dataloaders[phase]:
                # with torch_xla.step():
                    # inputs, labels = inputs.to('xla'), labels.to('xla')
                    # print("TPU works")
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                    
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                        
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, predicted = torch.max(outputs, 1)
                        
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
    
                correct += torch.sum(predicted == labels.data).item()
                total += labels.size(0)
            
            # torch_xla.sync()
            acc = correct / total * 100
            
            print(f"{phase.capitalize()} Epoch {epoch+1}: {correct}/{total} ({acc:.2f}%)")
            
            if phase == 'train':
                scheduler.step()
                
            else:  # phase == 'valid'
                last_epoch = epoch
                
                if acc > best_valid_acc:
                    print("New best validation accuracy!")
                    best_valid_acc = acc
                    bad_epochs = 0
                    checkpoint = {
                     'model': model.state_dict(),
                     'optimizer': optimizer.state_dict(),
                     'scheduler': scheduler.state_dict(),
                     'last_epoch': last_epoch + 1
                    }
                    torch.save(checkpoint, PATH)
                    
                else:
                    bad_epochs += 1
                    print(f"Validation accuracy did not improve. Bad epochs: {bad_epochs}/{patience}")
                    if bad_epochs >= patience:
                        print("Early stopping triggered")
                        return
    print("Training complete")


def test(model):
    model.eval()
    with torch.no_grad():
        correct = 0  
        total = 0
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            _ , predicted = torch.max(outputs, 1)
        
            correct += torch.sum(predicted == labels.data).item()
            total += labels.size(0)
        print("Correct/Total: ", correct, "/",total )
        print("Accuracy: ", (correct/total) * 100, "%")
    
    print("Testing done")
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of parameters:", total_params)

def prod(model = None): #for real demo purposes
    pass


train(model, criterion, optimizer, scheduler, num_epochs)

# test(model)

# prod(model)


False


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx