In [None]:
#installations
!pip install torchaudio
!sudo apt install -y fluidsynth
!pip install --upgrade pyfluidsynth
!pip install pretty_midi

In [None]:
#imports
from itertools import product
import os
import music21
import pretty_midi
from music21 import midi
import pandas as pd
from IPython.display import Audio
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import statistics
import math, random, time
from IPython.display import Audio
from google.colab import drive
from google.colab import files
import torchaudio
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import scipy
import shutil

In [None]:
#mount drive, establish data directory
from google.colab import drive
drive.mount('/content/drive')
data_directory = '/content/drive/MyDrive/maestro-v2.0.0-csv/'

In [None]:
# generate token dict
def generate_token_dict(time_step=0.01, max_shift_secs=1.0, num_velocity_bins=32):

    token_dict = {}
    num_time_bins = int(max_shift_secs / time_step)
    idx = 0

    #time-shift tokens
    for t in range(1, num_time_bins + 1):
        token_dict[f"time_shift_{t}"] = idx
        idx += 1

    #velocity tokens, by bin index 0..31
    for b in range(num_velocity_bins):
        token_dict[f"velocity_{b}"] = idx
        idx += 1

    # note-on & note-off tokens
    for p in range(128):
        token_dict[f"note_on_pitch_{p}"]  = idx; idx += 1
    for p in range(128):
        token_dict[f"note_off_pitch_{p}"] = idx; idx += 1

    #special tokens
    token_dict["<bos>"]  = idx; idx += 1
    token_dict["<eos>"]  = idx

    inverse_token_dict = {v:k for k,v in token_dict.items()}
    return token_dict, inverse_token_dict

In [None]:
import os
import shutil
from tqdm import tqdm
import pretty_midi
import numpy as np
import torch
from torch.utils.data import Dataset

class MaestroDataset(Dataset):
    def __init__(self, split, token_dict=None, directory='/content/drive/MyDrive/maestro-v2.0.0-csv/'):
        import pandas as pd

  #define attributes: train/test/validation split, token dictionary, target directory
        self.split = split
        self.token_dict = token_dict
        self.directory = directory.rstrip('/') + '/'
        # read the CSV and filter  split
        df = pd.read_csv(os.path.join(self.directory, 'maestro-v2.0.0.csv'))
        self.df = df[df['split'] == split].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

  #create token sequences for a single work
    def create_token_sequences(self, pm: pretty_midi.PrettyMIDI, time_step:float = 0.01, token_dict:dict = None, velocity_bins:np.ndarray = None, fs: int = 16, seq_len: int = 500, add_bos_eos: bool = True):
        events = []

        for inst in pm.instruments:
          for note in inst.notes:
            events.append(('on', note.start, note.pitch, note.velocity))
            events.append(('off', note.end, note.pitch, note.velocity))
        events.sort(key=lambda s: s[1])

        seq = [token_dict['<bos>']]
        prev_time = 0.0

        for on_off, time, pitch, velocity in events:
            delta = time - prev_time
            num_steps = int(round(delta/time_step))
            for i in range(num_steps):
              max_bin = int(1.0/ time_step)
              seq.append(token_dict[f"time_shift_{min(i+1, max_bin)}"])
            prev_time += num_steps * time_step
            if on_off == 'on':
              quant_velocity = int(np.argmin(np.abs(velocity_bins - velocity)))
              seq.append(token_dict[f'velocity_{quant_velocity}'])
              seq.append(token_dict[f'note_on_pitch_{pitch}'])

            else:
              seq.append(token_dict[f'note_off_pitch_{pitch}'])
        seq.append(token_dict['<eos>'])

        blocks = []

        for start in range(0, len(seq), seq_len):
          end = start + seq_len
          if end > len(seq):
            break
          block = seq[start:end]
          blocks.append(block)

        return blocks

