# 𝄆  Melogen Training Script  𝄇

## Initialization

In [76]:
# Load from USB
import json
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import pretty_midi
from model import Net
from midi_to_piano_roll import midi_to_piano_roll, postprocess
from loss import blur_loss

import time
import pickle

import copy

In [77]:
### CONFIGS ###

# Test code on small dataset, half dataset, or on full dataset
MODE = "full"

# Set the batch size
BATCH_SIZE = 1

# Set loader style
LOADER = "single"
LOAD_BATCH_SIZE = 250

# Train, val, test split
VAL_SAMPLES = 200
TEST_SAMPLES = 400

# If we already saved the npy's no need to resave
ALREADY_LOADED = True

CKPT_PATH = "/media/allentao/One Touch/APS360/ckpts/aug1lr0.0001overfit"
# CKPT_PATH = None

# Set to train mode or evaluation mode
SETTING = "eval" # "train", "eval"


In [78]:
if MODE == "full":
        IN_FOLDER = '/media/allentao/One Touch/APS360/data/clean_data/'
        file_path = "../data/songs.json"
elif MODE == "half":
        IN_FOLDER = '/media/allentao/One Touch/APS360/data/clean_data/'
        file_path = "../data/songs_med.json" # half dataset
elif MODE == "small":
        IN_FOLDER = '../data/clean_data/'
        file_path = "../data/songs_small.json"
else:
        raise NotImplementedError

with open(file_path, "r") as json_file:
        songs_file = json.load(json_file)

In [79]:
# def extend(data, max_length):
#     new_data = []
#     for i in range(len(data)):
#         rows_needed = max_length - data[i][0].shape[0]
#         zeros_to_add = torch.zeros((rows_needed, 128), dtype=data[i][0].dtype)
#         new_song= torch.concatenate((data[i][0], zeros_to_add), axis=0)

#         rows_needed = max_length - data[i][1].shape[0]
#         zeros_to_add = torch.zeros((rows_needed, 128), dtype=data[i][1].dtype)
#         new_cover= torch.concatenate((data[i][1], zeros_to_add), axis=0)
        
#         new_data.append((new_song, new_cover))
#     return new_data

## Load the data

