**Synphony**

Deep Learning Final Project - MSDS Spring Module 2 - 2025

Aditi Puttur & Emma Juan

# 1. Data Preprocessing

In [2]:
import pandas as pd
import numpy as np

import os
import json

from tqdm import tqdm

import re
import unicodedata

import warnings
warnings.filterwarnings("ignore")

from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile
from symusic import Score

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import math
from typing import Optional

import traceback

## Loading the data

### LMD: Midi Files

LMD-matched: [Download Link](https://colinraffel.com/projects/lmd/#get)

On the list of places to download, select any of the mirrors for **LMD-matched**

In [3]:
# Open and read the JSON file
with open('data/LMD/md5_to_paths.json', 'r') as file:
    md5_to_paths = json.load(file)

In [7]:
md5_to_paths['1c83fc02b8c57fbc2605900bb31793fb']

['E/Exaltasamba - Megastar.mid',
 'Midis Samba e Pagode/Exaltasamba - Megastar.mid',
 'Midis Samba e Pagode/Exaltasamba - Megastar.mid']

In [9]:
lmd_catalog = []

for dirpath, dirnames, filenames in os.walk('data/LMD/lmd_matched'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.mid'):
            lmd_catalog.append(full_path)

In [10]:
lmd_catalog.sort()
lmd_catalog[:10]

['data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/1d9d16a9da90c090809c153754823c2b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/5dd29e99ed7bd3cc0c5177a6e9de22ea.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/b97c529ab9ef783a849b896816001748.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/dac3cdd0db6341d8dc14641e44ed0d44.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/05f21994c71a5f881e64f45c8d706165.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/10288ea8e07b70c17f872fda82b94330.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/6304d2bba4282f8bd74322828c30f0c7.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/c24989559d170135b9c6546d1d2df20b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/ddb6a3db65461dca1a43de72f5375d8b.mid',
 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/dfea6fd75926c571a87db789280d059d.mid']

In [6]:
len(lmd_catalog)

116189

In [7]:
lmd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'LMD_name': []}

lmd_catalog_all['path'] = lmd_catalog
lmd_catalog_all['MSD_name'] = [path.split('/')[-2] for path in lmd_catalog]
lmd_catalog_all['LMD_name'] = [path.split('/')[-1].split('.')[-2] for path in lmd_catalog]

lmd_df = pd.DataFrame(lmd_catalog_all)
lmd_df

Unnamed: 0,path,MSD_name,LMD_name
0,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,TRAAAGR128F425B14B,1d9d16a9da90c090809c153754823c2b
1,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,TRAAAGR128F425B14B,5dd29e99ed7bd3cc0c5177a6e9de22ea
2,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,TRAAAGR128F425B14B,b97c529ab9ef783a849b896816001748
3,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,TRAAAGR128F425B14B,dac3cdd0db6341d8dc14641e44ed0d44
4,data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/...,TRAAAZF12903CCCF6B,05f21994c71a5f881e64f45c8d706165
...,...,...,...
116184,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,TRZZZTN128EF35C42F,165e156e5192569e41dc8390b80a1465
116185,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,TRZZZTN128EF35C42F,87e403b5fcb06718767aee0a9386f86c
116186,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,TRZZZTN128EF35C42F,c56e00ecc890dfdfbdd551cb9ea15ca5
116187,data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/...,TRZZZYV128F92E996D,1b966417a9aa703873c5fa1cfe18da32


In [8]:
lmd_df["MSD_name"].nunique()

31034

**Lackh MIDI Dataset (only tracks with matching metadata files) → 31,034 tracks / 116,189 MIDI files.**

### LMD-matched metadata (MillionSongDataset): The Metadata

LMD-matched metadata: [Download Link](https://colinraffel.com/projects/lmd/#get)

On the list of places to download, select any of the mirrors for **LMD-matched metadata**

We will extract title, artist and year from the metadata and add it to our dataset.

In [11]:
import hdf5_getters

In [12]:
msd_catalog = []
titles = []
artists = []
releases = []
years = []

for dirpath, dirnames, filenames in tqdm(os.walk('data/LMD-matched-MSD')):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.h5'):

            # Append the path to the list
            msd_catalog.append(full_path)

            # Get the metadata
            h5 = hdf5_getters.open_h5_file_read(full_path)
            titles.append(hdf5_getters.get_title(h5))
            artists.append(hdf5_getters.get_artist_name(h5))
            releases.append(hdf5_getters.get_release(h5))
            years.append(hdf5_getters.get_year(h5))
            # danceability = hdf5_getters.get_danceability(h5)
            # get_energy = hdf5_getters.get_energy(h5)


15298it [07:23, 34.52it/s]


In [13]:
msd_catalog[:10]

['data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5',
 'data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5',
 'data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRIVC12903CA6C5A.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRILD128F92CB682.h5',
 'data/LMD-matched-MSD/R/R/I/TRRRION128F145EBB7.h5',
 'data/LMD-matched-MSD/R/R/N/TRRRNPV128F42AAA55.h5',
 'data/LMD-matched-MSD/R/R/N/TRRRNGS12903CD16D9.h5']

In [12]:
len(msd_catalog)

31034

In [13]:
len(msd_catalog) == lmd_df["MSD_name"].nunique()

True

In [14]:
titles[:5]

[b'Wastelands',
 b'Runaway',
 b'Have You Met Miss Jones? (Swing When Version)',
 b'Goodbye',
 b'La Colegiala']

In [15]:
artists[:5]

[b'Hawkwind',
 b'Del Shannon',
 b'Robbie Williams',
 b'Volebeats',
 b'Rodolfo Y Su Tipica Ra7']

In [16]:
years[:5]

[1994, 1961, 2001, 0, 1997]

In [17]:
titles = [title.decode('utf-8') for title in titles]
artists = [artist.decode('utf-8') for artist in artists]

In [18]:
msd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'title': [],
                   'artist': [],
                   'year': []}

msd_catalog_all['path'] = msd_catalog
msd_catalog_all['title'] = titles
msd_catalog_all['artist'] = artists
msd_catalog_all['year'] = years
msd_catalog_all['MSD_name'] = [path.split('/')[-1].split('.')[-2] for path in msd_catalog]

msd_df = pd.DataFrame(msd_catalog_all)
msd_df

Unnamed: 0,path,MSD_name,title,artist,year
0,data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5,TRRRUFD12903CD7092,Wastelands,Hawkwind,1994
1,data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5,TRRRUTV12903CEA11B,Runaway,Del Shannon,1961
2,data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5,TRRRUJO128E07813E7,Have You Met Miss Jones? (Swing When Version),Robbie Williams,2001
3,data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5,TRRRIYO128F428CF6F,Goodbye,Volebeats,0
4,data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5,TRRRILO128F422FFED,La Colegiala,Rodolfo Y Su Tipica Ra7,1997
...,...,...,...,...,...
31029,data/LMD-matched-MSD/W/W/Y/TRWWYHD12903CC42B1.h5,TRWWYHD12903CC42B1,Gethsemane (I Only Want to Say) (Live-LP Version),Michael Crawford,0
31030,data/LMD-matched-MSD/W/W/Y/TRWWYNJ128F426541F.h5,TRWWYNJ128F426541F,Cold Feelings,Social Distortion,1992
31031,data/LMD-matched-MSD/W/W/P/TRWWPSV128F4244C71.h5,TRWWPSV128F4244C71,Ases Death,At Vance,2001
31032,data/LMD-matched-MSD/W/W/P/TRWWPBK128F42911E9.h5,TRWWPBK128F42911E9,Drowned Maid,Amorphis,1993


In [19]:
msd_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 31034 entries, 0 to 31033
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   path      31034 non-null  object
 1   MSD_name  31034 non-null  object
 2   title     31034 non-null  object
 3   artist    31034 non-null  object
 4   year      31034 non-null  int32 
dtypes: int32(1), object(4)
memory usage: 1.1+ MB


**Million Song Dataset (only the files for the matched LMD dataset) → 31,034 metadata files (.h5 format)**

### tagtraum: Adding Genre Tags

tagtraum genre annotations for the Million Song Dataset: [Download Link](https://www.tagtraum.com/msd_genre_datasets.html)

Scroll to **Genre Ground Truth** and download the `msd_tagtraum_cd2c.cls.zip` file.

In [20]:
tagtraum = {'MSD_name': [],
            'genre': []}

with open("data/tagtraum/msd_tagtraum_cd2c.cls", "r") as file:
    lines = file.readlines()
    for line in lines:
        if not line.startswith('#'):
            track, genre = line.strip().split('\t')
            tagtraum['MSD_name'].append(track)
            tagtraum['genre'].append(genre)

In [21]:
tagtraum_df = pd.DataFrame(tagtraum)
tagtraum_df

Unnamed: 0,MSD_name,genre
0,TRAAAAK128F9318786,Rock
1,TRAAAAW128F429D538,Rap
2,TRAAADJ128F4287B47,Rock
3,TRAAADZ128F9348C2E,Latin
4,TRAAAED128E0783FAB,Jazz
...,...,...
191396,TRZZZMY128F426D7A2,Reggae
191397,TRZZZRJ128F42819AF,Rock
191398,TRZZZUK128F92E3C60,Folk
191399,TRZZZZD128F4236844,Rock


In [22]:
tagtraum_df["genre"].unique()

array(['Rock', 'Rap', 'Latin', 'Jazz', 'Electronic', 'Pop', 'Metal',
       'RnB', 'Country', 'Reggae', 'Blues', 'Folk', 'Punk', 'World',
       'New Age'], dtype=object)

**Tagtraum genre tags → 191,401 tags**

## Creating our dataset: MIDI + Metadata + Genres

### Midi + Metadata

**Each track (MSD_name -> track_id) has one metadata file, and different MIDI files (LMD_name -> midi_id) associated with it.**

In [23]:
len(lmd_df), len(msd_df)

(116189, 31034)

In [24]:
lmd_df["MSD_name"].nunique(), len(msd_df)

(31034, 31034)

In [25]:
dataset = lmd_df.merge(msd_df, how="inner", on="MSD_name", suffixes=('_lmd', '_msd'))
dataset = dataset.rename(columns={"path_lmd": "midi_filepath",
                                  "path_msd": "metadata_filepath",
                                  "MSD_name": "track_id",
                                  "LMD_name": "midi_id"})
dataset = dataset[["track_id", "midi_id", "midi_filepath",
                   "title", "artist", "year"]]
dataset

Unnamed: 0,track_id,midi_id,midi_filepath,title,artist,year
0,TRAAAGR128F425B14B,1d9d16a9da90c090809c153754823c2b,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008
1,TRAAAGR128F425B14B,5dd29e99ed7bd3cc0c5177a6e9de22ea,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008
2,TRAAAGR128F425B14B,b97c529ab9ef783a849b896816001748,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008
3,TRAAAGR128F425B14B,dac3cdd0db6341d8dc14641e44ed0d44,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008
4,TRAAAZF12903CCCF6B,05f21994c71a5f881e64f45c8d706165,data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/...,Break My Stride,Matthew Wilder,1983
...,...,...,...,...,...,...
116184,TRZZZTN128EF35C42F,165e156e5192569e41dc8390b80a1465,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,Funky Dance Music Vol 1,DJ Rob E,0
116185,TRZZZTN128EF35C42F,87e403b5fcb06718767aee0a9386f86c,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,Funky Dance Music Vol 1,DJ Rob E,0
116186,TRZZZTN128EF35C42F,c56e00ecc890dfdfbdd551cb9ea15ca5,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,Funky Dance Music Vol 1,DJ Rob E,0
116187,TRZZZYV128F92E996D,1b966417a9aa703873c5fa1cfe18da32,data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/...,Dear Lie,TLC,1999


Some tracks have multiple MIDI files. We will only keep one MIDI file per track.

In [26]:
grouped_dataset = dataset.groupby('track_id').first().reset_index()
grouped_dataset = grouped_dataset[['track_id', 'midi_id', 'midi_filepath']]
grouped_dataset = grouped_dataset.merge(
    dataset[
        ['track_id', "title", "artist", "year"]
    ].drop_duplicates(), on='track_id', how='left' )
grouped_dataset = grouped_dataset[["track_id", "midi_id", "midi_filepath",
                                   "title", "artist", "year"]]
grouped_dataset

Unnamed: 0,track_id,midi_id,midi_filepath,title,artist,year
0,TRAAAGR128F425B14B,1d9d16a9da90c090809c153754823c2b,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008
1,TRAAAZF12903CCCF6B,05f21994c71a5f881e64f45c8d706165,data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/...,Break My Stride,Matthew Wilder,1983
2,TRAABVM128F92CA9DC,0dd4d2b9fbcf96a0fa363a1918255e58,data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/...,Caught In A Dream,Tesla,2004
3,TRAABXH128F42955D6,01ffb8729a2465bfa7f9ba0288c89e24,data/LMD/lmd_matched/A/A/B/TRAABXH128F42955D6/...,Keep An Eye On Summer (Album Version),Brian Wilson,1998
4,TRAACQE12903CC706C,1ee7c9ad5f18b2659789d9608c951ca5,data/LMD/lmd_matched/A/A/C/TRAACQE12903CC706C/...,Summer,Old Man River,2007
...,...,...,...,...,...,...
31029,TRZZYLO12903CAC06C,128551e12d6dec38ad7ce00665c77fe5,data/LMD/lmd_matched/Z/Z/Y/TRZZYLO12903CAC06C/...,I've Never Seen The Righteous Forsaken,Dallas Holm,0
31030,TRZZYTX128F92EBE33,538838021299e65875a8bec61a87a368,data/LMD/lmd_matched/Z/Z/Y/TRZZYTX128F92EBE33/...,I Don't Want To Do It (2009 Digital Remaster),George Harrison,0
31031,TRZZZBU128F426811B,0702ddab7728f7b0e51321d8a0366367,data/LMD/lmd_matched/Z/Z/Z/TRZZZBU128F426811B/...,Dame Una Se񡬢 size=,Los Iracundos,0
31032,TRZZZTN128EF35C42F,165e156e5192569e41dc8390b80a1465,data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/...,Funky Dance Music Vol 1,DJ Rob E,0


### Adding the genre tags

In [27]:
dataset = dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
dataset = dataset.drop(columns=["MSD_name"])
dataset

Unnamed: 0,track_id,midi_id,midi_filepath,title,artist,year,genre
0,TRAAAGR128F425B14B,1d9d16a9da90c090809c153754823c2b,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008,Pop
1,TRAAAGR128F425B14B,5dd29e99ed7bd3cc0c5177a6e9de22ea,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008,Pop
2,TRAAAGR128F425B14B,b97c529ab9ef783a849b896816001748,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008,Pop
3,TRAAAGR128F425B14B,dac3cdd0db6341d8dc14641e44ed0d44,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008,Pop
4,TRAABVM128F92CA9DC,0dd4d2b9fbcf96a0fa363a1918255e58,data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/...,Caught In A Dream,Tesla,2004,Rock
...,...,...,...,...,...,...,...
21348,TRZZROL12903CAC4A8,0f0aaf2f90bc66da732f4371e703eae4,data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/...,Love Love,Amy MacDonald,2010,Pop
21349,TRZZSML12903CBB7BD,bc4aae694e7c433a6da16284e52e11be,data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/...,Airwave (Radio Edit),Rank 1,2000,Electronic
21350,TRZZTHP128F427F139,b085f5c3571f570bdc44fa0c9b6a0672,data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/...,Briaris,The Sweetest Ache,1992,Rock
21351,TRZZTHP128F427F139,f10a54a5e8b4d169eec5231bb6b15c94,data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/...,Briaris,The Sweetest Ache,1992,Rock


In [28]:
grouped_dataset = grouped_dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
grouped_dataset = grouped_dataset.drop(columns=["MSD_name"])
grouped_dataset

Unnamed: 0,track_id,midi_id,midi_filepath,title,artist,year,genre
0,TRAAAGR128F425B14B,1d9d16a9da90c090809c153754823c2b,data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/...,Into The Nightlife,Cyndi Lauper,2008,Pop
1,TRAABVM128F92CA9DC,0dd4d2b9fbcf96a0fa363a1918255e58,data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/...,Caught In A Dream,Tesla,2004,Rock
2,TRAAGMC128F4292D0F,0644195d1a3d14e0a0bd3d8b30dc68da,data/LMD/lmd_matched/A/A/G/TRAAGMC128F4292D0F/...,My Love (Album Version),LITTLE TEXAS,0,Country
3,TRAANZE128F148BF55,0597bf18743a5aacfedc981eb58c9da9,data/LMD/lmd_matched/A/A/N/TRAANZE128F148BF55/...,The Name Of The Game,Abba,1977,Pop
4,TRAAPPQ128F14961F5,d39a20f33af4fb6b307529db8cf0cc3f,data/LMD/lmd_matched/A/A/P/TRAAPPQ128F14961F5/...,Wig,The B-52's,1986,Rock
...,...,...,...,...,...,...,...
6175,TRZZQGM128F9311E60,34d27fedd8dca07e36f50d69ba477e5b,data/LMD/lmd_matched/Z/Z/Q/TRZZQGM128F9311E60/...,Sun Of Jamaica,Goombay Dance Band,1991,Pop
6176,TRZZROL12903CAC4A8,0f0aaf2f90bc66da732f4371e703eae4,data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/...,Love Love,Amy MacDonald,2010,Pop
6177,TRZZSML12903CBB7BD,bc4aae694e7c433a6da16284e52e11be,data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/...,Airwave (Radio Edit),Rank 1,2000,Electronic
6178,TRZZTHP128F427F139,b085f5c3571f570bdc44fa0c9b6a0672,data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/...,Briaris,The Sweetest Ache,1992,Rock


When we put the three datasets together, we eneded up with 31,034 data points (MIDI file, metadata file, and genre tag)

## Sluggifying our parameters

**Slug‑safe metadata** – ASCII‑safe, ALL\_CAPS slugs for 15 genres, 2,956 artists, 60 years (between 1945 - 2010).

In [29]:
genres = dataset["genre"].unique()
artists = dataset["artist"].unique()
years = dataset["year"].unique()

In [30]:
def slug(text: str) -> str:
    """Return an ALL_CAPS alnum/underscore version of `text`."""
    # 1) strip accents → ascii
    text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
    # 2) replace non‑alnum with underscore
    text = re.sub(r"[^\w]+", "_", text)
    # 3) collapse multiple underscores and upper‑case
    return re.sub(r"_+", "_", text).strip("_").upper()

In [31]:
genres_slugged = np.array([slug(genre) for genre in genres])
artists_slugged = np.array([slug(artist) for artist in artists])
years = np.array([int(year) for year in years if not pd.isna(year)])

In [32]:
genres = pd.DataFrame({
    'genre': genres,
    'slugged_genre': genres_slugged
})

artists = pd.DataFrame({
    'artist': artists,
    'slugged_artist': artists_slugged
})

years = pd.DataFrame({
    'year': years
})

In [33]:
genres = genres.sort_values(by='genre')
artists = artists.sort_values(by='artist')
years = years.sort_values(by='year')

In [34]:
dataset["slugged_genre"] = dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
dataset["slugged_artist"] = dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

grouped_dataset["slugged_genre"] = grouped_dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
grouped_dataset["slugged_artist"] = grouped_dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

## Saving our data

### Saving the metadata datasets

In [35]:
dataset.to_csv("data/metadata.csv", index=False)

In [36]:
grouped_dataset.to_csv("data/grouped_metadata.csv", index=False)

### Saving the different parameters to csvs

In [37]:
genres.to_csv("data/genres.csv", index=False)
artists.to_csv("data/artists.csv", index=False)
years.to_csv("data/years.csv", index=False)

# 2. Model Implementation

In [85]:
dataset = pd.read_csv("data/metadata.csv")
grouped_dataset = pd.read_csv("data/grouped_metadata.csv")

genres = pd.read_csv("data/genres.csv")
titles = pd.read_csv("data/titles.csv")
artists = pd.read_csv("data/artists.csv")
years = pd.read_csv("data/years.csv")

In [86]:
genres_slugged = genres["slugged_genre"].values
artists_slugged = artists["slugged_artist"].values
years_vals = years["year"].values

In [87]:
# Config whith which the model was trained
# MAX_TOKENS = 512
# BATCH_SIZE = 2

# D_MODEL    = 512
# N_LAYERS   = 6
# N_HEADS    = 8

# New config to try
MAX_TOKENS = 1024
BATCH_SIZE = 8

D_MODEL = 768
N_LAYERS = 8
N_HEADS = 12 # 768 / 12 = 64 per head

## Tokenisation

Tokenisation converts variable‑length MIDI into a single integer stream compatible with text‑style language modelling, while injecting controllable style cues.

### Defining the tokenizer

Library: miditok‑REMI with config:

- use_chords=True
- use_programs=True
- 32 velocity bins
- beat‑resolution {(0‑4):8,(4‑8):4}
- rests and time‑signatures enabled

In [88]:
config = TokenizerConfig(
    num_velocities=32,
    use_chords=True,
    use_programs=True,
    beat_res={(0,4): 8, (4,8): 4},
    use_rests=True,
    rest_range=(2,8),
    use_time_signatures=True
)

tokenizer = REMI(config)

### Adding our special tokens

Conditioning We prepend three special tokens per piece: <GENRE_X> <ARTIST_Y> <YEAR_Z> (vocab extended programmatically). Each full sequence ends with <EOS>.

In [89]:
special_toks = \
    [f"<GENRE_{g}>"  for g in genres_slugged] + \
        [f"<ARTIST_{a}>" for a in artists_slugged] + \
            [f"<YEAR_{y}>"   for y in years_vals]  + \
                ["<EOS>", "<PAD>"]

for tok in special_toks:
    tokenizer.add_to_vocab(tok)

### Tokenising: Storing each track as a numpy int32 array.

In [90]:
tokenizing = False

In [91]:
# ─── 1. Helpers ──────────────────────────────────────────────────────────
def build_prefix(genre, artist, year, tokenizer):
    """Convert metadata row → list[int] conditioning tokens."""
    genre_tok  = f"<GENRE_{genre}>"
    artist_tok = f"<ARTIST_{artist}>"
    year_tok   = f"<YEAR_{year}>"

    # NOTE: use tokenizer.vocab[...]  (or .token_to_id(...))
    return [
        tokenizer.vocab[genre_tok],
        tokenizer.vocab[artist_tok],
        tokenizer.vocab[year_tok],
    ]

# ─── 3. Output directory -------------------------------------------------
out_dir = "data/tokens/train"

# ─── 4. Iterate files ----------------------------------------------------
if tokenizing:
    rows, _ = grouped_dataset.shape
    for row in tqdm(range(rows)):
        try:
            # 4.0. Get row
            row = grouped_dataset.iloc[row]

            # 4.1. Get MIDI filepath
            midi_path = row["midi_filepath"]

            # 4.2. Get the track ID
            track_id = row["track_id"]

            # 4a. Build CONDITIONING prefix
            genre = row["slugged_genre"]
            artist = row["slugged_artist"]
            year = row["year"]
            prefix_ids = build_prefix(genre, artist, year, tokenizer)          # list[int]

            # 4b. Encode MIDI to tokens
            midi = Score(midi_path)
            midi_tokens = tokenizer(midi)                 # list[int]

            # 4c. Concatenate prefix + midi + <EOS>
            seq_ids = prefix_ids + midi_tokens.ids + [tokenizer.vocab["<EOS>"]]

            # 4d. Save as int32 .npy
            np.save(f"{out_dir}/{track_id}.npy", np.array(seq_ids, dtype=np.int32))
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            traceback.print_exc()
            continue

## The Model

**Synphony** is a *decoder‑only Transformer* built in PyTorch 2.2 for autoregressive token prediction:

1. **Embedding** – each of the 3 534 tokens is projected to a 768‑dimensional vector.
2. **Relative sinusoidal positional encoding** – max sequence length 1024; lets the model extrapolate beyond training lengths.
3. **8 × `TransformerDecoderBlock`** – every block contains 12‑head self‑attention (64 d per head), residual LayerNorm, a GELU feed‑forward layer, and dropout 0.1.
    - Causal and pad masks are merged into a single FP32 attention mask to avoid memory blow‑ups.
4. **LayerNorm + Linear head** – normalise the final hidden state and project back to the vocabulary for next‑token logits.

In shorthand: **Embedding → PosEnc → 8 × DecoderBlock → LayerNorm → Linear**, giving the model enough depth and width to capture harmonic and rhythmic structure while remaining trainable on Apple‑Silicon hardware.

In [92]:
class RelativePositionalEncoding(nn.Module):
    """
    Sinusoidal *relative‑style* positional encoding.
    The tensor it returns has the same shape as `x`
    so you can just add it:  x + pos(x)

    Args
    ----
    d_model : int            # embedding size
    max_len : int, optional  # maximum sequence length
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Create the (max_len, d_model) sinusoid table once
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * -(math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)          # (L, D)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as a buffer so it moves with .to(device)
        self.register_buffer("pe", pe)              # (L, D)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : Tensor, shape (batch, seq_len, d_model)

        Returns
        -------
        pos : Tensor, same shape as `x`
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}")
        # (1, L, D) – broadcast over batch dimension
        return self.pe[:seq_len].unsqueeze(0)


In [93]:
class TransformerDecoderBlock(nn.Module):
    """
    Decoder block that merges causal + pad masking into a (B×H, L, L) float mask,
    so no hidden bool→float blow-ups occur.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_len: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim   = d_model,
            num_heads   = n_heads,
            dropout     = dropout,
            batch_first = True,
        )
        self.ln1      = nn.LayerNorm(d_model)
        self.ln2      = nn.LayerNorm(d_model)
        self.ff       = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout  = nn.Dropout(dropout)

        # Precompute float causal mask: 0 on/under diag, -inf above
        causal = torch.triu(
            torch.full((max_len, max_len), float("-inf")),
            diagonal=1
        )
        self.register_buffer("causal_mask", causal, persistent=False)

    def forward(
        self,
        x: torch.Tensor,            # (B, L, D)
        pad_mask: torch.Tensor=None  # (B, L), True=keep token, False=pad
    ) -> torch.Tensor:
        B, L, _ = x.shape
        H       = self.self_attn.num_heads
        device  = x.device
        dtype   = x.dtype

        # 1) slice the (L×L) causal mask
        causal = self.causal_mask[:L, :L]              # float32, (L, L)

        # 2) build a (B, L) float pad mask: 0 on tokens, -inf on pads
        if pad_mask is not None:
            pad_float = torch.zeros((B, L), device=device, dtype=dtype)
            pad_float = pad_float.masked_fill(~pad_mask, float("-inf"))
            # 3) expand pad_float to (B, L, L) and add causal
            #    pad_float.unsqueeze(1): (B, 1, L) → broadcast over src_len
            attn_batch = causal.unsqueeze(0) + pad_float.unsqueeze(1)  # (B, L, L)
        else:
            attn_batch = causal                               # (L, L)

        # 4) if we have a batch, repeat per-head to (B×H, L, L)
        if pad_mask is not None:
            # attn_batch: (B, L, L) → repeat each batch H times
            attn_mask = attn_batch.repeat_interleave(H, dim=0)  # (B*H, L, L)
        else:
            attn_mask = attn_batch   # 2D mask

        # 5) self-attention with ONLY attn_mask
        attn_out, _ = self.self_attn(
            x, x, x,
            attn_mask=attn_mask
        )

        # 6) residual + norm + feed-forward + norm
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x


In [94]:
class Synphony(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = RelativePositionalEncoding(d_model, max_len=2048)
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embed(x) + self.pos(x)
        for blk in self.blocks:
            x = blk(x, pad_mask)
        x = self.ln(x)
        return self.out(x)

## The Training Loop

- **Perplexity (PPL)** as primary intrinsic metric.
    
    Lower PPL ≈ model is less “surprised” by true sequences, correlating with better musical coherence.
    
    $PPL = exp{(\text{avg cross entropy} = \frac{\text{running loss}}{\text{train loader size}})}$
    
- **Training details**
    - Machine: 1 NVIDIA L4 GPU → g2-standard-8 (8 vCPUs, 32 GB Memory)
    - `MAX_TOKENS` = 1024, `BATCH_SIZE` = 8 → 50 epochs (≈ 7 h).
- **Vocabulary** – 3,534 tokens, incl. 125 special conditioning IDs.
- **Training objective**
    
    Classic language‑model framing turns music generation into a well‑studied optimisation problem.
    
    - **Next‑token prediction** (teacher‑forcing):
        - `loss = cross_entropy(logits, target, ignore_index=PAD_ID, label_smoothing=0.1)`.
        - Padded positions are masked; causal + pad masks are merged to keep attention logits float32‐sized.
- **Optimisation regimen**
    - **AdamW**
        - Learning Rate = 3 e‑4
        - weight‑decay = 1 e‑2
            - A smaller weight-decay was applied for regularisation purposes. The small decay term (1 × 10⁻²) discourages the weights from growing too large, helping generalisation.
    - **Label Smoothing** = 0.1
        
        Instead of treating the target token as probability 1.0, we soften it to 0.9 and spread 0.1 across the rest of the vocabulary. This prevents the model from becoming over‑confident and generally speeds convergence.
        
    - **Gradient clipping** (‖g‖₂ ≤ 1).
        
        Keeps exploding gradients in check by scaling the entire gradient vector to length 1 when it gets too large. That stabilises training, especially with long sequences.
        
    - **LR Scheduler → ReduceLROnPlateau**
        
        Watches the **validation loss**; if it hasn’t improved for **2 epochs**, the scheduler cuts the current learning‑rate in half. That lets us start fast (3 e‑4) and automatically slow down when improvements plateau.
        
        - factor = 0.5
        - patience = 2
        - floor = 1 e‑6
    - **50 epochs**, batch 8, max 1024 tokens.

In [95]:
from torch.utils.data import Dataset, DataLoader

import random
random.seed(42)  # For reproducibility

In [96]:
tok_paths = []

for dirpath, dirnames, filenames in os.walk('data/tokens/train'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.npy'):
            tok_paths.append(full_path)

In [97]:
len(tok_paths)

6150

In [98]:
split_index = int(len(tok_paths) * 0.8)  # 80% train, 20% test
random.shuffle(tok_paths)

train_paths = tok_paths[:split_index]
test_paths = tok_paths[split_index:]

In [104]:
# ─── 1. Dataset + collate ────────────────────────────────────────────────
class MidiTokenDataset(Dataset):
    def __init__(self, npy_paths):
        self.paths = npy_paths

    def __len__(self):               # number of songs in split
        return len(self.paths)

    def __getitem__(self, idx):      # returns 1‑D np.ndarray[int]
        return np.load(self.paths[idx]).astype(np.int64)

def collate_fn(batch, pad_id):
    B, L = len(batch), MAX_TOKENS
    x = torch.full((B, L), pad_id, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq = torch.from_numpy(seq)
        if seq.numel() > L:
            start = torch.randint(0, seq.numel() - L + 1, (1,)).item()
            seq = seq[start : start + L]
        x[i, : seq.numel()] = seq
    pad_mask = ~x.eq(pad_id)
    return x, pad_mask


# ─── 2. DataLoaders ──────────────────────────────────────────────────────
PAD_ID = tokenizer.vocab['<PAD>']          # or use the ID you chose for <PAD>

train_ds = MidiTokenDataset(train_paths)
val_ds   = MidiTokenDataset(test_paths)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)
val_loader   = DataLoader(
    val_ds,   batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)

# ─── 3. Model, optimiser, scheduler ─────────────────────────────────────
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

model = Synphony(
    vocab_size=len(tokenizer), d_model=D_MODEL,
    n_layers=N_LAYERS, n_heads=N_HEADS).to(device)

# 1. Switch to AdamW with weight decay
optim = torch.optim.AdamW(model.parameters(),
                          lr=3e-4,           # whatever your current LR is
                          weight_decay=1e-2) # small wd to regularize

# 2. Set up a Reduce-on-Plateau scheduler
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                  mode='min',        # val loss should go down
                                                  factor=0.5,        # cut LR in half
                                                  patience=2,        # wait 2 epochs
                                                  min_lr=1e-6,       # floor on LR
                                                  verbose=True)


# ─── 4. Training loop ────────────────────────────────────────────────────
best_val_loss = float("inf")

for epoch in tqdm(range(1, 51)):                         # 50 epochs
    # ---- train ----------------------------------------------------------
    model.train()
    running_loss = 0.0

    for x, pad_mask in train_loader:          # pad_mask: (B, L)
        x, pad_mask = x.to(device), pad_mask.to(device)

        logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])

        loss   = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            x[:, 1:].reshape(-1),
            ignore_index=PAD_ID,
            label_smoothing=0.1
        )

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        running_loss += loss.item()

    train_ppl = math.exp(running_loss / len(train_loader))

    # ---- validation -----------------------------------------------------
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x, pad_mask in val_loader:             # pad_mask is (B, L)
            x, pad_mask = x.to(device), pad_mask.to(device)

            # exactly like in training
            logits  = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
            val_loss += F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                x[:, 1:].reshape(-1),
                ignore_index=PAD_ID
            ).item()

    val_ppl = math.exp(val_loss / len(val_loader))
    print(f"val PPL {val_ppl:6.2f}")
    print(f"Epoch {epoch:02d} ▸ train PPL {train_ppl:6.2f} | val PPL {val_ppl:6.2f}")
    
    # ---- scheduler step -----------------------------------------------
    sched.step(val_loss / len(val_loader))  # pass your avg val_loss
    
    # log current LR
    current_lr = optim.param_groups[0]['lr']
    print(f"         ↳ LR now = {current_lr:.2e}")

    # ---- checkpoint -----------------------------------------------------
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "synphony_best.pt")
        print("  ✓ new best model saved")