#create token sequences for entire dataset
    def create_all_token_sequences(self, fs:int = 16, seq_len:int = 500):
        split_dir = os.path.join(self.directory, self.split)
        print("-> files in", split_dir, "=", os.listdir(split_dir)[:10])
        all_seqs = []

        for file_name in tqdm(os.listdir(split_dir)):
            if not file_name.lower().endswith(('.mid', '.midi')):
                continue
            path = os.path.join(split_dir, file_name)
            try:
                pm = pretty_midi.PrettyMIDI(path)
                pm.remove_invalid_notes()
                base_name = os.path.basename(file_name)
                compare = self.df['midi_filename'].apply(os.path.basename) == base_name
                print("    looking for", file_name, " -> matched rows:", compare.sum())
                if not compare.any():
                    print(f"Warning: MIDI file not found in CSV: {path}")
                    continue
                composer_label = self.df.loc[compare, 'canonical_composer'].values[0]
                title_label = self.df.loc[compare, 'canonical_title'].values[0]
            except Exception as e:
                print(f"Skipping {file_name}: {e}")
                continue

            seqs = self.create_token_sequences(pm, fs=fs, seq_len=seq_len, token_dict=self.token_dict, velocity_bins=np.linspace(0,127,32))
            for s in seqs:
                all_seqs.append({'title': title_label, 'sequence': s, 'composer': composer_label, 'transposition amount': 0})
                for pshift, aug in self.augment_idxs(s):
                      all_seqs.append({'title': title_label, 'sequence': aug, 'composer': composer_label, 'transposition amount': pshift})
        return all_seqs

#augment data via transposition
    def augment_idxs(self, seq):
        augmented_shift = []
        for shift in range(-12, +13):
            out = []
            for idx in seq:
                # only shift pitches, ignore special tokens
                if 0 <= idx < 128:
                    n = idx + shift
                    n = max(0, min(127, n))
                    out.append(n)
                else:
                    out.append(idx)
            augmented_shift.append((shift, out))
        return augmented_shift

In [None]:
#generate csv with MIDI tokens based on data with initial suggested splits
token_dict, inverse_token_dict = generate_token_dict()
csv_rows = []

for phase in ['train', 'test', 'validation']:
    ds = MaestroDataset(phase, token_dict,
                        directory='/content/drive/MyDrive/maestro-v2.0.0-csv/')
    seqs = ds.create_all_token_sequences(fs=16, seq_len=500)
    for seq in seqs:
        seq['split'] = phase
        seq['sequence'] = ' '.join(map(str, seq['sequence']))

    csv_rows.extend(seqs)
df = pd.DataFrame(csv_rows)
dir_csv = '/content/drive/MyDrive/maestro-v2.0.0-csv/'
df.to_csv('maestro_token_sequences_bigger_vocab.csv', index=False)
print(f'Wrote {len(df)} rows to {dir_csv}')

In [None]:
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import KFold, train_test_split

#resplit data in different ways indicated by functions below
input_csv = 'maestro_token_sequences_bigger_vocab.csv'

def prepare_data(file, augmentations=False, collaborations=False):
    # Read the data
    df = pd.read_csv(file)

    df['sequence_vector'] = df['sequence'].apply(string_to_vector)

#Apply filters
    if not collaborations:
        df = df[~df['composer'].str.contains('/', na=False)].copy()
    if not augmentations:
        df = df[df['transposition amount'] == 0].copy()

#Create labels (assuming composer_label_dict exists)
    clean_df = df.copy()  #don't mutate the input
    composer_list = sorted(list(set(clean_df['composer'])))  # convert to sorted list for consistent ordering
    num_composers = len(composer_list)
    print(f"Unique composers: {composer_list}")
    print(f"Total composers: {num_composers}")

    # Create proper label dictionary
    composer_label_dict = {composer: idx for idx, composer in enumerate(composer_list)}
    index_composer_dict = {idx: composer for composer, idx in composer_label_dict.items()}
    clean_df['label'] = clean_df['composer'].map(composer_label_dict)

#Select only the necessary columnns
    data = clean_df[['sequence_vector', 'composer', 'label']].copy()

    return data,composer_label_dict,index_composer_dict

def string_to_vector(seq_str):
    return [int(x) for x in seq_str.split()]