In [80]:
def get_data():
    training_data = []
    validation_data = []
    testing_data = []

    count = 0
    max_length = 0
    for song in songs_file["songs"]:
        song_file = song["filename"]
        song_num = int(song_file.split("_")[0])
        for piano_file in song["piano covers"]["filename"]:
            
            name = os.path.splitext(piano_file)[0].split('_')[0] + "_" + os.path.splitext(piano_file)[0].split('_')[1]
            song_file_path = IN_FOLDER + name + "_song.midi"
            cover_file_path = IN_FOLDER + name + "_cover.midi"
            print("Parsing", song_file_path, cover_file_path)

            song_piano_roll = midi_to_piano_roll(song_file_path)
            cover_piano_roll = midi_to_piano_roll(cover_file_path)

            if song_piano_roll == None or cover_piano_roll == None:
                continue

            song_piano_roll_val = song_piano_roll[:song_piano_roll.shape[-1]//2, :]
            cover_piano_roll_val = cover_piano_roll[:cover_piano_roll.shape[-1]//2, :]
            
            song_length = song_piano_roll.shape[0]
            cover_length = cover_piano_roll.shape[0]

            if song_length > max_length:
                max_length = song_length
            if cover_length > max_length:
                max_length = cover_length
            if count < VAL_SAMPLES:
                validation_data.append((song_piano_roll_val, cover_piano_roll_val))
            elif VAL_SAMPLES <= count < TEST_SAMPLES:
                testing_data.append((song_piano_roll_val, cover_piano_roll_val))
            else:
                training_data.append((song_piano_roll, cover_piano_roll))
                
            print("Processed", count, "songs")
            count += 1

    # training_data = extend(training_data, max_length)
    # validation_data = extend(validation_data, max_length)
    # testing_data = extend(testing_data, max_length)
    
    return training_data, validation_data, testing_data

In [81]:
def get_data_paths(already_loaded=False):
    # data will now consist of file paths, which can be loaded individually
    # UPDATE: massive speedup, pickle the files on hdd, can load into memory in batches
    #   save the file paths of the pkl's instead

    training_data = []
    validation_data = []
    testing_data = []

    count = 0
    max_length = 0
    for song in songs_file["songs"]:
        song_file = song["filename"]
        song_num = int(song_file.split("_")[0])
        for piano_file in song["piano covers"]["filename"]:
            
            name = os.path.splitext(piano_file)[0].split('_')[0] + "_" + os.path.splitext(piano_file)[0].split('_')[1]
            song_file_path = IN_FOLDER + name + "_song.midi"
            cover_file_path = IN_FOLDER + name + "_cover.midi"
            print("Parsing", song_file_path, cover_file_path)

            train_song_path = os.path.join("/media/allentao/One Touch/APS360/pkls/train", name + ".npy")
            train_cover_path = os.path.join("/media/allentao/One Touch/APS360/pkls/train", name + ".npy")

            val_song_path = os.path.join("/media/allentao/One Touch/APS360/pkls/val", name + ".npy")
            val_cover_path = os.path.join("/media/allentao/One Touch/APS360/pkls/val", name + ".npy")

            if not already_loaded:
                song_piano_roll = midi_to_piano_roll(song_file_path)
                cover_piano_roll = midi_to_piano_roll(cover_file_path)

                if song_piano_roll == None or cover_piano_roll == None:
                    continue

                song_piano_roll_val = song_piano_roll[:song_piano_roll.shape[-1]//2, :]
                cover_piano_roll_val = cover_piano_roll[:cover_piano_roll.shape[-1]//2, :]
                
                song_length = song_piano_roll.shape[0]
                cover_length = cover_piano_roll.shape[0]

                if song_length > max_length:
                    max_length = song_length
                if cover_length > max_length:
                    max_length = cover_length

                # pickle the data onto hdd
                
                np.save(train_song_path, song_piano_roll)
                np.save(train_cover_path, cover_piano_roll)

                # save some memory these npy files are like 20MB each
                
                if count < max(TEST_SAMPLES, VAL_SAMPLES):
                    np.save(val_song_path, song_piano_roll_val)
                    np.save(val_cover_path, cover_piano_roll_val)
            else:
                # append if file path exists
                if not os.path.exists(train_song_path) or not os.path.exists(train_cover_path):
                    continue

                # save the file paths of the pkl's
                training_data.append((train_song_path, train_cover_path))
                if count < VAL_SAMPLES:
                    # validation_data.append((val_song_path, val_cover_path))
                    validation_data.append((train_song_path, train_cover_path))
                elif count < TEST_SAMPLES:
                    # testing_data.append((val_song_path, val_cover_path))
                    testing_data.append((train_song_path, train_cover_path))
            
            print("Processed", count, "songs")
            count += 1

    
    return training_data, validation_data, testing_data

In [82]:
def model_train(model, lr, batch_size, training_data, validation_data, num_epochs, device, loss_func="mse"):
    torch.cuda.empty_cache()

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    batch_size = batch_size
    train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)

    validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle = True)

    train_loss = np.zeros(num_epochs)
    val_loss = np.zeros(num_epochs)


    for epoch in range(num_epochs):
        train_loss_total = 0.0
        val_loss_total = 0.0

        # Training
        model.train()
        count = 0
        for data in train_loader:
            count += 1
            songs = data[0].to(device)

            covers = data[1].to(device)
            optimizer.zero_grad()

            outputs = model(songs)
            outputs = outputs.to(device)

            # pad tensors to same length
            if outputs.shape[1] > covers.shape[1]:
                covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
            elif covers.shape[1] > outputs.shape[1]:
                outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
            assert(outputs.shape == covers.shape)
            
            if loss_func == "custom":
                loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
            elif loss_func == "mse":
                loss = criterion(outputs, covers)
            else:
                raise NotImplementedError

            loss.backward(retain_graph = True)
            optimizer.step()

            train_loss_total += loss.item()
            
            torch.cuda.empty_cache()

        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            # Add any other information you want to save (e.g., training loss, validation loss, etc.)
        }

        torch.save(checkpoint, f'/media/allentao/One Touch/APS360/ckpts/checkpoint_epoch{epoch + 1}.pt')
        # Validation
        model.eval()
        with torch.no_grad():
            for data in validation_loader:
                # loss = criterion(outputs, labels)
                songs = data[0].to(device)

                covers = data[1].to(device)

                outputs = model(songs)
                outputs = outputs.to(device)

                # pad tensors to same length
                if outputs.shape[1] > covers.shape[1]:
                    covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                elif covers.shape[1] > outputs.shape[1]:
                    outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                assert(outputs.shape == covers.shape)
                
                if loss_func == "custom":
                    loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                elif loss_func == "mse":
                    loss = criterion(outputs, covers)
                else:
                    raise NotImplementedError
                
                val_loss_total += loss.item()
                
        train_loss[epoch] = train_loss_total
        val_loss[epoch] = val_loss_total

        print(f'Epoch [{epoch+1}/{num_epochs}], '
                f'Train Loss: {train_loss_total:.7f}, Train Loss: {train_loss_total:.7f}, '
                f'Val Loss: {val_loss_total:.7f}, Val Loss: {val_loss_total:.7f}')
        torch.cuda.empty_cache()


    model_path = str(lr) + '_' + str(batch_size) + '_' + str(num_epochs)
    torch.save(model.state_dict(), 'model' + model_path)
    np.savetxt("{}_train_loss.csv".format(model_path), train_loss)
    np.savetxt("{}_val_loss.csv".format(model_path), val_loss)

In [83]:
def prepare_batch_sample(data, load_batch_size=1, dataset_type="train"):

    # should only be called in single data loading mode (to save memory)
    # parses the midi file and returns the piano roll representation of the song and cover
    # saves it in the correct format to be data loaded into pytorch

    out_data = []

    # print("TYPE:", dataset_type)
    # print("Building batch of size", load_batch_size, "...")
    # print("DEBUG:", len(data), load_batch_size)

    for i in range(load_batch_size):

        song_file_path = data[i][0]
        cover_file_path = data[i][1]

        # print("Loading: ", song_file_path, cover_file_path)
        # input()

        # song_piano_roll = midi_to_piano_roll(song_file_path)
        # cover_piano_roll = midi_to_piano_roll(cover_file_path)

        # if song_piano_roll == None or cover_piano_roll == None:
        #     continue

        # song_piano_roll_val = song_piano_roll[:song_piano_roll.shape[-1]//2, :]
        # cover_piano_roll_val = cover_piano_roll[:cover_piano_roll.shape[-1]//2, :]

        song_piano_roll = np.load(song_file_path)
        cover_piano_roll = np.load(cover_file_path)
        
        song_length = song_piano_roll.shape[0]
        cover_length = cover_piano_roll.shape[0]
        
        out_data.append((song_piano_roll, cover_piano_roll))


    return out_data

In [84]:
def model_train_paths(model, lr, load_batch_size, training_data, validation_data, 
                      num_epochs, device, loss_func="mse", load_ckpt=None):

    # get current time
    start = time.time()

    torch.cuda.empty_cache()

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    # optimizer = optim.Adam(model.parameters(), lr=lr)

    if load_ckpt is not None and CKPT_PATH is not None:
        print("LOADING FROM CKPT:", load_ckpt)
        checkpoint = torch.load(os.path.join(CKPT_PATH, load_ckpt))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = 0

    train_loss = np.zeros(num_epochs)
    val_loss = np.zeros(num_epochs)

    batch_size = 1

    # postprocess is broken
    torch.autograd.set_detect_anomaly(True)


    for epoch in range(num_epochs):
        train_loss_total = 0.0
        val_loss_total = 0.0

        # Training
        model.train()

        data_path_batched = []

        for i, data_path in enumerate(training_data):

            data_path_batched.append(data_path)

            if ((i + 1) % load_batch_size == 0) or i == len(training_data) - 1: # only load when we accumulated enough
                prep_load_batch_size = load_batch_size if (i + 1) % load_batch_size == 0 else len(training_data) % load_batch_size
                batch_data = prepare_batch_sample(data_path_batched, prep_load_batch_size, "train")
                train_loader = DataLoader(batch_data, batch_size=batch_size, shuffle=True) # this still has to be 1 due to vram constraints

                for data in train_loader:
                    songs = data[0].to(device)

                    covers = data[1].to(device)
                    optimizer.zero_grad()

                    outputs = model(songs)
                    outputs = outputs.to(device)
                    outputs = postprocess(outputs, covers)

                    # pad tensors to same length
                    if outputs.shape[1] > covers.shape[1]:
                        covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                    elif covers.shape[1] > outputs.shape[1]:
                        outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                    assert(outputs.shape == covers.shape)
                    
                    if loss_func == "custom":
                        loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                    elif loss_func == "mse":
                        loss = criterion(outputs, covers)
                    else:
                        raise NotImplementedError

                    loss.backward(retain_graph = True)
                    optimizer.step()

                    train_loss_total += loss.item()
                    
                    torch.cuda.empty_cache()

                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    # Add any other information you want to save (e.g., training loss, validation loss, etc.)
                }

                torch.save(checkpoint, f'/media/allentao/One Touch/APS360/ckpts/checkpoint_epoch{epoch + 1}.pt')

                # reset
                data_path_batched = []

        # Validation
        model.eval()
        data_path_batched = []
        with torch.no_grad():
            for i, data_path in enumerate(validation_data):
                data_path_batched.append(data_path)

                if ((i + 1) % load_batch_size == 0) or i == len(validation_data) - 1: # only load when we accumulated enough
                    prep_load_batch_size = load_batch_size if (i + 1) % load_batch_size == 0 else len(validation_data) % load_batch_size
                    batch_data = prepare_batch_sample(data_path_batched, prep_load_batch_size, "val")
                    validation_loader = DataLoader(batch_data, batch_size=batch_size, shuffle=True) 

                    for data in validation_loader:

                        songs = data[0].to(device)

                        covers = data[1].to(device)

                        outputs = model(songs)
                        outputs = outputs.to(device)
                        outputs = postprocess(outputs, covers)

                        # pad tensors to same length
                        if outputs.shape[1] > covers.shape[1]:
                            covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                        elif covers.shape[1] > outputs.shape[1]:
                            outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                        assert(outputs.shape == covers.shape)
                        
                        if loss_func == "custom":
                            loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                        elif loss_func == "mse":
                            loss = criterion(outputs, covers)
                        else:
                            raise NotImplementedError
                        
                        val_loss_total += loss.item()
                        
                    data_path_batched = []
                    
        train_loss[epoch] = train_loss_total
        val_loss[epoch] = val_loss_total

        print(f'Epoch [{epoch+1}/{num_epochs}], '
                f'Train Loss: {train_loss_total:.7f}, Train Loss: {train_loss_total:.7f}, '
                f'Val Loss: {val_loss_total:.7f}, Val Loss: {val_loss_total:.7f}')
        print("Time elapsed:", time.time() - start)
        
        torch.cuda.empty_cache()


    model_path = str(lr) + '_' + str(batch_size) + '_' + str(num_epochs)
    torch.save(model.state_dict(), 'model' + model_path)
    np.savetxt("{}_train_loss.csv".format(model_path), train_loss)
    np.savetxt("{}_val_loss.csv".format(model_path), val_loss)