print("Done!")


  0%|          | 0/50 [00:00<?, ?it/s]

val PPL  12.04
Epoch 01 ▸ train PPL  41.38 | val PPL  12.04
         ↳ LR now = 3.00e-04


  2%|▏         | 1/50 [08:14<6:43:27, 494.03s/it]

  ✓ new best model saved
val PPL   6.79
Epoch 02 ▸ train PPL  21.59 | val PPL   6.79
         ↳ LR now = 3.00e-04


  4%|▍         | 2/50 [16:31<6:36:43, 495.90s/it]

  ✓ new best model saved
val PPL   4.01
Epoch 03 ▸ train PPL  12.46 | val PPL   4.01
         ↳ LR now = 3.00e-04


  6%|▌         | 3/50 [24:48<6:28:57, 496.54s/it]

  ✓ new best model saved
val PPL   3.48
Epoch 04 ▸ train PPL   9.59 | val PPL   3.48
         ↳ LR now = 3.00e-04


  8%|▊         | 4/50 [33:05<6:20:55, 496.85s/it]

  ✓ new best model saved
val PPL   3.28
Epoch 05 ▸ train PPL   8.84 | val PPL   3.28
         ↳ LR now = 3.00e-04


 10%|█         | 5/50 [41:23<6:12:51, 497.16s/it]

  ✓ new best model saved
