In [1]:
# Load from USB
import json
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import pretty_midi
from model import Net
from midi_to_piano_roll import midi_to_piano_roll
from loss import blur_loss

In [2]:
# Test code on small dataset, or on full dataset
mode = "full"

In [3]:
if mode == "full":
        IN_FOLDER = '/media/allentao/One Touch/APS360/data/clean_data/'
        file_path = "../data/songs.json" # full dataset is too large to load into memory
elif mode == "small":
        IN_FOLDER = '../data/clean_data/'
        file_path = "../data/songs_small.json"
else:
        raise NotImplementedError

with open(file_path, "r") as json_file:
        songs_file = json.load(json_file)

In [4]:
def extend(data, max_length):
    new_data = []
    for i in range(len(data)):
        rows_needed = max_length - data[i][0].shape[0]
        zeros_to_add = torch.zeros((rows_needed, 128), dtype=data[i][0].dtype)
        new_song= torch.concatenate((data[i][0], zeros_to_add), axis=0)

        rows_needed = max_length - data[i][1].shape[0]
        zeros_to_add = torch.zeros((rows_needed, 128), dtype=data[i][1].dtype)
        new_cover= torch.concatenate((data[i][1], zeros_to_add), axis=0)
        
        new_data.append((new_song, new_cover))
    return new_data

In [5]:
def get_data():
    training_data = []
    validation_data = []
    testing_data = []

    count = 0
    max_length = 0
    for song in songs_file["songs"]:
        song_file = song["filename"]
        song_num = int(song_file.split("_")[0])
        for piano_file in song["piano covers"]["filename"]:
            
            name = os.path.splitext(piano_file)[0].split('_')[0] + "_" + os.path.splitext(piano_file)[0].split('_')[1]
            song_file_path = IN_FOLDER + name + "_song.midi"
            cover_file_path = IN_FOLDER + name + "_cover.midi"
            print("Parsing", song_file_path, cover_file_path)

            song_piano_roll = midi_to_piano_roll(song_file_path)
            cover_piano_roll = midi_to_piano_roll(cover_file_path)

            if song_piano_roll == None or cover_piano_roll == None:
                continue

            song_piano_roll_val = song_piano_roll[:song_piano_roll.shape[-1]//2, :]
            cover_piano_roll_val = cover_piano_roll[:cover_piano_roll.shape[-1]//2, :]
            
            song_length = song_piano_roll.shape[0]
            cover_length = cover_piano_roll.shape[0]

            if song_length > max_length:
                max_length = song_length
            if cover_length > max_length:
                max_length = cover_length
            training_data.append((song_piano_roll, cover_piano_roll))
            if count < 200:
                validation_data.append((song_piano_roll_val, cover_piano_roll_val))
            elif count < 400:
                testing_data.append((song_piano_roll_val, cover_piano_roll_val))
            
            print("Processed", count, "songs")
            count += 1

    # training_data = extend(training_data, max_length)
    # validation_data = extend(validation_data, max_length)
    # testing_data = extend(testing_data, max_length)
    
    return training_data, validation_data, testing_data

In [6]:
def model_train(model, lr, batch_size, training_data, validation_data, num_epochs, device, loss_func="mse"):
    torch.cuda.empty_cache()

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    batch_size = batch_size
    train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)

    validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle = True)

    train_loss = np.zeros(num_epochs)
    val_loss = np.zeros(num_epochs)


    for epoch in range(num_epochs):
        train_loss_total = 0.0
        val_loss_total = 0.0

        # Training
        model.train()
        count = 0
        for data in train_loader:
            count += 1
            songs = data[0].to(device)

            covers = data[1].to(device)
            optimizer.zero_grad()

            outputs = model(songs)
            outputs = outputs.to(device)

            # pad tensors to same length
            if outputs.shape[1] > covers.shape[1]:
                covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
            elif covers.shape[1] > outputs.shape[1]:
                outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
            assert(outputs.shape == covers.shape)
            
            if loss_func == "custom":
                loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
            elif loss_func == "mse":
                loss = criterion(outputs, covers)
            else:
                raise NotImplementedError

            loss.backward(retain_graph = True)
            optimizer.step()

            train_loss_total += loss.item()
            
            torch.cuda.empty_cache()

        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            # Add any other information you want to save (e.g., training loss, validation loss, etc.)
        }

        torch.save(checkpoint, f'/media/allentao/One Touch/APS360/ckpts/checkpoint_epoch{epoch + 1}.pt')
        # Validation
        model.eval()
        with torch.no_grad():
            for data in validation_loader:
                # loss = criterion(outputs, labels)
                songs = data[0].to(device)

                covers = data[1].to(device)

                outputs = model(songs)
                outputs = outputs.to(device)

                # pad tensors to same length
                if outputs.shape[1] > covers.shape[1]:
                    covers = F.pad(covers, (0, 0, 0, outputs.shape[1] - covers.shape[1]))
                elif covers.shape[1] > outputs.shape[1]:
                    outputs = F.pad(outputs, (0, 0, 0, covers.shape[1] - outputs.shape[1]))
                assert(outputs.shape == covers.shape)
                
                if loss_func == "custom":
                    loss = blur_loss(outputs, covers, device) + criterion(outputs, covers) # warning: memory intensive
                elif loss_func == "mse":
                    loss = criterion(outputs, covers)
                else:
                    raise NotImplementedError
                
                val_loss_total += loss.item()
                
        train_loss[epoch] = train_loss_total
        val_loss[epoch] = val_loss_total

        print(f'Epoch [{epoch+1}/{num_epochs}], '
                f'Train Loss: {train_loss_total:.7f}, Train Loss: {train_loss_total:.7f}, '
                f'Val Loss: {val_loss_total:.7f}, Val Loss: {val_loss_total:.7f}')
        torch.cuda.empty_cache()


    model_path = str(lr) + '_' + str(batch_size) + '_' + str(num_epochs)
    torch.save(model.state_dict(), 'model' + model_path)
    np.savetxt("{}_train_loss.csv".format(model_path), train_loss)
    np.savetxt("{}_val_loss.csv".format(model_path), val_loss)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" # force cpu