In [85]:
def model_train_paths_test(model, load_batch_size, test_data, 
                      device, loss_func="mse", load_ckpt=None):

    # get current time
    start = time.time()

    torch.cuda.empty_cache()

    model = model.to(device)
    criterion = nn.MSELoss()
    # optimizer = optim.Adam(model.parameters(), lr=lr)

    if load_ckpt is not None and CKPT_PATH is not None:
        print("LOADING FROM CKPT:", load_ckpt)
        checkpoint = torch.load(os.path.join(CKPT_PATH, load_ckpt))
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = 0

    test_loss = 0

    batch_size = 1

    test_loss_total = 0.0

    # Testing
    model.eval()
    data_path_batched = []
    with torch.no_grad():
        for i, data_path in enumerate(test_data):
            data_path_batched.append(data_path)

            if ((i + 1) % load_batch_size == 0) or i == len(test_data) - 1: # only load when we accumulated enough
                prep_load_batch_size = load_batch_size if (i + 1) % load_batch_size == 0 else len(test_data) % load_batch_size
                batch_data = prepare_batch_sample(data_path_batched, prep_load_batch_size, "val")
                validation_loader = DataLoader(batch_data, batch_size=batch_size, shuffle=True) 

                for data in validation_loader:

                    songs = data[0].to(device)

                    covers = data[1].to(device)

                    outputs = model(songs)
                    outputs = outputs.to(device)
                    outputs = postprocess(outputs, covers)

                    # pad tensors to same length
                    if outputs.shape[1] > covers.shape[1]:
                        covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                    elif covers.shape[1] > outputs.shape[1]:
                        outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                    assert(outputs.shape == covers.shape)
                    
                    if loss_func == "custom":
                        loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                    elif loss_func == "mse":
                        loss = criterion(outputs, covers)
                    else:
                        raise NotImplementedError
                    
                    test_loss_total += loss.item()
                    
                data_path_batched = []

                print("Test loss:", test_loss_total)
                print("Time elapsed:", time.time() - start)

        torch.cuda.empty_cache()