val PPL   3.17
Epoch 06 ▸ train PPL   8.44 | val PPL   3.17
         ↳ LR now = 3.00e-04


 12%|█▏        | 6/50 [49:40<6:04:39, 497.25s/it]

  ✓ new best model saved
val PPL   3.06
Epoch 07 ▸ train PPL   8.14 | val PPL   3.06
         ↳ LR now = 3.00e-04


 14%|█▍        | 7/50 [57:58<5:56:26, 497.36s/it]

  ✓ new best model saved
val PPL   3.03
Epoch 08 ▸ train PPL   7.97 | val PPL   3.03
         ↳ LR now = 3.00e-04


 16%|█▌        | 8/50 [1:06:15<5:48:09, 497.36s/it]

  ✓ new best model saved
val PPL   2.97
Epoch 09 ▸ train PPL   7.81 | val PPL   2.97
         ↳ LR now = 3.00e-04


 18%|█▊        | 9/50 [1:14:33<5:39:53, 497.39s/it]

  ✓ new best model saved
val PPL   2.92
Epoch 10 ▸ train PPL   7.64 | val PPL   2.92
         ↳ LR now = 3.00e-04


 20%|██        | 10/50 [1:22:50<5:31:34, 497.37s/it]

  ✓ new best model saved
val PPL   2.87
Epoch 11 ▸ train PPL   7.55 | val PPL   2.87
         ↳ LR now = 3.00e-04


 22%|██▏       | 11/50 [1:31:08<5:23:17, 497.38s/it]

  ✓ new best model saved


 24%|██▍       | 12/50 [1:39:24<5:14:51, 497.14s/it]