device

device(type='cuda')

In [8]:
# Set the batch size
batch_size = 1

In [9]:
model = Net(width = 3, batch_size = batch_size)

In [10]:
print(IN_FOLDER)
print(file_path)
training_data, validation_data, testing_data = get_data()

/media/allentao/One Touch/APS360/data/clean_data/
../data/songs_med.json
Parsing /media/allentao/One Touch/APS360/data/clean_data/0_0_song.midi /media/allentao/One Touch/APS360/data/clean_data/0_0_cover.midi


Processed 0 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/0_1_song.midi /media/allentao/One Touch/APS360/data/clean_data/0_1_cover.midi
Processed 1 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/0_2_song.midi /media/allentao/One Touch/APS360/data/clean_data/0_2_cover.midi
Parsing /media/allentao/One Touch/APS360/data/clean_data/1_0_song.midi /media/allentao/One Touch/APS360/data/clean_data/1_0_cover.midi
Processed 2 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/1_1_song.midi /media/allentao/One Touch/APS360/data/clean_data/1_1_cover.midi
Processed 3 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/1_2_song.midi /media/allentao/One Touch/APS360/data/clean_data/1_2_cover.midi
Processed 4 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/2_0_song.midi /media/allentao/One Touch/APS360/data/clean_data/2_0_cover.midi
Processed 5 songs
Parsing /media/allentao/One Touch/APS360/data/clean_data/2_1_song.midi /media/alle

KeyboardInterrupt: 

In [None]:
# train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# for data in train_loader:
#     model = model.to(device)
#     out = model(data[0].to(device))

In [None]:
model_train(model, 1e4, batch_size, training_data, validation_data, 10, device, loss_func="mse")

Epoch [1/10], Train Loss: 0.1448535, Train Loss: 0.1448535, Val Loss: 0.0238449, Val Loss: 0.0238449
Epoch [2/10], Train Loss: 0.1448556, Train Loss: 0.1448556, Val Loss: 0.0226380, Val Loss: 0.0226380
Epoch [3/10], Train Loss: 0.1448541, Train Loss: 0.1448541, Val Loss: 0.0205448, Val Loss: 0.0205448
Epoch [4/10], Train Loss: 0.1448511, Train Loss: 0.1448511, Val Loss: 0.0229805, Val Loss: 0.0229805
Epoch [5/10], Train Loss: 0.1448465, Train Loss: 0.1448465, Val Loss: 0.0281214, Val Loss: 0.0281214
Epoch [6/10], Train Loss: 0.1448440, Train Loss: 0.1448440, Val Loss: 0.0344856, Val Loss: 0.0344856
Epoch [7/10], Train Loss: 0.1448424, Train Loss: 0.1448424, Val Loss: 0.0192781, Val Loss: 0.0192781
Epoch [8/10], Train Loss: 0.1448418, Train Loss: 0.1448418, Val Loss: 0.0356231, Val Loss: 0.0356231
Epoch [9/10], Train Loss: 0.1448443, Train Loss: 0.1448443, Val Loss: 0.0210562, Val Loss: 0.0210562
Epoch [10/10], Train Loss: 0.1448433, Train Loss: 0.1448433, Val Loss: 0.0282223, Val Loss: