# Audio-Midi Transcription

## Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch.utils.data as data_utils


import MidiUtils as mu
import PlotUtils as pu
from data.Dataset import MidiDataset, DatasetUtils, MidiIterDataset
from models.DNNs import MidiTranscriptionModel
from data.Note import Note
from data.Song import Song

dataset_path = "/Users/andreas/Development/Midi-Conversion/maestro-v3.0.0"
workspace = "/Users/andreas/Development/Midi-Conversion/PianoTranscription"

# Computing the total length of the dataset is expensive, so we cache it here
TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100 = 57412301
VAL_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7009869
TEST_SET_TOTAL_LENGTH_DISCRETIZED_100 = 7214840

In [None]:
def get_data_loaders(batch_size=1, time_discretization=100, train_set_size=None, val_set_size=None, shuffle=True, iter_dataset=False, precomp=False):
    # Only create train, validation, test splits if they don't exist (unnecessarily slow)
    if not (os.path.exists(os.path.join(dataset_path, 'train.txt')) and os.path.exists(os.path.join(dataset_path, 'test.txt')) and os.path.exists(os.path.join(dataset_path, 'validation.txt'))):
        DatasetUtils.create_dataset_files(dataset_path)

    if iter_dataset:
        train_set = MidiIterDataset(dataset_path, 'train', time_discretization, total_length=TRAIN_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=precomp)
        val_set = MidiIterDataset(dataset_path, 'validation', time_discretization, total_length=VAL_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=precomp)
        test_set = MidiIterDataset(dataset_path, 'test', time_discretization, total_length=TEST_SET_TOTAL_LENGTH_DISCRETIZED_100, precomputed_midi=precomp)
    else:
        train_set = MidiDataset(dataset_path, 'train', time_discretization)
        val_set = MidiDataset(dataset_path, 'validation', time_discretization)
        test_set = MidiDataset(dataset_path, 'test', time_discretization)

        if train_set_size is not None:
            limit_type, limit = train_set_size
            if limit_type == 'items':  # At least 1 item, at most all items
                limit = max(min(limit, len(train_set)), 1)
            elif limit_type == 'percentage':  # At least 1 item, at most all items
                limit = max(min(limit * len(train_set), len(train_set)), 1)
            else:
                limit = len(train_set)
            train_set = data_utils.Subset(train_set, torch.arange(limit))

        if val_set_size is not None:
            limit_type, limit = val_set_size
            if limit_type == 'items':  # At least 1 item, at most all items
                limit = max(min(limit, len(val_set)), 1)
            elif limit_type == 'percentage':  # At least 1 item, at most all items
                limit = max(min(limit * len(val_set), len(val_set)), 1)
            else:
                limit = len(val_set)
            val_set = data_utils.Subset(val_set, torch.arange(limit))

    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle)
    val_data = DataLoader(val_set, batch_size=batch_size, shuffle=shuffle)
    test_data = DataLoader(test_set, batch_size=batch_size, shuffle=shuffle)

    return train_data, val_data, test_data


def create_tqdm_bar(iterable, desc):
    return tqdm(enumerate(iterable), total=len(iterable), ncols=150, desc=desc)

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'

print(device)

## Training

### Overfit on single song