def cross_validation_splits(data, labels, max_folds=5, random_state=None):
    #place data into different split containers
    splits = {}

    for k in range(2, max_folds + 1):
        for bool in (False, True):
            kf = KFold(n_splits=k, shuffle=bool, random_state=random_state)
            split_indices = kf.split(data, labels)
            splits[f'{k}-fold' + ('-shuffled' if bool else '')] = list(split_indices)
    return splits

#60-30-10 split
def split_data(features, labels):
    X_train, X_temp, y_train, y_temp = train_test_split(
        features,
        labels,
        test_size=0.4,
        stratify=labels,
        random_state=42
    )

    X_test, X_val, y_test, y_val = train_test_split(
        X_temp,
        y_temp,
        test_size=0.25,
        stratify=y_temp,
        random_state=42
    )

    return X_train, X_test, y_train, y_test, X_val, y_val

#k-fold cross validation
def build_fold_data(data, labels, splits):
    fold_data = {}
    for key, idx_vals in splits.items():
        fold_data[key] = []
        for train_idx, test_idx in idx_vals:
            data_train = data.iloc[train_idx].reset_index(drop=True)
            data_test = data.iloc[test_idx].reset_index(drop=True)
            label_train = labels.iloc[train_idx].reset_index(drop=True)
            label_test = labels.iloc[test_idx].reset_index(drop=True)

            fold_data[key].append({
                'training data': data_train,
                'test data': data_test,
                'training labels': label_train,
                'test labels': label_test
            })
    return fold_data

def export_all_folds(fold_data, output_path='all_folds.csv'):
    all_rows = []
    for split, folds in fold_data.items():
        for i, fold in enumerate(folds, start=1):
            df_train = fold['training data'].copy()
            df_train['label'] = fold['training labels']
            df_train['set'] = 'train'
            df_train['split'] = split
            all_rows.append(df_train)

            df_test = fold['test data'].copy()
            df_test['label'] = fold['test labels']
            df_test['set'] = 'test'
            df_test['split'] = split
            all_rows.append(df_test)

    whole_df = pd.concat(all_rows, ignore_index=True)
    whole_df.to_csv("fold_train.csv.gz", index=False, compression="gzip")
    print(f"Wrote combined CSV to {output_path}")

data, composer_label_dict, index_composer_dict = prepare_data(input_csv, augmentations=True)

prepared_data = data['sequence_vector']
prepared_labels = data['label']

#Uncomment to do k-fold cross validation instead
# splits = cross_validation_splits(prepared_data, prepared_labels, max_folds=10)
# all_fold_data = build_fold_data(prepared_data, prepared_labels, splits)
# export_all_folds(all_fold_data)

data_train, data_test, labels_train, labels_test, data_val, labels_val = split_data(prepared_data, prepared_labels)

def make_df(Data, labels, split_name):
    df = pd.DataFrame({
        'sequence_vector': Data.values,
        'label':            labels.values,
    })
    df['split'] = split_name
    return df

df_train = make_df(data_train, labels_train, 'train')
df_test  = make_df(data_test , labels_test , 'test')
df_val   = make_df(data_val  , labels_val  , 'validation')

df_all = pd.concat([df_train, df_test, df_val], ignore_index=True)
# df_all.to_csv('maestro_new_splits.csv', index=False)
output_file = "maestro_new_splits_augmented_403010_bigger_vocab.csv"
chunk_size = 2500


with open(output_file, 'w', newline='') as f:
    for start in tqdm(range(0, len(df_all), chunk_size), desc="Writing CSV"):
        end   = start + chunk_size
        chunk = df_all.iloc[start:end]
        chunk.to_csv(f, index=False, header=(start == 0))

In [None]:
!zip -r data.zip /content/maestro_token_sequences_bigger_vocab.csv

In [None]:
import numpy as np
import pretty_midi
from IPython.display import Audio, display

#Function to checkand make sure that tokenization can be reversed