val PPL   2.87
Epoch 12 ▸ train PPL   7.42 | val PPL   2.87
         ↳ LR now = 3.00e-04
val PPL   2.82
Epoch 13 ▸ train PPL   7.35 | val PPL   2.82
         ↳ LR now = 3.00e-04


 26%|██▌       | 13/50 [1:47:42<5:06:43, 497.40s/it]

  ✓ new best model saved
val PPL   2.80
Epoch 14 ▸ train PPL   7.28 | val PPL   2.80
         ↳ LR now = 3.00e-04


 28%|██▊       | 14/50 [1:56:00<4:58:31, 497.54s/it]

  ✓ new best model saved
val PPL   2.77
Epoch 15 ▸ train PPL   7.23 | val PPL   2.77
         ↳ LR now = 3.00e-04


 30%|███       | 15/50 [2:04:18<4:50:18, 497.66s/it]

  ✓ new best model saved
val PPL   2.76
Epoch 16 ▸ train PPL   7.15 | val PPL   2.76
         ↳ LR now = 3.00e-04


 32%|███▏      | 16/50 [2:12:36<4:42:01, 497.70s/it]

  ✓ new best model saved
val PPL   2.74
Epoch 17 ▸ train PPL   7.10 | val PPL   2.74
         ↳ LR now = 3.00e-04


 34%|███▍      | 17/50 [2:20:54<4:33:44, 497.73s/it]

  ✓ new best model saved