In [86]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" # force cpu
device

device(type='cuda')

In [87]:
model = Net(width = 3, batch_size = BATCH_SIZE)

In [88]:
# WARNING: on single you only need to run this ONCE!
if LOADER == "batch":
    training_data, validation_data, testing_data = get_data()
elif LOADER == "single":
    training_data, validation_data, testing_data = get_data_paths(ALREADY_LOADED)
else:
    raise NotImplementedError

Parsing ../data/clean_data/0_0_song.midi ../data/clean_data/0_0_cover.midi
Processed 0 songs
Parsing ../data/clean_data/0_1_song.midi ../data/clean_data/0_1_cover.midi
Processed 1 songs
Parsing ../data/clean_data/0_2_song.midi ../data/clean_data/0_2_cover.midi
Processed 2 songs
Parsing ../data/clean_data/1_0_song.midi ../data/clean_data/1_0_cover.midi
Processed 3 songs
Parsing ../data/clean_data/1_1_song.midi ../data/clean_data/1_1_cover.midi
Processed 4 songs
Parsing ../data/clean_data/1_2_song.midi ../data/clean_data/1_2_cover.midi
Processed 5 songs
Parsing ../data/clean_data/2_0_song.midi ../data/clean_data/2_0_cover.midi
Processed 6 songs
Parsing ../data/clean_data/2_1_song.midi ../data/clean_data/2_1_cover.midi
Processed 7 songs
Parsing ../data/clean_data/2_2_song.midi ../data/clean_data/2_2_cover.midi
Processed 8 songs
Parsing ../data/clean_data/3_0_song.midi ../data/clean_data/3_0_cover.midi
Parsing ../data/clean_data/3_1_song.midi ../data/clean_data/3_1_cover.midi
Processed 9 s

