## General Setup

In [None]:
%load_ext autoreload
%autoreload 2

import os

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from data.Dataset import DatasetUtils, MidiTransformerDataset
from models.Transformer import TransformerModel

import PlotUtils as pu
import MidiUtils as mu
from data.Song import Song

dataset_path = "/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0"
workspace = "/Users/andreas/Development/Midi-Conversion/PianoTranscription"

# Computing the total length of the dataset is expensive, so we cache it here
TRAIN_SET_TOTAL_LEN_DISC_100_TRANS = 112624
VAL_SET_TOTAL_LEN_DISC_100_TRANS = 13764
TEST_SET_TOTAL_LEN_DISC_100_TRANS = 14183

In [None]:
def get_data_loaders(batch_size=1, time_discretization=100, shuffle=True, precomp=False):
    # Only create train, validation, test splits if they don't exist (unnecessarily slow)
    if not (os.path.exists(os.path.join(dataset_path, 'train.txt')) and os.path.exists(os.path.join(dataset_path, 'test.txt')) and os.path.exists(os.path.join(dataset_path, 'validation.txt'))):
        DatasetUtils.create_dataset_files(dataset_path)

    train_set = MidiTransformerDataset(dataset_path, 'train', time_discretization, total_length=None, precomputed_midi=precomp)
    val_set = MidiTransformerDataset(dataset_path, 'validation', time_discretization, total_length=None, precomputed_midi=precomp)
    test_set = MidiTransformerDataset(dataset_path, 'test', time_discretization, total_length=None, precomputed_midi=precomp)
        
    train_data = DataLoader(train_set, batch_size=batch_size)
    val_data = DataLoader(val_set, batch_size=batch_size)
    test_data = DataLoader(test_set, batch_size=batch_size)

    return train_data, val_data, test_data


def create_tqdm_bar(iterable, desc):
    return tqdm(enumerate(iterable), total=len(iterable), ncols=150, desc=desc)

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'

print(F"Using device '{device}'")

## Overfit Single Chunk

In [None]:
params = {
    'device': device,
    'discretization': 100,

    'learning_rate': 1e-4, 
    'epochs': 1, 
    'dropout': 0.0,
    'batch_size': 10
}

loaders = get_data_loaders(params['batch_size'], params['discretization'], precomp=True)
train_loader, val_loader, _ = loaders
model = TransformerModel(output_depth=129, d_model=512, nhead=1, d_hid=512, nlayers=1, dropout=params['dropout'], params=params)

train_iter = iter(train_loader)
for i in range(1):
    audio, midi, mask = next(train_iter)

audio = audio.to(device)
midi = midi.to(device)
mask = mask.to(device)

print(audio.shape)
print(midi[0, :, :])

In [None]:
single_file_set = MidiTransformerDataset(dataset_path, 'single_file', params['discretization'], total_length=None, precomputed_midi=False, no_file_lengths=True)
single_file_loader = DataLoader(single_file_set, batch_size=10)
audio, midi, mask = next(iter(single_file_loader))
audio = audio.to(device)
midi = midi.to(device)
mask = mask.to(device)

In [None]:
model.to(device)
model.train()

for i in range(500):
    train_loss = model.training_step(audio, midi, mask, mask)
    if i % 25 == 0:
        print(f'it: {i} loss: {train_loss}')

print(train_loss)

In [None]:
with torch.no_grad():
    # combine midi batches into a single tensor
    midi_plot = midi.reshape(-1, midi.shape[2])
    # Threshold to ignore 
    pu.plot_tensor_as_image(midi_plot.T, figure_shape=(16, 4), threshold=0.5)
    pred_midi = model.forward(audio, midi, mask, mask)
    pred_midi = torch.sigmoid(pred_midi)

    pred_midi_plot = pred_midi.reshape(-1, pred_midi.shape[2])
    pu.plot_tensor_as_image(pred_midi_plot.T, figure_shape=(16, 4), threshold=None)

In [None]:
# Original midi
Song.start_time_tensor_to_midi(midi_plot.squeeze(), 'output/original_midi.midi', 100, note_threshold=0.5)
mu.play_midi('output/original_midi.midi', output_path='output/original_midi.wav')

# Play the midi
Song.start_time_tensor_to_midi(pred_midi_plot.squeeze(), 'output/predicted_midi.midi', 100, note_threshold=0.5)
mu.play_midi('output/predicted_midi.midi')