def verify(pm: pretty_midi.PrettyMIDI,
                    token_dict: dict,
                    inverse_token_dict: dict,
                    time_step: float = 0.01,
                    max_shift_secs: float = 1.0,
                    num_velocity_bins: int = 32,
                    tol_time: float = 0.01,
                    fs_audio: int = 44100):

    #reate velocity bins
    vel_bins = np.linspace(0, 127, num_velocity_bins)

    #collect original events
    orig = []
    for inst in pm.instruments:
        for note in inst.notes:
            orig.append(("on",  note.start, note.pitch, note.velocity))
            orig.append(("off", note.end,   note.pitch, note.velocity))
    orig.sort(key=lambda e: (e[1], e[0], e[2]))

    #tokenize continuously
    seq = [token_dict["<bos>"]]
    prev_t = 0.0
    max_bin = int(max_shift_secs / time_step)

    for typ, t, p, v in orig:
        # emit time-shift tokens
        dt = t - prev_t
        steps = int(round(dt / time_step))
        for i in range(steps):
            b = min(i+1, max_bin)
            seq.append(token_dict[f"time_shift_{b}"])
        prev_t += steps * time_step

        if typ == "on":
            b = int(np.argmin(np.abs(vel_bins - v)))
            seq.append(token_dict[f"velocity_{b}"])
            seq.append(token_dict[f"note_on_pitch_{p}"])
        else:
            seq.append(token_dict[f"note_off_pitch_{p}"])

    seq.append(token_dict["<eos>"])

    #decode back into event list
    decoded = []
    cur_t    = 0.0
    last_vel = None

    for tok in seq:
        name = inverse_token_dict[tok]

        if name.startswith("time_shift_"):
            # each one is exactly one time_step
            cur_t += time_step

        elif name.startswith("velocity_"):
            # velocity_{b}, where b was the bin index 0…31
            b = int(name.split("_")[-1])
            last_vel = vel_bins[b]

        elif name.startswith("note_on_pitch_"):
            p = int(name.split("_")[-1])
            # emit a note‐on with the most recent velocity
            decoded.append(("on", cur_t, p, last_vel))

        elif name.startswith("note_off_pitch_"):
            p = int(name.split("_")[-1])
            decoded.append(("off", cur_t, p, None))


    #verify matching
    unmatched_orig = orig.copy()
    unmatched_dec  = decoded.copy()
    matches = []

    for o in orig:
        for d in unmatched_dec:
            if (o[0] == d[0] and
                o[2] == d[2] and
                abs(o[1] - d[1]) <= tol_time):
                matches.append((o, d))
                unmatched_orig.remove(o)
                unmatched_dec.remove(d)
                break

    print(f"Original events: {len(orig)}")
    print(f"Decoded events:  {len(decoded)}")
    print(f"Matches:         {len(matches)}")
    if unmatched_orig:
        print("Missing matches for these original events (up to 5):")
        for e in unmatched_orig[:5]: print("   ", e)
    if unmatched_dec:
        print("Spurious decoded events (up to 5):")
        for e in unmatched_dec[:5]: print("   ", e)

    success = (len(unmatched_orig) == 0 and len(unmatched_dec) == 0)

    #reconstruct a new PrettyMIDI with decoded notes
    new_pm = pretty_midi.PrettyMIDI()
    piano  = pretty_midi.Instrument(program=0)  #piano
    note_on_map = {}

    for typ, t, p, v in decoded:
        if typ == "on":
            note_on_map[p] = (t, v or 100)  # fallback velocity
        else:
            if p in note_on_map:
                s, vel = note_on_map.pop(p)
                note = pretty_midi.Note(velocity=int(vel),
                                        pitch=p,
                                        start=s,
                                        end=t)
                piano.notes.append(note)
    new_pm.instruments.append(piano)
    return success

In [None]:
from pretty_midi import PrettyMIDI

#buld token dictionary
token_dict, inverse_token_dict = generate_token_dict()

#test with a single midi file
pm = PrettyMIDI("/content/drive/MyDrive/maestro-v2.0.0/validation/ORIG-MIDI_03_7_8_13_Group__MID--AUDIO_18_R2_2013_wav--3.midi")

#verify
ok = verify(pm, token_dict, inverse_token_dict)
print("round-trip OK!" if ok else "problems detected")