In [1]:
import datetime
import os
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader as DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from matplotlib import pyplot as plt
from melSpecDataset import MelSpecDataset
import basic_model as net0
import ModMusicRedNet as net1

In [2]:
#set run choices
loss_plot = True
verbose = True
epoch_save = False

# set variables
train_dir = './splitdata/training'
val_dir = './splitdata/testing'
num_epochs = 10
batch_size = 32
learning_rate = 0.001
weight_decay = 1e-4
rho = 0.9
eps = 1e-06
momentum = 0.6

In [3]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [4]:
def calcMeanStd ():
    # Assuming MelSpecDataset is your dataset class and train_dir is your training directory
    resize_size = (258, 128)
    dataset = MelSpecDataset(train_dir, transform=Compose([Resize(resize_size), ToTensor()]))

    # Create a DataLoader with the desired batch size
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # Variables to accumulate the sum and sum of squares
    mean_sum = 0.0
    sum_of_squares = 0.0
    nb_samples = 0

    # Loop through all the batches in the DataLoader
    for images, _ in loader:
        # Flatten the images to (batch_size, pixels)
        images = images.view(images.size(0), -1)
        # Sum up the mean and mean of squares
        mean_sum += images.mean(1).sum(0)
        sum_of_squares += (images ** 2).mean(1).sum(0)
        # Count the total number of samples (images) processed
        nb_samples += images.size(0)

    # Calculate the mean and standard deviation
    mean = mean_sum / nb_samples
    # For std, we need to take the square root of the variance (average of the squared differences from the mean)
    std = (sum_of_squares / nb_samples - mean ** 2) ** 0.5

    # Convert to scalar for single-channel image
    mean = mean.item()
    std = std.item()

    #print(f'Calculated mean: {mean}')
    #print(f'Calculated std: {std}')
    return mean, std

In [5]:


#need to change based on model name
#this calls the constructor of model class setting the choosen model for the run
################################################################################
#model = net0.GenreClassificationANN()
model = net1.MusicClassNet()

model.to(device)

################################################################################
#dataloader
mean, std = calcMeanStd()
resize_size = (258, 128)
transform = Compose([
    Resize(resize_size),
    ToTensor(),
    Normalize(mean=[mean], std=[std]) 
])

#training
train_dataset = MelSpecDataset(train_dir, transform)
data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

#validation
val_dataset = MelSpecDataset(val_dir, transform)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

print("data loaded sucessfully")
#model optimizer
#optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho, eps=eps, weight_decay=weight_decay)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
#model scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=6, verbose=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)


# loss function
loss_function = nn.CrossEntropyLoss()

data loaded sucessfully


In [6]:
def train():
    model.train()
    model.to(device)

    epoch_losses = []
    epoch_losses_val = []
    n_batches_train = len(data_loader)
    n_batches_val = len(val_loader)

    for epoch in range(1, num_epochs+1):
        batch_loss = 0
        print(f'Epoch #{epoch}, Start Time: {datetime.datetime.now()}')

        #training
        model.train()
        for melspecs, labels in data_loader:
            
            #print(melspecs)
            #print(labels)
            audios = melspecs.to(device)
            labels = labels.to(device)
            print("before model")
            # calculate losses and call call model
            output = model(audios)
            print("ouput reached")
            # batch loss
            loss = loss_function(output, labels)
            #print(loss)
            batch_loss += loss.item()
            #print(batch_loss)

            # backpropogation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # save epoch loss
        epoch_losses += [batch_loss/n_batches_train]
        #scheduler.step()
        scheduler.step(epoch_losses[-1])
        print("epoch reached")
        

        #validation
        loss_val = 0
        model.eval()
        with torch.no_grad():
            for melspecs, labels in val_loader:
                audios = melspecs.to(device)
                labels = labels.to(device)

                output = model(audios)
                #print(output)

                # calculate losses
                loss = loss_function(output, labels)
                loss_val += loss.item()
        
        epoch_losses_val += [loss_val/n_batches_val]
        #scheduler.step(epoch_losses_val[-1])
        

        if verbose:
            print(f'Epoch: #{epoch}, Loss: {epoch_losses[epoch-1]}')
            print(f'Epoch: #{epoch}, Val_Loss: {epoch_losses_val[epoch - 1]}')

        if epoch_save:
            model_folder_dir = './temp_models'
            if not os.path.isdir(model_folder_dir):
                os.mkdir(model_folder_dir)

             # save temp model
            try:
                temp_model_path = model_folder_dir + '/' + str(datetime.datetime.now()) + '_epoch' + str(epoch) + '.pth'
                torch.save(model.state_dict(), temp_model_path)
                if verbose:
                    print(f'Saved model for epoch {epoch} @{temp_model_path}')
            except:
                print('Epoch model save failed')

    #save final model parameters   
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    torch.save(model.state_dict(), f'model_mod_{timestamp}.pth')

    # save final loss plot
    if not os.path.exists("./plots"):
        os.makedirs("./plots")
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    #generate_loss_plot(epoch_losses, f'./plots/loss_plot_{timestamp}.png', show_plot=loss_plot)
    generate_loss_plot_with_val(epoch_losses, epoch_losses_val, f'./plots/loss_plot_{timestamp}.png', show_plot=loss_plot)

def generate_loss_plot(loss, file_loc, show_plot=False):
    epochs = list(range(1, len(loss)+1))
    plt.plot(epochs, loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Epoch vs Loss')
    plt.savefig(file_loc)
    if show_plot:
        plt.show()
    plt.close()

def generate_loss_plot_with_val(train_loss, val_loss, file_loc, show_plot=False): # loss plot with validation
    epochs = list(range(1, len(train_loss)+1))
    plt.plot(epochs, train_loss, label = "Training Loss")
    plt.plot(epochs, val_loss, label= "Validation Loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Epoch vs Loss')
    plt.legend()
    plt.savefig(file_loc)
    if show_plot:
        plt.show()
    plt.close()


In [7]:
if __name__ == "__main__":
    train()

Epoch #1, Start Time: 2023-12-05 15:58:29.027354
before model


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x512 and 1024x256)