In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchaudio
import numpy as np
from mido import MidiFile, MidiTrack, Message
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import librosa
import timeit

from data.Dataset import MidiDataset, DatasetUtils, MidiIterDataset, MidiTransformerDataset
import MidiUtils as mu
from data.Note import Note
from data.Song import Song
import PlotUtils
from models.Transformer import TransformerModel
from data.Dataloader import MidiDataloader

dataset_path = "/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0"
workspace = "/Users/andreas/Development/Midi-Conversion/PianoTranscription"

# Computing the total length of the dataset is expensive, so we cache it here
TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100 = 57412301
VAL_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7009869
TEST_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7214840

In [None]:
audio_file = './output/Planet_Earth_II.mp3'
audio_tensor, file_sample_rate = torchaudio.load(audio_file, normalize=True)

samples = audio_tensor.shape[1]
print(F'audio_tensor length: {audio_tensor.shape}, sample_rate: {file_sample_rate}')
print(F"file_length seconds: {samples/file_sample_rate}")

In [None]:
dataset = MidiTransformerDataset(dataset_path, 'train', 100, TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=True)
data_iter = iter(dataset)
length = 480
total_length = 0
iteration_length = 0
# while length == 480:
#     audio, midi = next(data_iter)
#     length = len(audio)
#     total_length += length
#     iteration_length += length
#     if iteration_length > 10_000_000:
#         print(F'total_length: {total_length}')
#         iteration_length = 0
#     if length != 480:
#         print(F'total length: {total_length} length: {length}')


In [None]:
# Plot audio and midi next to each other
audio_path = '/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.wav'
midi_path = '/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.midi'

audio_tensor, file_sample_rate = torchaudio.load(audio_path, normalize=True)
midi = MidiFile(midi_path)

audio_length_s = audio_tensor.shape[1] / file_sample_rate
audio_length_ms = (audio_tensor.shape[1] % file_sample_rate) / file_sample_rate * 1000
print(f"Audio length samples: {audio_tensor.shape[1]}, sample rate: {file_sample_rate}, length in time: {int(audio_length_s)}s {audio_length_ms}ms")

midi_length = midi.length
print(f"Midi length: {midi_length}")
discretization = 100
midi_tensor = DatasetUtils.transform_midi_file(midi_path, discretization).T

x_lim_seconds = max(audio_tensor.shape[1]/file_sample_rate, midi_length)
x_lim_midi = int(x_lim_seconds * discretization)
x_lim_audio = int(x_lim_seconds * file_sample_rate)

print(f'lim_x midi: {x_lim_midi}, audio: {x_lim_audio}')

fig = plt.figure(figsize=(10, 3))
# fig = plt.figure(figsize=(400, 20))
# Add a subplot for the plot
ax1 = fig.add_subplot(2, 1, 1)
# Plot audio waveform exactly from start to end
ax1.plot(audio_tensor[0].numpy())
# Transform midi to midi tensor
ax1.set_xlim(0, x_lim_audio)

ax2 = fig.add_subplot(2, 1, 2)
ax2.set_xlim(0, x_lim_midi)
# Plot midi notes exactly from start to end in same plot
midi_tensor = midi_tensor.to('cpu')
ax2.imshow(midi_tensor, aspect='auto')

ax1.axis("off")
ax2.axis("off")

In [None]:
audio_seq1 = torch.stack([next(data_iter)[0] for i in range(10)])
audio_seq2 = torch.stack([next(data_iter)[0] for i in range(15)])

padded_seq = torch.nn.utils.rnn.pad_sequence([audio_seq1, audio_seq2], batch_first=True)
print(f"padded_seq: {padded_seq.shape}")

In [None]:
seq_len = 10
elems = 30
for i in range(0, elems - seq_len + 1, seq_len):
    print(f"i: {i}")

print(f"last sequence start index: {elems // seq_len}")

In [None]:
dataset = MidiTransformerDataset(dataset_path, 'train', 100, TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=True)
model = TransformerModel(ntoken=480, d_model=512, nhead=1, d_hid=512, nlayers=1, dropout=0.1)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)

for i, value in enumerate(dataloader):
    audio, midi, mask = value
    print(f"audio: {audio.shape}, midi: {midi.shape}")
    output = model.forward(audio, mask)
    print(f"output: {output.shape}")
    break