In [None]:
def train_loop(model: torch.nn.Module, data_loaders, params, workspace, log_name, function):
    assert 'epochs' in params, 'Number of epochs not specified in params (\'epochs\')'
    assert 'device' in params, 'Device not specified in params (\'device\')'
    assert 'batch_size' in params, 'Batch size not specified in params (\'batch_size\')'
    assert 'learning_rate' in params, 'Learning rate not specified in params (\'learning_rate\')'

    model.to(params['device'])

    logger_path = os.path.join(workspace, 'logs', log_name)
    num_of_runs = len(os.listdir(logger_path)) if os.path.exists(logger_path) else 0
    logger = SummaryWriter(os.path.join(logger_path, f'run_{num_of_runs + 1}'))

    epochs = params['epochs']
    train_loader, val_loader, test_loader = data_loaders
    last_time_ms = time.time()/10e6
    val_loss = 0
    best_loss = float('inf')

    for epoch in range(epochs):
        # Create a progress bar for the training loop.
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')

        train_loss = 0
        for train_iteration, batch in training_loop:

            if train_iteration % 1000 == 0:
                print(f'Training Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss / (train_iteration + 1)}')
            # Actual training
            loss = model.training_step(batch)
            train_loss += loss

            # Progress indicator
            if time.time_ns()/10e6 - last_time_ms > 10:
                training_loop.set_postfix(curr_train_loss="{:.8f}".format(
                    train_loss / (train_iteration + 1)), val_loss="{:.8f}".format(val_loss), refresh=True)
                last_time_ms = time.time_ns()/10e6
                logger.add_scalar(f'{log_name}/train_loss', loss, epoch * len(train_loader) + train_iteration)

        print(f'Training Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss / len(train_loader)}')
        train_loss = train_loss / len(train_loader)

        # Validation
        val_loss = 0
        for val_iteration, batch in enumerate(val_loader):
            # Actual validation
            loss = model.validation_step(batch)
            val_loss += loss
            logger.add_scalar(f'{log_name}/val_loss', loss,
                              epoch * len(val_loader) + val_iteration)
        if val_loss < best_loss:
            best_loss = val_loss
            path = os.path.join(workspace, 'out_models', log_name, f'best_model.pt')
            # Create path if it doesn't exist
            os.makedirs(os.path.dirname(path), exist_ok=True)
            model.save_state(path)

In [None]:
params = {
    'device': device,
    'batch_size': 1,
    'train_set_size': ('items', 1), # ('percent', 0.1) = 10 percent, None = all
    'val_set_size': ('items', 1), # ('percent', 0.1) = 10 percent, None = all
    'discretization': 100,

    'input_size': 480,
    'hidden_size_1': 720,
    'hidden_size_2': 1024,
    'hidden_size_3': 720,
    'hidden_size_4': 256,
    'output_size': 128,

    'learning_rate': 2e-3,
    'epochs': 3,
}

loaders = get_data_loaders(params['batch_size'], params['discretization'], train_set_size=params['train_set_size'], val_set_size=params['val_set_size'], shuffle=False)
model = MidiTranscriptionModel(params=params)

train_loop(function=None, model=model, data_loaders=loaders, params=params, workspace=workspace, log_name='overfitting')

## Inference 

In [None]:
# Load a song
train_loader, _, _ = loaders
audio, midi = train_loader.dataset[0]
audio = audio.to(device)
midi = midi.to(device)
pred_midi = model.predict(audio)
pu.plot_tensor_as_image(pred_midi.T)

In [None]:
Song.start_time_tensor_to_midi(pred_midi, 'output/predicted_midi.midi', 100, note_threshold=0.1)
mu.play_midi('output/predicted_midi.midi')

In [None]:
Song.start_time_tensor_to_midi(midi, 'output/reconverted.midi', 100, note_threshold=0.5)
mu.play_midi('output/reconverted.midi')

In [None]:
def train_loop_iter(model: torch.nn.Module, data_loaders, params, workspace, log_name):
    assert 'epochs' in params, 'Number of epochs not specified in params (\'epochs\')'
    assert 'device' in params, 'Device not specified in params (\'device\')'
    assert 'batch_size' in params, 'Batch size not specified in params (\'batch_size\')'
    assert 'learning_rate' in params, 'Learning rate not specified in params (\'learning_rate\')'

    model.to(params['device'])

    logger_path = os.path.join(workspace, 'logs', log_name)
    num_of_runs = len(os.listdir(logger_path)) if os.path.exists(logger_path) else 0
    logger = SummaryWriter(os.path.join(logger_path, f'run_{num_of_runs + 1}'))

    epochs = params['epochs']
    train_loader, val_loader, _ = data_loaders
    last_time_ms = time.time()/10e6
    last_time_val_ms = time.time()/10e6
    val_loss = 0
    best_loss = float('inf')

    for epoch in range(epochs):
        # Create a progress bar for the training loop.
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')

        train_loss = 0
        val_iteration = 0
        for train_iteration, batch in training_loop:
            loss = model.training_step(batch)
            train_loss += loss

            # Progress indicator
            if time.time_ns()/10e6 - last_time_ms > 20:
                training_loop.set_postfix(curr_train_loss="{:.8f}".format(
                    train_loss / (train_iteration + 1)), val_loss="{:.8f}".format(val_loss), refresh=True)
                last_time_ms = time.time_ns()/10e6
                logger.add_scalar(f'{log_name}/train_loss', loss, epoch * len(train_loader) + train_iteration)

            if time.time_ns()/10e6 - last_time_val_ms > 1000:
                last_time_val_ms = time.time_ns()/10e6
                # Validation
                val_time_start = time.time_ns()/10e6
                val_loss = 0
                while (time.time_ns()/10e6 - val_time_start) < 200:
                    # Get next batch from validation loader
                    batch = next(iter(val_loader))

                    # Actual validation
                    loss = model.validation_step(batch)
                    val_loss += loss
                    logger.add_scalar(f'{log_name}/val_loss', loss,
                                    epoch * len(val_loader) + val_iteration)
                    val_iteration += 1

                # if val_loss < best_loss:
                #     best_loss = val_loss
                #     path = os.path.join(workspace, 'out_models', log_name, f'best_model.pt')
                #     # Create path if it doesn't exist
                #     os.makedirs(os.path.dirname(path), exist_ok=True)
                #     model.save_state(path)

            

        print(f'Training Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss / len(train_loader)}')
        train_loss = train_loss / len(train_loader)

        