val PPL   2.71
Epoch 18 ▸ train PPL   7.05 | val PPL   2.71
         ↳ LR now = 3.00e-04


 36%|███▌      | 18/50 [2:29:11<4:25:23, 497.62s/it]

  ✓ new best model saved
val PPL   2.67
Epoch 19 ▸ train PPL   7.00 | val PPL   2.67
         ↳ LR now = 3.00e-04


 38%|███▊      | 19/50 [2:37:29<4:17:08, 497.69s/it]

  ✓ new best model saved


 40%|████      | 20/50 [2:45:45<4:08:35, 497.17s/it]

val PPL   2.70
Epoch 20 ▸ train PPL   6.96 | val PPL   2.70
         ↳ LR now = 3.00e-04


 42%|████▏     | 21/50 [2:54:01<4:00:11, 496.95s/it]

val PPL   2.68
Epoch 21 ▸ train PPL   6.92 | val PPL   2.68
         ↳ LR now = 3.00e-04
val PPL   2.66
Epoch 22 ▸ train PPL   6.88 | val PPL   2.66
         ↳ LR now = 3.00e-04


 44%|████▍     | 22/50 [3:02:19<3:52:00, 497.15s/it]

  ✓ new best model saved


 46%|████▌     | 23/50 [3:10:35<3:43:35, 496.88s/it]

