In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchaudio
import numpy as np
from mido import MidiFile, MidiTrack, Message
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import librosa
import timeit
import numpy as np
import random

from data.Dataset import MidiDataset, DatasetUtils, MidiIterDataset, MidiTransformerDataset
import MidiUtils as mu
from data.Note import Note
from data.Song import Song
import PlotUtils
from models.Transformer import TransformerModel
from data.Dataloader import MidiDataloader

from models.Transformer import MidiTransformer
from torch import nn

dataset_path = "/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0"
workspace = "/Users/andreas/Development/Midi-Conversion/PianoTranscription"

# Computing the total length of the dataset is expensive, so we cache it here
TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100 = 57412301
VAL_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7009869
TEST_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7214840

In [None]:
audio_file = './output/Planet_Earth_II.mp3'
audio_tensor, file_sample_rate = torchaudio.load(audio_file, normalize=True)

samples = audio_tensor.shape[1]
print(F'audio_tensor length: {audio_tensor.shape}, sample_rate: {file_sample_rate}')
print(F"file_length seconds: {samples/file_sample_rate}")

In [None]:
dataset = MidiTransformerDataset(dataset_path, 'train', 100, TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=True)
data_iter = iter(dataset)
length = 480
total_length = 0
iteration_length = 0
# while length == 480:
#     audio, midi = next(data_iter)
#     length = len(audio)
#     total_length += length
#     iteration_length += length
#     if iteration_length > 10_000_000:
#         print(F'total_length: {total_length}')
#         iteration_length = 0
#     if length != 480:
#         print(F'total length: {total_length} length: {length}')


In [None]:
# Plot audio and midi next to each other
audio_path = '/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.wav'
midi_path = '/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.midi'

audio_tensor, file_sample_rate = torchaudio.load(audio_path, normalize=True)
midi = MidiFile(midi_path)

audio_length_s = audio_tensor.shape[1] / file_sample_rate
audio_length_ms = (audio_tensor.shape[1] % file_sample_rate) / file_sample_rate * 1000
print(f"Audio length samples: {audio_tensor.shape[1]}, sample rate: {file_sample_rate}, length in time: {int(audio_length_s)}s {audio_length_ms}ms")

midi_length = midi.length
print(f"Midi length: {midi_length}")
discretization = 100
midi_tensor = DatasetUtils.transform_midi_file(midi_path, discretization).T

x_lim_seconds = max(audio_tensor.shape[1]/file_sample_rate, midi_length)
x_lim_midi = int(x_lim_seconds * discretization)
x_lim_audio = int(x_lim_seconds * file_sample_rate)

print(f'lim_x midi: {x_lim_midi}, audio: {x_lim_audio}')

fig = plt.figure(figsize=(10, 3))
# fig = plt.figure(figsize=(400, 20))
# Add a subplot for the plot
ax1 = fig.add_subplot(2, 1, 1)
# Plot audio waveform exactly from start to end
ax1.plot(audio_tensor[0].numpy())
# Transform midi to midi tensor
ax1.set_xlim(0, x_lim_audio)

ax2 = fig.add_subplot(2, 1, 2)
ax2.set_xlim(0, x_lim_midi)
# Plot midi notes exactly from start to end in same plot
midi_tensor = midi_tensor.to('cpu')
ax2.imshow(midi_tensor, aspect='auto')

ax1.axis("off")
ax2.axis("off")

In [None]:
audio_seq1 = torch.stack([next(data_iter)[0] for i in range(10)])
audio_seq2 = torch.stack([next(data_iter)[0] for i in range(15)])

padded_seq = torch.nn.utils.rnn.pad_sequence([audio_seq1, audio_seq2], batch_first=True)
print(f"padded_seq: {padded_seq.shape}")

In [None]:
dataset = MidiTransformerDataset(dataset_path, 'train', 100, TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=True)
model = TransformerModel(output_depth=128, d_model=512, nhead=1, d_hid=512, nlayers=1, dropout=0.1)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)

for i, value in enumerate(dataloader):
    audio, midi, mask = value
    print(f"audio: {audio.shape}, midi: {midi.shape}")
    output = model.forward(src=audio, tgt=midi, src_pad_mask=mask, tgt_pad_mask=mask)
    print(f"output: {output.shape}")
    break

In [None]:
    
def generate_random_data(n):
    SOS_token = np.array([2])
    EOS_token = np.array([3])
    length = 8

    data = []

    # 1,1,1,1,1,1 -> 1,1,1,1,1
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.ones(length), EOS_token))
        y = np.concatenate((SOS_token, np.ones(length), EOS_token))
        data.append([X, y])

    # 0,0,0,0 -> 0,0,0,0
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        y = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        data.append([X, y])

    # 1,0,1,0 -> 1,0,1,0,1
    for i in range(n // 3):
        X = np.zeros(length)
        start = random.randint(0, 1)

        X[start::2] = 1

        y = np.zeros(length)
        if X[-1] == 0:
            y[::2] = 1
        else:
            y[1::2] = 1

        X = np.concatenate((SOS_token, X, EOS_token))
        y = np.concatenate((SOS_token, y, EOS_token))

        data.append([X, y])

    np.random.shuffle(data)

    return data


def batchify_data(data, batch_size=16, padding=False, padding_token=-1):
    batches = []
    for idx in range(0, len(data), batch_size):
        # We make sure we dont get the last bit if its not batch_size size
        if idx + batch_size < len(data):
            # Here you would need to get the max length of the batch,
            # and normalize the length with the PAD token.
            if padding:
                max_batch_length = 0

                # Get longest sentence in batch
                for seq in data[idx : idx + batch_size]:
                    if len(seq) > max_batch_length:
                        max_batch_length = len(seq)

                # Append X padding tokens until it reaches the max length
                for seq_idx in range(batch_size):
                    remaining_length = max_bath_length - len(data[idx + seq_idx])
                    data[idx + seq_idx] += [padding_token] * remaining_length

            batches.append(np.array(data[idx : idx + batch_size]).astype(np.int64))

    print(f"{len(batches)} batches of size {batch_size}")

    return batches


train_data = generate_random_data(9000)
val_data = generate_random_data(3000)

train_dataloader = batchify_data(train_data)
val_dataloader = batchify_data(val_data)

In [None]:


device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'mps'
model = MidiTransformer(
    num_tokens=4, dim_model=8, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1
).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()



In [None]:
def train_loop(model, opt, loss_fn, dataloader):

    model.train()
    total_loss = 0
    
    for batch in dataloader:
        X, y = batch[:, 0], batch[:, 1]
        X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)

        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = y[:,:-1]
        y_expected = y[:,1:]
        
        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(X, y_input, tgt_mask)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)      
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

In [None]:
def validation_loop(model, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[:, 0], batch[:, 1]
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(y, dtype=torch.long, device=device)

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:,:-1]
            y_expected = y[:,1:]
            
            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            # Standard training except we pass in y_input and src_mask
            pred = model(X, y_input, tgt_mask)

            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)      
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

In [None]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list
    
train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_dataloader, val_dataloader, 10)