In [None]:
params = {
    'device': device,
    'batch_size': 4096*2,
    'train_set_size': None, # ('percent', 0.1) = 10 percent, None = all
    'val_set_size': None, # ('items', 1) = 1 elem, ('percent', 0.1) = 10 percent, None = all
    'discretization': 100,

    'input_size': 480,
    'hidden_size_1': 1024,
    'hidden_size_2': 2048,
    'hidden_size_3': 1440,
    'hidden_size_4': 512,
    'output_size': 128,

    'learning_rate': 3e-3,
    'epochs': 1,
}

loaders = get_data_loaders(params['batch_size'], params['discretization'], train_set_size=params['train_set_size'], val_set_size=params['val_set_size'], shuffle=False, iter_dataset=True, precomp=True)
model = MidiTranscriptionModel(params=params)

train_loop_iter(model=model, data_loaders=loaders, params=params, workspace=workspace, log_name='iter_full')

In [None]:
train_loader, _, _ = loaders
# Get next element from iterable dataset
audio, midi = next(iter(train_loader))
audio = audio.to(device)
# Predict the midi
pred_midi = model.predict(audio)
pu.plot_tensor_as_image(pred_midi.T)
# Play the midi
Song.start_time_tensor_to_midi(pred_midi, 'output/predicted_midi.midi', 100, note_threshold=0.1)
mu.play_midi('output/predicted_midi.midi')

## Test overfitting single chunk

In [None]:
params = {
    'device': device,
    'batch_size': 20000,
    'train_set_size': None, # ('percent', 0.1) = 10 percent, None = all
    'val_set_size': None, # ('items', 1) = 1 elem, ('percent', 0.1) = 10 percent, None = all
    'discretization': 100,

    'input_size': 480,
    'hidden_size_1': 1024,
    'hidden_size_2': 2048,
    'hidden_size_3': 1440,
    'hidden_size_4': 512,
    'output_size': 128,

    'learning_rate': 6e-3,
    'epochs': 1,
}

loaders = get_data_loaders(params['batch_size'], params['discretization'], train_set_size=params['train_set_size'], val_set_size=params['val_set_size'], shuffle=False, iter_dataset=True, precomp=True)
train_loader, val_loader, _ = loaders
model = MidiTranscriptionModel(params=params)

train_iter = iter(train_loader)
for i in range(1):
    audio, midi = next(train_iter)

audio = audio.to(device)
midi = midi.to(device)

print(audio.shape)


In [None]:
model.to(device)
model.train()
for i in range(500):
    train_loss = model.training_step((audio, midi))
    if i % 25 == 0:
        print(f'it: {i} loss: {train_loss}')

print(train_loss)

In [None]:
pu.plot_tensor_as_image(midi.T, figure_shape=(16, 4))
pred_midi = model.predict(audio)
pu.plot_tensor_as_image(pred_midi.T, figure_shape=(16, 4))

In [None]:
# Play the midi
Song.start_time_tensor_to_midi(pred_midi, 'output/predicted_midi.midi', 100, note_threshold=0.5)
mu.play_midi('output/predicted_midi.midi')