In [None]:
import datetime
import os
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader as DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from matplotlib import pyplot as plt
from melSpecDataset import MelSpecDataset
import basic_model as net0
import christ as net1

from torchvision import transforms
from torchvision.datasets import ImageFolder

In [None]:
#set run choices
loss_plot = True
verbose = True
epoch_save = False

# set variables
train_dir = './splitdata/training'
val_dir = './splitdata/testing'
#test_dir = './images/testing/'
gamma = 1
num_epochs = 100
batch_size = 64
learning_rate = 0.0001
weight_decay = 1e-4
rho = 0.9
eps = 1e-06

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
device

In [None]:


#need to change based on model name
#this calls the constructor of model class setting the choosen model for the run
################################################################################

#model = net0.GenreClassificationANN()
model = net1.MusicClassNet()

model.to(device)

################################################################################
# dataloader

transform = transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.ToTensor(),
])

# training
train_dataset = MelSpecDataset(train_dir, transform=transform)
data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# validation
val_dataset = MelSpecDataset(val_dir, transform)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# model optimizer
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho, eps=eps, weight_decay=weight_decay )

# model scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=6, verbose=True)

# loss function
loss_function = nn.CrossEntropyLoss()

In [None]:
def validate(model, val_loader, loss_function, device):
    model.eval()
    val_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for melspecs, labels in val_loader:
            audios = melspecs.to(device)
            labels = labels.to(device)

            output = model(audios)

            # Calculate the loss
            loss = loss_function(output, labels)
            val_loss += loss.item()

            # Calculate accuracy (you can customize this based on your task)
            predictions = torch.argmax(output, dim=1)
            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    # Calculate average validation loss
    avg_val_loss = val_loss / len(val_loader)

    # Calculate accuracy
    accuracy = total_correct / total_samples

    return avg_val_loss, accuracy

In [None]:
def train():
    model.train()
    model.to(device)

    epoch_losses = []
    epoch_losses_val = []
    n_batches_train = len(data_loader)
    n_batches_val = len(val_loader)

    for epoch in range(1, num_epochs+1):
        batch_loss = 0
        print(f'Epoch #{epoch}, Start Time: {datetime.datetime.now()}')

        #training
        model.train()
        for melspecs, labels in data_loader:
            
            #print(melspecs.shape)
            #print(labels)
            audios = melspecs.to(device)
            labels = labels.to(device)
        
            # calculate losses and call call model
            output = model(audios)
            
            # batch loss
            loss = loss_function(output, labels)
            batch_loss += loss.item()

            # backpropogation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # save epoch loss
        epoch_losses += [batch_loss/n_batches_train]
        scheduler.step(epoch_losses[-1])
        
        
        #validation
        loss_val = 0
        model.eval()
        with torch.no_grad():
            for melspecs, labels in val_loader:
                audios = melspecs.to(device)
                labels = labels.to(device)
 
                output = model(audios)
            
                # calculate losses
                loss = loss_function(output, labels)
                loss_val += loss.item()
    
        # Record the validation loss and accuracy
        epoch_losses_val += [loss_val / n_batches_val]
        
        #epoch_accuracies.append(val_accuracy)
        
        if verbose:
            print(f'Epoch: #{epoch}, Loss: {epoch_losses[epoch-1]}')
            print(f'Epoch: #{epoch}, Val_Loss: {epoch_losses_val[epoch - 1]}')

        if epoch_save:
            model_folder_dir = './temp_models'
            if not os.path.isdir(model_folder_dir):
                os.mkdir(model_folder_dir)

             # save temp model
            try:
                temp_model_path = model_folder_dir + '/' + str(datetime.datetime.now()) + '_epoch' + str(epoch) + '.pth'
                torch.save(model.state_dict(), temp_model_path)
                if verbose:
                    print(f'Saved model for epoch {epoch} @{temp_model_path}')
            except:
                print('Epoch model save failed')

    # save final model parameters   
    torch.save(model.state_dict(), f'model.pth')

    # save final loss plot
    if not os.path.exists("./plots"):
        os.makedirs("./plots")
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    # generate_loss_plot(epoch_losses, f'./plots/loss_plot_{timestamp}.png', show_plot=loss_plot)
    generate_loss_plot_with_val(epoch_losses, epoch_losses_val, f'./plots/loss_plot_{timestamp}.png', show_plot=loss_plot)

def generate_loss_plot(loss, file_loc, show_plot=False):
    epochs = list(range(1, len(loss)+1))
    plt.plot(epochs, loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Epoch vs Loss')
    plt.savefig(file_loc)
    if show_plot:
        plt.show()
    plt.close()

def generate_loss_plot_with_val(train_loss, val_loss, file_loc, show_plot=False): # loss plot with validation
    epochs = list(range(1, len(train_loss)+1))
    plt.plot(epochs, train_loss, label = "Training Loss")
    plt.plot(epochs, val_loss, label= "Validation Loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Epoch vs Loss')
    plt.legend()
    plt.savefig(file_loc)
    
    if show_plot:
        plt.show()
    plt.close()


In [None]:
if __name__ == "__main__":
    train()