## Train the Data

In [89]:
if LOADER == "batch" and SETTING == "train":
    model_train(model, 0.1, BATCH_SIZE, training_data, validation_data, 200, device, loss_func="mse")
elif LOADER == "single" and SETTING == "train":
    model_train_paths(model, 0.0001, LOAD_BATCH_SIZE, training_data, validation_data, 200, \
                      device, loss_func="mse", load_ckpt="checkpoint_epoch1.pt")

In [90]:
# Get testing loss using blur loss
model_train_paths_test(model, LOAD_BATCH_SIZE, testing_data, device, loss_func="custom", load_ckpt="checkpoint_epoch121.pt")

LOADING FROM CKPT: checkpoint_epoch121.pt


Test loss: 7.786501407623291
Time elapsed: 10.042664766311646


# Data Analysis (WIP)

In [91]:
# Accuracy function
import pretty_midi
import numpy as np

def midi_note_pitch_accuracy(piano_roll1, piano_roll2, start_time1=None, end_time1=None, start_time2=None, end_time2=None):

    # If start and end times are specified, convert them to frame indices and slice the piano rolls
    if start_time1 is not None and end_time1 is not None:
        start_frame1 = int(start_time1 * 100)
        end_frame1 = int(end_time1 * 100)
        piano_roll1 = piano_roll1[:, start_frame1:end_frame1]

    if start_time2 is not None and end_time2 is not None:
        start_frame2 = int(start_time2 * 100)
        end_frame2 = int(end_time2 * 100)
        piano_roll2 = piano_roll2[:, start_frame2:end_frame2]

    # If the piano rolls have different numbers of columns, truncate the longer one to match the shorter one
    min_length = min(piano_roll1.shape[1], piano_roll2.shape[1])
    piano_roll1 = piano_roll1[:, :min_length]
    piano_roll2 = piano_roll2[:, :min_length]

    # Get the note pitch at each time step by finding the index of the maximum value in each column
    note_sequence1 = np.argmax(piano_roll1, axis=0)
    note_sequence2 = np.argmax(piano_roll2, axis=0)

    # Compute the note pitch accuracy by comparing the note sequences
    correct_notes = np.sum(note_sequence1 == note_sequence2)
    total_notes = len(note_sequence1)
    note_accuracy = correct_notes / total_notes

    return note_accuracy