In [None]:
sequence_index = 0
# Predict single time step
# Combine audio batches into a single tensor
audio_sequence = audio[sequence_index, :, :].reshape(-1, audio.shape[2])
# audio_sequence = audio[sequence_index, :, :]
print(f"Shape of audio sequence: {audio_sequence.shape}")
midi_prediction = model.predict2(audio_sequence, audio_sequence.shape[0], threshold=0.5)

In [None]:
pu.plot_tensor_as_image(midi[sequence_index, :, :].reshape(-1, midi.shape[2]).T, figure_shape=(16, 4))
pu.plot_tensor_as_image(midi_prediction.T, figure_shape=(16, 4), threshold=None)

# Play the midi
Song.start_time_tensor_to_midi(midi_prediction.squeeze(), 'output/predicted_midi.midi', 100)
mu.play_midi('output/predicted_midi.midi')

model.count_parameters()

## Default train loop

In [None]:
def train_loop(model: torch.nn.Module, data_loaders, params, workspace, log_name, function):
    assert 'epochs' in params, 'Number of epochs not specified in params (\'epochs\')'
    assert 'device' in params, 'Device not specified in params (\'device\')'
    assert 'batch_size' in params, 'Batch size not specified in params (\'batch_size\')'
    assert 'learning_rate' in params, 'Learning rate not specified in params (\'learning_rate\')'

    model.to(params['device'])

    logger_path = os.path.join(workspace, 'logs', log_name)
    num_of_runs = len(os.listdir(logger_path)) if os.path.exists(logger_path) else 0
    logger = SummaryWriter(os.path.join(logger_path, f'run_{num_of_runs + 1}'))

    epochs = params['epochs']
    train_loader, val_loader, test_loader = data_loaders
    val_iter = iter(val_loader)
    best_loss = float('inf')
    train_group_length = 100
    val_group_length = 10
    val_loss = 0

    for epoch in range(epochs):
        # Create a progress bar for the training loop.
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')
        # training_loop = train_loader
        train_iter = iter(training_loop)
        val_iteration = 0
        train_iteration = 0

        for iteration_group in range(0, len(training_loop), train_group_length):        
            train_loss = 0

            for train_group_index in range(min(train_group_length, len(training_loop) - iteration_group * train_group_length)):
                _, batch = next(train_iter)
                # batch = next(train_iter)
                train_iteration += 1
                # print(f"train_iteration: {train_iteration}, train_group_index: {train_group_index}, batch_size: {len(batch)}")

                # Actual training
                src, tgt, pad_mask = batch
                # Shift target to right by one
                tgt = tgt[:, 1:]
                # Invert pad mask
                tgt_pad_mask = pad_mask[:, 1:]
                loss = model.training_step(src, tgt, pad_mask, tgt_pad_mask)
                train_loss += loss

                # if train_iteration % 10 == 0:
                #     print(f'Training Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss / (train_iteration + 1)}')

                # Progress indicator
                if train_iteration % 10 == 0:
                    training_loop.set_postfix(curr_train_loss="{:.8f}".format(
                        train_loss / (train_group_index + 1)), val_loss="{:.8f}".format(val_loss), refresh=True)
                    logger.add_scalar(f'{log_name}/train_loss', loss, epoch * len(train_loader) + train_iteration)


            # print(f'Training Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss / len(train_loader)}')

            # Validation
            val_loss = 0
            for _ in range(val_group_length):
                val_iteration += 1
                batch = next(val_iter)
                src, tgt, pad_mask = batch
                loss = model.validation_step(src, tgt, pad_mask, pad_mask)
                logger.add_scalar(f'{log_name}/val_loss', loss, epoch * len(val_loader) + val_iteration)
                val_loss += loss

            val_loss /= val_group_length
                
            if val_loss < best_loss:
                best_loss = val_loss
                path = os.path.join(workspace, 'out_models', log_name, f'best_model.pt')
                # Create path if it doesn't exist
                os.makedirs(os.path.dirname(path), exist_ok=True)
                model.save_state(path)

In [None]:
params = {'device': device, 'learning_rate': 5e-4, 'epochs': 1, 'batch_size': 10}

model = TransformerModel(output_depth=129, d_model=512, nhead=1, d_hid=1024, nlayers=1, dropout=0.1, params=params)
train_data, val_data, test_data = get_data_loaders(params['batch_size'], time_discretization=100, precomp=True)

print(f"Train data length: {len(train_data)}")

train_loop(model, (train_data, val_data, test_data), params, workspace, 'transformer', None)