val PPL   2.66
Epoch 23 ▸ train PPL   6.82 | val PPL   2.66
         ↳ LR now = 3.00e-04
val PPL   2.63
Epoch 24 ▸ train PPL   6.80 | val PPL   2.63
         ↳ LR now = 3.00e-04


 48%|████▊     | 24/50 [3:18:53<3:35:24, 497.11s/it]

  ✓ new best model saved


 50%|█████     | 25/50 [3:27:09<3:26:59, 496.79s/it]

val PPL   2.65
Epoch 25 ▸ train PPL   6.79 | val PPL   2.65
         ↳ LR now = 3.00e-04
val PPL   2.62
Epoch 26 ▸ train PPL   6.75 | val PPL   2.62
         ↳ LR now = 3.00e-04


 52%|█████▏    | 26/50 [3:35:26<3:18:49, 497.05s/it]

  ✓ new best model saved


 54%|█████▍    | 27/50 [3:43:42<3:10:24, 496.72s/it]

val PPL   2.63
Epoch 27 ▸ train PPL   6.73 | val PPL   2.63
         ↳ LR now = 3.00e-04
val PPL   2.61
Epoch 28 ▸ train PPL   6.68 | val PPL   2.61
         ↳ LR now = 3.00e-04


 56%|█████▌    | 28/50 [3:52:00<3:02:12, 496.93s/it]

  ✓ new best model saved


 58%|█████▊    | 29/50 [4:00:16<2:53:51, 496.73s/it]