In [92]:
# Evaluate: loads the model checkpoints and computes the losses and accuracies
def evaluate(trained_epoch, load_batch_size, training_data, validation_data, 
              device, loss_func="mse"):

    # Load the current checkpoint
    cur_ckpt = os.path.join(CKPT_PATH, "checkpoint_epoch" + str(trained_epoch) + ".pt")
    state = torch.load(cur_ckpt)
    model = Net(width = 3, batch_size = 1)
    model.load_state_dict(state['state_dict'])

    # Compute losses and accuracies
    train_loss_total = 0.0
    train_acc_total = 0.0
    val_loss_total = 0.0
    val_acc_total = 0.0

    model.eval()
    batch_size = 1

    torch.cuda.empty_cache()
    model = model.to(device)
    criterion = nn.MSELoss()

    data_path_batched = []

    for i, data_path in enumerate(training_data):

        data_path_batched.append(data_path)

        if ((i + 1) % load_batch_size == 0) or i == len(training_data) - 1: # only load when we accumulated enough
            prep_load_batch_size = load_batch_size if (i + 1) % load_batch_size == 0 else len(training_data) % load_batch_size
            batch_data = prepare_batch_sample(data_path_batched, prep_load_batch_size, "train")
            train_loader = DataLoader(batch_data, batch_size=batch_size, shuffle=True) # this still has to be 1 due to vram constraints

            for data in train_loader:
                songs = data[0].to(device)

                covers = data[1].to(device)

                outputs = model(songs)
                outputs = outputs.to(device)
                outputs = postprocess(outputs, covers)

                # pad tensors to same length
                if outputs.shape[1] > covers.shape[1]:
                    covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                elif covers.shape[1] > outputs.shape[1]:
                    outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                assert(outputs.shape == covers.shape)
                
                if loss_func == "custom":
                    loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                elif loss_func == "mse":
                    loss = criterion(outputs, covers)
                else:
                    raise NotImplementedError

                train_loss_total += loss.item().detach().numpy()

                train_acc_total += midi_note_pitch_accuracy(outputs.to("cpu"), covers.to("cpu"))
                
                torch.cuda.empty_cache()

            # reset
            data_path_batched = []

    # Validation
    data_path_batched = []
    with torch.no_grad():
        for i, data_path in enumerate(validation_data):
            data_path_batched.append(data_path)

            if ((i + 1) % load_batch_size == 0) or i == len(validation_data) - 1: # only load when we accumulated enough
                prep_load_batch_size = load_batch_size if (i + 1) % load_batch_size == 0 else len(validation_data) % load_batch_size
                batch_data = prepare_batch_sample(data_path_batched, prep_load_batch_size, "val")
                validation_loader = DataLoader(batch_data, batch_size=batch_size, shuffle=True) 

                for data in validation_loader:

                    songs = data[0].to(device)

                    covers = data[1].to(device)

                    outputs = model(songs)
                    outputs = outputs.to(device)
                    outputs = postprocess(outputs, covers)

                    # pad tensors to same length
                    if outputs.shape[1] > covers.shape[1]:
                        covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                    elif covers.shape[1] > outputs.shape[1]:
                        outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                    assert(outputs.shape == covers.shape)
                    
                    if loss_func == "custom":
                        loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                    elif loss_func == "mse":
                        loss = criterion(outputs, covers)
                    else:
                        raise NotImplementedError
                    
                    val_loss_total += loss.item().detach().numpy()

                    val_acc_total += midi_note_pitch_accuracy(outputs.to("cpu"), covers.to("cpu"))
                    
                data_path_batched = []
    
    torch.cuda.empty_cache()

    return train_loss_total, train_acc_total, val_loss_total, val_acc_total



In [93]:
# # Load model from each epoch and plot its training and validation loss and accuracies
# TRAINED_EPOCHS = 121

# # Arrays for plotting
# train_losses = []
# train_accs = []

# val_losses = []
# val_accs = []

# for epoch in range(1, TRAINED_EPOCHS + 1):
#     # ckpt_path = "/media/allentao/One Touch/APS360/ckpts/aug1lr0.0001overfit/checkpoint_epoch121.pt"
#     train_loss, train_acc, val_loss, val_acc = evaluate(epoch, LOAD_BATCH_SIZE, training_data, validation_data, device, loss_func="mse")
    
#     train_losses.append(train_loss)
#     train_accs.append(train_acc)
#     val_losses.append(val_loss)
#     val_accs.append(val_acc)

#     print("Epoch: {} | Train Loss: {} | Train Acc: {} | Val Loss: {} | Val Acc: {}".format(epoch, train_loss, train_acc, val_loss, val_acc))


In [94]:
# Generate the plots (quantitative evaluation)

In [95]:
# Bonus: generate confusion matrix or etc
