In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torchaudio
import numpy as np
from mido import MidiFile, MidiTrack, Message
import matplotlib.pyplot as plt
import librosa
import timeit

from data.Dataset import MidiDataset, DatasetUtils, MidiIterDataset
import MidiUtils as mu
from data.Note import Note
from data.Song import Song
import PlotUtils

dataset_path = "/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0"
workspace = "/Users/andreas/Development/Midi-Conversion/PianoTranscription"

# Computing the total length of the dataset is expensive, so we cache it here
TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100 = 57412301
VAL_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7009869
TEST_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7214840

### Precompute dataset paths files and midi files

In [None]:
DatasetUtils.create_dataset_files(dataset_path, dataset_path)
discretization = 100

In [None]:
# # Precompute midi dataset
# # Read train, val, test split paths
# with open(f"{dataset_path}/train.txt", "r") as f:
#     train_midi_paths = f.read().splitlines()[1::2]
# with open(f"{dataset_path}/validation.txt", "r") as f:
#     val_midi_paths = f.read().splitlines()[1::2]
# with open(f"{dataset_path}/test.txt", "r") as f:
#     test_midi_paths = f.read().splitlines()[1::2]

# paths = train_midi_paths + val_midi_paths + test_midi_paths

# DatasetUtils.preprocess_midi_dataset(paths, discretization=discretization)

In [None]:
train_set = MidiDataset(dataset_path, "train", discretization=discretization)
val_set = MidiDataset(dataset_path, "validation", discretization=discretization)
test_set = MidiDataset(dataset_path, "test", discretization=discretization)

audio_path = train_set.get_audio_path(0)
midi_path = train_set.get_midi_path(0)

metadata = torchaudio.info(audio_path)
print("Train file 0: ", metadata)


In [None]:
waveform, sample_rate = torchaudio.load(audio_path)

PlotUtils.print_stats(waveform, sample_rate=sample_rate)
PlotUtils.plot_waveform(waveform[:, int(2.5*sample_rate):int(7.5*sample_rate)], sample_rate)
PlotUtils.plot_specgram(waveform[:, int(2.5*sample_rate):int(7.5*sample_rate)], sample_rate)
PlotUtils.play_audio(waveform, sample_rate)

In [None]:
midi = MidiFile(midi_path, clip=True)
print('Filename: {}, length: {}'.format(midi.filename, midi.length))
print("Number of tracks: {}".format(len(midi.tracks)))

mu.print_midi_info(midi_path)
notes = Note.midi_to_notes(midi)

In [None]:
tempo = 500000
midi.tracks[0]
for msg in midi.tracks[0]:
    if msg.type == 'set_tempo':
        tempo = msg.tempo
        break


song = Song(notes, midi.length, ticks_per_beat=midi.ticks_per_beat, tempo=tempo)  
print("song: ", song)
midi_tensor = song.to_start_time_tensor(discretization_step=100)  

In [None]:
# Print occuring notes
occuring_notes = torch.max(midi_tensor, axis=0).values
notes = torch.nonzero(occuring_notes).flatten()
print("Occuring notes: ", notes)



In [None]:
  
anyNotePlayed = torch.max(midi_tensor, dim=1).values
    
# Get times at which any note is played
times = torch.nonzero(anyNotePlayed).flatten()
print("Times at which any note is played: ", times)

In [None]:
PlotUtils.plot_tensor_as_image(midi_tensor.T)

In [None]:
y, sr = librosa.load(audio_path)
C = np.abs(librosa.cqt(y, sr=sr))
fig, ax = plt.subplots()
img = librosa.display.specshow(librosa.amplitude_to_db(C, ref=np.max),
                               sr=sr, x_axis='time', y_axis='cqt_note', ax=ax)
ax.set_title('Constant-Q power spectrum')
fig.colorbar(img, ax=ax, format="%+2.0f dB")

# Set figsize
fig.set_figwidth(100)
fig.set_figheight(6)

In [None]:
train_set = MidiDataset(dataset_path, "train", discretization=100)
midi_path = train_set.get_midi_path(0)
midi = MidiFile(midi_path, clip=True)
notes = Note.midi_to_notes(midi)
song = Song(notes, midi.length, ticks_per_beat=midi.ticks_per_beat, tempo=500000)

In [None]:
%%timeit

midi = MidiFile(midi_path, clip=True)

In [None]:
%%timeit

notes = Note.midi_to_notes(midi)

In [None]:
%%timeit

song = Song(notes, midi.length, ticks_per_beat=midi.ticks_per_beat, tempo=500_000)

In [None]:
midi_tensor = song.to_start_time_tensor_faster(discretization_step=100)

In [None]:
%%timeit

midi_tensor = song.to_start_time_tensor_faster(discretization_step=100)

In [None]:
%%timeit

midi_tensor = song.to_start_time_tensor(discretization_step=100)

In [None]:
basis = torch.tensor([1, 2, 3, 4, 5])
comp = torch.tensor([2, 3]).unsqueeze(1)

print(f'basis shape {basis.shape}, comp shape {comp.shape}')

basis > comp

## Profile Code

In [None]:
import cProfile
cProfile.run('song.to_start_time_tensor_faster(discretization_step=100)')