val PPL   2.61
Epoch 29 ▸ train PPL   6.67 | val PPL   2.61
         ↳ LR now = 3.00e-04
val PPL   2.60
Epoch 30 ▸ train PPL   6.64 | val PPL   2.60
         ↳ LR now = 3.00e-04


 60%|██████    | 30/50 [4:08:34<2:45:42, 497.11s/it]

  ✓ new best model saved
val PPL   2.57
Epoch 31 ▸ train PPL   6.65 | val PPL   2.57
         ↳ LR now = 3.00e-04


 62%|██████▏   | 31/50 [4:16:52<2:37:28, 497.31s/it]

  ✓ new best model saved


 64%|██████▍   | 32/50 [4:25:08<2:29:05, 496.96s/it]

val PPL   2.59
Epoch 32 ▸ train PPL   6.59 | val PPL   2.59
         ↳ LR now = 3.00e-04


 66%|██████▌   | 33/50 [4:33:24<2:20:44, 496.72s/it]

val PPL   2.58
Epoch 33 ▸ train PPL   6.57 | val PPL   2.58
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 34 ▸ train PPL   6.56 | val PPL   2.55
         ↳ LR now = 3.00e-04


 68%|██████▊   | 34/50 [4:41:42<2:12:32, 497.02s/it]

  ✓ new best model saved


 70%|███████   | 35/50 [4:49:58<2:04:11, 496.74s/it]

val PPL   2.57
Epoch 35 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04


 72%|███████▏  | 36/50 [4:58:14<1:55:51, 496.52s/it]

val PPL   2.57
Epoch 36 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 37 ▸ train PPL   6.53 | val PPL   2.55
         ↳ LR now = 3.00e-04


 74%|███████▍  | 37/50 [5:06:31<1:47:38, 496.79s/it]

  ✓ new best model saved
val PPL   2.54
Epoch 38 ▸ train PPL   6.49 | val PPL   2.54
         ↳ LR now = 3.00e-04


 76%|███████▌  | 38/50 [5:14:49<1:39:24, 497.00s/it]

  ✓ new best model saved


 78%|███████▊  | 39/50 [5:23:05<1:31:03, 496.68s/it]

val PPL   2.55
Epoch 39 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04


 80%|████████  | 40/50 [5:31:21<1:22:44, 496.46s/it]

val PPL   2.55
Epoch 40 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04
val PPL   2.54
Epoch 41 ▸ train PPL   6.44 | val PPL   2.54
         ↳ LR now = 3.00e-04


 82%|████████▏ | 41/50 [5:39:38<1:14:30, 496.69s/it]

  ✓ new best model saved
val PPL   2.52
Epoch 42 ▸ train PPL   6.40 | val PPL   2.52
         ↳ LR now = 3.00e-04


 84%|████████▍ | 42/50 [5:47:55<1:06:15, 496.91s/it]

  ✓ new best model saved


 86%|████████▌ | 43/50 [5:56:11<57:56, 496.60s/it]  

val PPL   2.52
Epoch 43 ▸ train PPL   6.41 | val PPL   2.52
         ↳ LR now = 3.00e-04


 88%|████████▊ | 44/50 [6:04:27<49:38, 496.37s/it]

val PPL   2.53
Epoch 44 ▸ train PPL   6.38 | val PPL   2.53
         ↳ LR now = 3.00e-04
val PPL   2.52
Epoch 45 ▸ train PPL   6.38 | val PPL   2.52
         ↳ LR now = 1.50e-04


 90%|█████████ | 45/50 [6:12:45<41:23, 496.68s/it]

  ✓ new best model saved
val PPL   2.46
Epoch 46 ▸ train PPL   6.20 | val PPL   2.46
         ↳ LR now = 1.50e-04


 92%|█████████▏| 46/50 [6:21:02<33:07, 496.86s/it]

  ✓ new best model saved
val PPL   2.45
Epoch 47 ▸ train PPL   6.12 | val PPL   2.45
         ↳ LR now = 1.50e-04


 94%|█████████▍| 47/50 [6:29:19<24:50, 496.93s/it]

  ✓ new best model saved


 96%|█████████▌| 48/50 [6:37:35<16:33, 496.73s/it]

val PPL   2.46
Epoch 48 ▸ train PPL   6.09 | val PPL   2.46
         ↳ LR now = 1.50e-04


 98%|█████████▊| 49/50 [6:45:51<08:16, 496.51s/it]

val PPL   2.45
Epoch 49 ▸ train PPL   6.06 | val PPL   2.45
         ↳ LR now = 1.50e-04
val PPL   2.43
Epoch 50 ▸ train PPL   6.04 | val PPL   2.43
         ↳ LR now = 1.50e-04


100%|██████████| 50/50 [6:54:09<00:00, 496.99s/it]

  ✓ new best model saved
Done!





In [72]:
tokenizer.vocab_size

3534

# 3. Model Inference

In [105]:
model.eval()

Synphony(
  (embed): Embedding(3534, 768)
  (pos): RelativePositionalEncoding()
  (blocks): ModuleList(
    (0-7): 8 x TransformerDecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=768, out_features=3534, bias=True)
)

In [106]:
TEMPERATURE = 1.0
TOP_K = 8

# ─── 2. Helper for top-k filtering ───────────────────────────────────────
def top_k_logits(logits, k):
    v, _ = torch.topk(logits, k)
    threshold = v[-1]
    return torch.where(logits < threshold, torch.full_like(logits, -float("Inf")), logits)

# ─── 3. Autoregressive generation ────────────────────────────────────────
@torch.no_grad()
def generate(
        genre:str,
        artist:str,
        year:int,
        max_length:int = MAX_TOKENS
    ) -> list[int]:
    prefix = build_prefix(genre, artist, year, tokenizer)
    input_ids = torch.tensor([prefix], device=device)  # (1, P)
    pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    for _ in tqdm(range(max_length - len(prefix))):
        logits = model(input_ids, pad_mask=pad_mask)
        next_logits = logits[0, -1, :]                  # (V,)
        next_logits = next_logits / TEMPERATURE
        next_logits = top_k_logits(next_logits, TOP_K)
        probs = F.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (1,)
        if next_id.item() == tokenizer.vocab["<EOS>"]:
            break

        # append and extend pad_mask
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)   # (1, L+1)
        pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    return input_ids[0].tolist()

# ─── 4. Decode to MIDI & save ────────────────────────────────────────────
def tokens_to_midi(token_ids: list[int], out_path: str):
    """
    Drop the 3 metadata tokens + optional EOS, then decode the rest.
    """
    # 1) drop the first 3 prefix IDs (genre, artist, year)
    musical_ids = token_ids[3:]
    # 2) drop trailing <EOS> if present
    eos_id = tokenizer.vocab["<EOS>"]
    if len(musical_ids) > 0 and musical_ids[-1] == eos_id:
        musical_ids = musical_ids[:-1]

    # 3) decode only the musical tokens back to a PrettyMIDI
    pm = tokenizer(musical_ids)
    # 4) write out the .mid file
    pm.dump_midi(out_path)

In [120]:
# ─── 5. Run it! ───────────────────────────────────────────────────────────
# Example user inputs
genre_input  = "ROCK"
artist_input = "GLORIA_GAYNOR"
year_input   = 1990

gen_ids = generate(genre_input, artist_input, year_input, max_length=512)
out_file = "generated.mid"
tokens_to_midi(gen_ids, out_file)
print(f"🎹 Wrote MIDI to {out_file}")

100%|██████████| 509/509 [00:03<00:00, 131.18it/s]

🎹 Wrote MIDI to generated.mid





In [121]:
from midi2audio import FluidSynth
from IPython.display import Audio

# render your MIDI to a WAV
fs = FluidSynth()
fs.midi_to_audio('generated.mid', 'generated.wav')

# now embed the WAV inline
Audio('generated.wav')

Parameter '/home/jupyter/.fluidsynth/default_sound_font.sf2' not a SoundFont or MIDI file or error occurred identifying it.


FluidSynth runtime version 2.1.7
Copyright (C) 2000-2021 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file 'generated.wav'..
