In [2]:
!pip install music21

Collecting music21
  Downloading music21-9.9.1-py3-none-any.whl.metadata (5.2 kB)
Collecting chardet (from music21)
  Downloading chardet-5.2.0-py3-none-any.whl.metadata (3.4 kB)
Collecting more-itertools (from music21)
  Using cached more_itertools-10.8.0-py3-none-any.whl.metadata (39 kB)
Collecting webcolors>=1.5 (from music21)
  Downloading webcolors-25.10.0-py3-none-any.whl.metadata (2.2 kB)
Downloading music21-9.9.1-py3-none-any.whl (20.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.1/20.1 MB[0m [31m13.4 MB/s[0m  [33m0:00:01[0mm0:00:01[0m00:01[0m
[?25hDownloading webcolors-25.10.0-py3-none-any.whl (14 kB)
Downloading chardet-5.2.0-py3-none-any.whl (199 kB)
Using cached more_itertools-10.8.0-py3-none-any.whl (69 kB)
Installing collected packages: webcolors, more-itertools, chardet, music21
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4/4[0m [music21]m3/4[0m [music21]
[1A[2KSuccessfully installed chardet-5.2.0 more-itertools-10.8.

In [None]:
import os
import json
import random
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple, Optional

import numpy as np
from collections import Counter

from music21 import converter, instrument, note, chord, stream

import tensorflow as tf
from tensorflow.keras import layers, regularizers, callbacks, Model

from sklearn.model_selection import train_test_split

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [2]:
@dataclass
class ModelConfig:
    name: str
    sequence_length: int = 50
    batch_size: int = 64
    embedding_dim: int = 64
    lstm_units: int = 256
    num_lstm_layers: int = 1
    dropout: float = 0.3
    recurrent_dropout: float = 0.0
    l1_reg: float = 1e-5
    learning_rate: float = 1e-3
    max_epochs: int = 50
    patience: int = 5
    min_note_freq: int = 3 

In [16]:
DATA_DIR = "data/midi"             
OUTPUT_DIR = "outputs/music_lstm" 

In [3]:
def load_midi_file_paths(data_dir: str, max_files: Optional[int] = None) -> List[str]:
    midi_files = []
    for root, _, files in os.walk(data_dir):
        for f in files:
            if f.lower().endswith((".mid", ".midi")):
                midi_files.append(os.path.join(root, f))
    midi_files.sort()
    if max_files is not None:
        midi_files = midi_files[:max_files]
    print(f"Found {len(midi_files)} MIDI files.")
    return midi_files

In [None]:

def extract_notes_from_midi(file_paths: List[str]) -> List[str]:
    notes = []
    for i, fp in enumerate(file_paths):
        print(f"[{i+1}/{len(file_paths)}] Parsing {fp} ...")
        try:
            midi = converter.parse(fp)
        except Exception as e:
            print(f"  Skipping {fp} due to parse error: {e}")
            continue

        parts = instrument.partitionByInstrument(midi)
        if parts: 
            part_stream = None
            for p in parts.parts:
                if "Piano" in p.partName if p.partName else False:
                    part_stream = p
                    break
            if part_stream is None:
                part_stream = parts.parts[0]
        else:
            part_stream = midi.flat

        for element in part_stream.recurse():
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
            elif isinstance(element, chord.Chord):
                pitches = ".".join(sorted(str(n) for n in element.pitches))
                notes.append(pitches)
    print(f"Extracted {len(notes)} note/chord events.")
    return notes

In [None]:
def build_vocabulary(notes: List[str], min_freq: int = 3) -> Tuple[Dict[str, int], Dict[int, str]]:
    freq = Counter(notes)
    vocab_tokens = [tok for tok, c in freq.items() if c >= min_freq]
    vocab_tokens.sort()

    token_to_int = {"<PAD>": 0, "<UNK>": 1}
    idx = 2
    for tok in vocab_tokens:
        token_to_int[tok] = idx
        idx += 1

    int_to_token = {i: t for t, i in token_to_int.items()}

    print(f"Vocabulary size (including PAD & UNK): {len(token_to_int)}")
    return token_to_int, int_to_token

In [6]:
def encode_notes(notes: List[str], token_to_int: Dict[str, int]) -> np.ndarray:
    unk_id = token_to_int["<UNK>"]
    encoded = np.array([token_to_int.get(tok, unk_id) for tok in notes], dtype=np.int32)
    print(f"Encoded note sequence length: {len(encoded)}")
    return encoded

In [None]:
def create_sequences(
    encoded_notes: np.ndarray,
    seq_length: int
) -> Tuple[np.ndarray, np.ndarray]:
    
    inputs = []
    targets = []

    for i in range(0, len(encoded_notes) - seq_length):
        inputs.append(encoded_notes[i : i + seq_length])
        targets.append(encoded_notes[i + seq_length])

    X = np.array(inputs, dtype=np.int32)
    y = np.array(targets, dtype=np.int32)

    print(f"Created {X.shape[0]} sequences of length {seq_length}.")
    return X, y

In [None]:
def build_lstm_model(config: ModelConfig, vocab_size: int) -> Model:
    reg = regularizers.l1(config.l1_reg)

    inputs = layers.Input(shape=(config.sequence_length,), name="input_tokens")

    x = layers.Embedding(
        input_dim=vocab_size,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)

    if config.num_lstm_layers == 1:
        x = layers.LSTM(
            config.lstm_units,
            return_sequences=False,
            dropout=config.dropout,
            recurrent_dropout=config.recurrent_dropout,
            kernel_regularizer=reg,
            name="lstm_1"
        )(x)
    elif config.num_lstm_layers == 2:
        x = layers.LSTM(
            config.lstm_units,
            return_sequences=True,
            dropout=config.dropout,
            recurrent_dropout=config.recurrent_dropout,
            kernel_regularizer=reg,
            name="lstm_1"
        )(x)
        x = layers.LSTM(
            config.lstm_units // 2, 
            return_sequences=False,
            dropout=config.dropout,
            recurrent_dropout=config.recurrent_dropout,
            kernel_regularizer=reg,
            name="lstm_2"
        )(x)
    else:
        raise ValueError("num_lstm_layers must be 1 or 2 for this project.")

    x = layers.Dense(
        config.lstm_units // 2,
        activation="relu",
        kernel_regularizer=reg,
        name="dense_projection"
    )(x)

    outputs = layers.Dense(vocab_size, activation="softmax", name="output")(x)

    model = Model(inputs=inputs, outputs=outputs, name=f"lstm_music_{config.name}")

    optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"]
    )

    model.summary()
    return model

In [9]:
def make_tf_dataset(
    X: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    shuffle: bool = True
) -> tf.data.Dataset:
    ds = tf.data.Dataset.from_tensor_slices((X, y))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(X), seed=SEED, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

In [None]:
def train_with_configs(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    vocab_size: int,
    configs: List[ModelConfig]
) -> Tuple[Model, ModelConfig, Dict[str, float]]:
    best_model = None
    best_config = None
    best_val_loss = float("inf")
    best_metrics = {}

    for cfg in configs:
        print("\n" + "=" * 80)
        print(f"Training config: {cfg.name}")
        print(cfg)
        print("=" * 80 + "\n")

        model = build_lstm_model(cfg, vocab_size)

        train_ds = make_tf_dataset(X_train, y_train, batch_size=cfg.batch_size, shuffle=True)
        val_ds = make_tf_dataset(X_val, y_val, batch_size=cfg.batch_size, shuffle=False)

        ckpt_path = os.path.join(OUTPUT_DIR, f"model_{cfg.name}.keras")

        cb_early = callbacks.EarlyStopping(
            monitor="val_loss",
            patience=cfg.patience,
            restore_best_weights=True
        )
        cb_ckpt = callbacks.ModelCheckpoint(
            filepath=ckpt_path,
            monitor="val_loss",
            save_best_only=True
        )
        cb_lr = callbacks.ReduceLROnPlateau(
            monitor="val_loss", factor=0.5, patience=2, verbose=1
        )

        history = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=cfg.max_epochs,
            callbacks=[cb_early, cb_ckpt, cb_lr],
            verbose=2
        )

        val_loss, val_acc = model.evaluate(val_ds, verbose=0)
        print(f"[{cfg.name}] Final val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

        metrics = {
            "val_loss": float(val_loss),
            "val_acc": float(val_acc),
            "history": {
                "loss": [float(x) for x in history.history["loss"]],
                "val_loss": [float(x) for x in history.history["val_loss"]],
                "sparse_categorical_accuracy": [
                    float(x) for x in history.history["sparse_categorical_accuracy"]
                ],
                "val_sparse_categorical_accuracy": [
                    float(x) for x in history.history["val_sparse_categorical_accuracy"]
                ]
            }
        }

        metrics_path = os.path.join(OUTPUT_DIR, f"metrics_{cfg.name}.json")
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model
            best_config = cfg
            best_metrics = metrics

    print("\nBest config:", best_config)
    print(f"Best validation loss: {best_val_loss:.4f}")
    return best_model, best_config, best_metrics


In [None]:
def sample_from_logits(logits: np.ndarray, temperature: float = 1.0) -> int:

    if temperature <= 0:
        return int(np.argmax(logits))

    logits = logits.astype(np.float64)
    logits = logits / temperature
    probs = np.exp(logits) / np.sum(np.exp(logits))
    return int(np.random.choice(len(probs), p=probs))

In [None]:
def generate_continuation(
    model: Model,
    seed_sequence: List[int],
    int_to_token: Dict[int, str],
    token_to_int: Dict[str, int],
    num_generate: int = 200,
    temperature: float = 1.0,
    sequence_length: int = 50
) -> List[str]:
    model_input = list(seed_sequence)
    generated_tokens = []

    pad_id = token_to_int["<PAD>"]

    for i in range(num_generate):
        if len(model_input) < sequence_length:
            input_seq = [pad_id] * (sequence_length - len(model_input)) + model_input
        else:
            input_seq = model_input[-sequence_length:]

        input_arr = np.array([input_seq], dtype=np.int32)
        preds = model.predict(input_arr, verbose=0)[0]  # (vocab_size,)
        next_id = sample_from_logits(np.log(preds + 1e-9), temperature=temperature)

        generated_tokens.append(int_to_token.get(next_id, "<UNK>"))
        model_input.append(next_id)

    return generated_tokens

In [None]:
def tokens_to_midi(
    tokens: List[str],
    output_path: str,
    quarter_length: float = 0.5
) -> None:

    out_stream = stream.Stream()
    offset = 0.0

    for tok in tokens:
        if tok in ("<PAD>", "<UNK>"):
            offset += quarter_length
            continue

        if "." in tok:
            pitches = tok.split(".")
            chord_notes = []
            for p in pitches:
                try:
                    n = note.Note(p)
                    n.storedInstrument = instrument.Piano()
                    chord_notes.append(n)
                except Exception:
                    pass
            if chord_notes:
                c = chord.Chord(chord_notes)
                c.offset = offset
                out_stream.append(c)
        else:
            try:
                n = note.Note(tok)
                n.offset = offset
                n.storedInstrument = instrument.Piano()
                out_stream.append(n)
            except Exception:
                pass

        offset += quarter_length

    out_stream.write("midi", fp=output_path)
    print(f"Saved generated MIDI to {output_path}")

In [None]:
def main():

    midi_files = load_midi_file_paths(DATA_DIR, max_files=None) 
    notes = extract_notes_from_midi(midi_files)

    base_config = ModelConfig(name="base")
    token_to_int, int_to_token = build_vocabulary(notes, min_freq=base_config.min_note_freq)
    encoded = encode_notes(notes, token_to_int)

    X, y = create_sequences(encoded, seq_length=base_config.sequence_length)

    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.1, random_state=SEED, shuffle=True
    )
    print(f"Train sequences: {X_train.shape[0]}, Val sequences: {X_val.shape[0]}")

    vocab_size = len(token_to_int)

    configs = [
        ModelConfig(
            name="small_lstm",
            sequence_length=50,
            batch_size=64,
            embedding_dim=48,
            lstm_units=128,
            num_lstm_layers=1,
            dropout=0.3,
            recurrent_dropout=0.1,
            l1_reg=1e-5,
            learning_rate=1e-3,
            max_epochs=40,
            patience=5,
            min_note_freq=base_config.min_note_freq
        ),
        ModelConfig(
            name="medium_lstm",
            sequence_length=50,
            batch_size=64,
            embedding_dim=64,
            lstm_units=256,
            num_lstm_layers=1,
            dropout=0.4,
            recurrent_dropout=0.1,
            l1_reg=1e-5,
            learning_rate=1e-3,
            max_epochs=50,
            patience=6,
            min_note_freq=base_config.min_note_freq
        ),
        ModelConfig(
            name="two_layer_lstm",
            sequence_length=50,
            batch_size=64,
            embedding_dim=64,
            lstm_units=256,
            num_lstm_layers=2,
            dropout=0.3,
            recurrent_dropout=0.1,
            l1_reg=1e-5,
            learning_rate=5e-4,
            max_epochs=60,
            patience=7,
            min_note_freq=base_config.min_note_freq
        ),
    ]

    best_model, best_config, best_metrics = train_with_configs(
        X_train, y_train, X_val, y_val, vocab_size, configs
    )

    best_model_path = os.path.join(OUTPUT_DIR, f"best_model_{best_config.name}.keras")
    best_model.save(best_model_path)
    print(f"Saved best model to {best_model_path}")

    mapping_path = os.path.join(OUTPUT_DIR, "vocab_mappings.json")
    with open(mapping_path, "w") as f:
        json.dump(
            {
                "token_to_int": token_to_int,
                "int_to_token": int_to_token,
                "best_config": asdict(best_config),
                "best_metrics": {k: best_metrics[k] for k in ("val_loss", "val_acc")}
            },
            f,
            indent=2
        )
    print(f"Saved vocabulary and config metadata to {mapping_path}")

    idx = np.random.randint(0, X_train.shape[0])
    seed_seq = X_train[idx].tolist()

    generated_tokens = generate_continuation(
        model=best_model,
        seed_sequence=seed_seq,
        int_to_token=int_to_token,
        token_to_int=token_to_int,
        num_generate=200,
        temperature=1.0,
        sequence_length=best_config.sequence_length
    )

    decoded_seed = [int_to_token.get(i, "<UNK>") for i in seed_seq]
    full_tokens = decoded_seed + generated_tokens

    midi_output_path = os.path.join(OUTPUT_DIR, f"demo_continuation_{best_config.name}.mid")
    tokens_to_midi(full_tokens, midi_output_path, quarter_length=0.5)

    print("Done. Use the saved metrics, activations, and MIDI to build your report.")

In [17]:
if __name__ == "__main__":
    main()

Found 295 MIDI files.
[1/295] Parsing data/midi/albeniz/alb_esp1.mid ...




[2/295] Parsing data/midi/albeniz/alb_esp2.mid ...
[3/295] Parsing data/midi/albeniz/alb_esp3.mid ...
[4/295] Parsing data/midi/albeniz/alb_esp4.mid ...
[5/295] Parsing data/midi/albeniz/alb_esp5.mid ...
[6/295] Parsing data/midi/albeniz/alb_esp6.mid ...
[7/295] Parsing data/midi/albeniz/alb_se1.mid ...
[8/295] Parsing data/midi/albeniz/alb_se2.mid ...
[9/295] Parsing data/midi/albeniz/alb_se3.mid ...
[10/295] Parsing data/midi/albeniz/alb_se4.mid ...
[11/295] Parsing data/midi/albeniz/alb_se5.mid ...
[12/295] Parsing data/midi/albeniz/alb_se6.mid ...
[13/295] Parsing data/midi/albeniz/alb_se7.mid ...
[14/295] Parsing data/midi/albeniz/alb_se8.mid ...
[15/295] Parsing data/midi/bach/bach_846.mid ...




[16/295] Parsing data/midi/bach/bach_847.mid ...




[17/295] Parsing data/midi/bach/bach_850.mid ...




[18/295] Parsing data/midi/balakir/islamei.mid ...




[19/295] Parsing data/midi/beeth/appass_1.mid ...




[20/295] Parsing data/midi/beeth/appass_2.mid ...
[21/295] Parsing data/midi/beeth/appass_3.mid ...
[22/295] Parsing data/midi/beeth/beethoven_hammerklavier_1.mid ...




[23/295] Parsing data/midi/beeth/beethoven_hammerklavier_2.mid ...
[24/295] Parsing data/midi/beeth/beethoven_hammerklavier_3.mid ...
[25/295] Parsing data/midi/beeth/beethoven_hammerklavier_4.mid ...
[26/295] Parsing data/midi/beeth/beethoven_les_adieux_1.mid ...
[27/295] Parsing data/midi/beeth/beethoven_les_adieux_2.mid ...
[28/295] Parsing data/midi/beeth/beethoven_les_adieux_3.mid ...
[29/295] Parsing data/midi/beeth/beethoven_opus10_1.mid ...
[30/295] Parsing data/midi/beeth/beethoven_opus10_2.mid ...
[31/295] Parsing data/midi/beeth/beethoven_opus10_3.mid ...
[32/295] Parsing data/midi/beeth/beethoven_opus22_1.mid ...




[33/295] Parsing data/midi/beeth/beethoven_opus22_2.mid ...
[34/295] Parsing data/midi/beeth/beethoven_opus22_3.mid ...
[35/295] Parsing data/midi/beeth/beethoven_opus22_4.mid ...
[36/295] Parsing data/midi/beeth/beethoven_opus90_1.mid ...




[37/295] Parsing data/midi/beeth/beethoven_opus90_2.mid ...
[38/295] Parsing data/midi/beeth/elise.mid ...




[39/295] Parsing data/midi/beeth/mond_1.mid ...




[40/295] Parsing data/midi/beeth/mond_2.mid ...




[41/295] Parsing data/midi/beeth/mond_3.mid ...




[42/295] Parsing data/midi/beeth/pathetique_1.mid ...




[43/295] Parsing data/midi/beeth/pathetique_2.mid ...




[44/295] Parsing data/midi/beeth/pathetique_3.mid ...




[45/295] Parsing data/midi/beeth/waldstein_1.mid ...
[46/295] Parsing data/midi/beeth/waldstein_2.mid ...
[47/295] Parsing data/midi/beeth/waldstein_3.mid ...




[48/295] Parsing data/midi/borodin/bor_ps1.mid ...
[49/295] Parsing data/midi/borodin/bor_ps2.mid ...
[50/295] Parsing data/midi/borodin/bor_ps3.mid ...
[51/295] Parsing data/midi/borodin/bor_ps4.mid ...
[52/295] Parsing data/midi/borodin/bor_ps5.mid ...
[53/295] Parsing data/midi/borodin/bor_ps6.mid ...
[54/295] Parsing data/midi/borodin/bor_ps7.mid ...
[55/295] Parsing data/midi/brahms/BR_IM6.MID ...




[56/295] Parsing data/midi/brahms/br_im2.mid ...




[57/295] Parsing data/midi/brahms/br_im5.mid ...




[58/295] Parsing data/midi/brahms/br_rhap.mid ...




[59/295] Parsing data/midi/brahms/brahms_opus117_1.mid ...
[60/295] Parsing data/midi/brahms/brahms_opus117_2.mid ...
[61/295] Parsing data/midi/brahms/brahms_opus1_1.mid ...
[62/295] Parsing data/midi/brahms/brahms_opus1_2.mid ...
[63/295] Parsing data/midi/brahms/brahms_opus1_3.mid ...
[64/295] Parsing data/midi/brahms/brahms_opus1_4.mid ...
[65/295] Parsing data/midi/burgm/burg_agitato.mid ...




[66/295] Parsing data/midi/burgm/burg_erwachen.mid ...




[67/295] Parsing data/midi/burgm/burg_geschwindigkeit.mid ...




[68/295] Parsing data/midi/burgm/burg_gewitter.mid ...
[69/295] Parsing data/midi/burgm/burg_perlen.mid ...
[70/295] Parsing data/midi/burgm/burg_quelle.mid ...
[71/295] Parsing data/midi/burgm/burg_spinnerlied.mid ...
[72/295] Parsing data/midi/burgm/burg_sylphen.mid ...
[73/295] Parsing data/midi/burgm/burg_trennung.mid ...
[74/295] Parsing data/midi/chopin/chp_op18.mid ...
[75/295] Parsing data/midi/chopin/chp_op31.mid ...
[76/295] Parsing data/midi/chopin/chpn-p1.mid ...




[77/295] Parsing data/midi/chopin/chpn-p10.mid ...
[78/295] Parsing data/midi/chopin/chpn-p11.mid ...




[79/295] Parsing data/midi/chopin/chpn-p12.mid ...
[80/295] Parsing data/midi/chopin/chpn-p13.mid ...
[81/295] Parsing data/midi/chopin/chpn-p14.mid ...
[82/295] Parsing data/midi/chopin/chpn-p15.mid ...
[83/295] Parsing data/midi/chopin/chpn-p16.mid ...
[84/295] Parsing data/midi/chopin/chpn-p17.mid ...
[85/295] Parsing data/midi/chopin/chpn-p18.mid ...
[86/295] Parsing data/midi/chopin/chpn-p19.mid ...
[87/295] Parsing data/midi/chopin/chpn-p2.mid ...
[88/295] Parsing data/midi/chopin/chpn-p20.mid ...
[89/295] Parsing data/midi/chopin/chpn-p21.mid ...
[90/295] Parsing data/midi/chopin/chpn-p22.mid ...
[91/295] Parsing data/midi/chopin/chpn-p23.mid ...
[92/295] Parsing data/midi/chopin/chpn-p24.mid ...
[93/295] Parsing data/midi/chopin/chpn-p3.mid ...
[94/295] Parsing data/midi/chopin/chpn-p4.mid ...
[95/295] Parsing data/midi/chopin/chpn-p5.mid ...
[96/295] Parsing data/midi/chopin/chpn-p6.mid ...
[97/295] Parsing data/midi/chopin/chpn-p7.mid ...
[98/295] Parsing data/midi/chopin/chp



[99/295] Parsing data/midi/chopin/chpn-p9.mid ...
[100/295] Parsing data/midi/chopin/chpn_op10_e01.mid ...




[101/295] Parsing data/midi/chopin/chpn_op10_e05.mid ...




[102/295] Parsing data/midi/chopin/chpn_op10_e12.mid ...




[103/295] Parsing data/midi/chopin/chpn_op23.mid ...




[104/295] Parsing data/midi/chopin/chpn_op25_e1.mid ...




[105/295] Parsing data/midi/chopin/chpn_op25_e11.mid ...




[106/295] Parsing data/midi/chopin/chpn_op25_e12.mid ...
[107/295] Parsing data/midi/chopin/chpn_op25_e2.mid ...
[108/295] Parsing data/midi/chopin/chpn_op25_e3.mid ...
[109/295] Parsing data/midi/chopin/chpn_op25_e4.mid ...
[110/295] Parsing data/midi/chopin/chpn_op27_1.mid ...
[111/295] Parsing data/midi/chopin/chpn_op27_2.mid ...
[112/295] Parsing data/midi/chopin/chpn_op33_2.mid ...
[113/295] Parsing data/midi/chopin/chpn_op33_4.mid ...
[114/295] Parsing data/midi/chopin/chpn_op35_1.mid ...




[115/295] Parsing data/midi/chopin/chpn_op35_2.mid ...
[116/295] Parsing data/midi/chopin/chpn_op35_3.mid ...




[117/295] Parsing data/midi/chopin/chpn_op35_4.mid ...
[118/295] Parsing data/midi/chopin/chpn_op53.mid ...
[119/295] Parsing data/midi/chopin/chpn_op66.mid ...
[120/295] Parsing data/midi/chopin/chpn_op7_1.mid ...
[121/295] Parsing data/midi/chopin/chpn_op7_2.mid ...
[122/295] Parsing data/midi/debussy/DEB_CLAI.MID ...




[123/295] Parsing data/midi/debussy/DEB_PASS.MID ...
[124/295] Parsing data/midi/debussy/deb_menu.mid ...




[125/295] Parsing data/midi/debussy/deb_prel.mid ...
[126/295] Parsing data/midi/debussy/debussy_cc_1.mid ...
[127/295] Parsing data/midi/debussy/debussy_cc_2.mid ...
[128/295] Parsing data/midi/debussy/debussy_cc_3.mid ...
[129/295] Parsing data/midi/debussy/debussy_cc_4.mid ...
[130/295] Parsing data/midi/debussy/debussy_cc_6.mid ...
[131/295] Parsing data/midi/granados/gra_esp_2.mid ...
[132/295] Parsing data/midi/granados/gra_esp_3.mid ...
[133/295] Parsing data/midi/granados/gra_esp_4.mid ...
[134/295] Parsing data/midi/grieg/grieg_album.mid ...




[135/295] Parsing data/midi/grieg/grieg_berceuse.mid ...




[136/295] Parsing data/midi/grieg/grieg_brooklet.mid ...
[137/295] Parsing data/midi/grieg/grieg_butterfly.mid ...




[138/295] Parsing data/midi/grieg/grieg_elfentanz.mid ...




[139/295] Parsing data/midi/grieg/grieg_halling.mid ...




[140/295] Parsing data/midi/grieg/grieg_kobold.mid ...
[141/295] Parsing data/midi/grieg/grieg_march.mid ...




[142/295] Parsing data/midi/grieg/grieg_once_upon_a_time.mid ...




[143/295] Parsing data/midi/grieg/grieg_spring.mid ...




[144/295] Parsing data/midi/grieg/grieg_voeglein.mid ...




[145/295] Parsing data/midi/grieg/grieg_waechter.mid ...
[146/295] Parsing data/midi/grieg/grieg_walzer.mid ...
[147/295] Parsing data/midi/grieg/grieg_wanderer.mid ...




[148/295] Parsing data/midi/grieg/grieg_wedding.mid ...
[149/295] Parsing data/midi/grieg/grieg_zwerge.mid ...
[150/295] Parsing data/midi/haydn/hay_40_1.mid ...
[151/295] Parsing data/midi/haydn/hay_40_2.mid ...
[152/295] Parsing data/midi/haydn/haydn_33_1.mid ...
[153/295] Parsing data/midi/haydn/haydn_33_2.mid ...
[154/295] Parsing data/midi/haydn/haydn_33_3.mid ...
[155/295] Parsing data/midi/haydn/haydn_35_1.mid ...
[156/295] Parsing data/midi/haydn/haydn_35_2.mid ...
[157/295] Parsing data/midi/haydn/haydn_35_3.mid ...
[158/295] Parsing data/midi/haydn/haydn_43_1.mid ...
[159/295] Parsing data/midi/haydn/haydn_43_2.mid ...
[160/295] Parsing data/midi/haydn/haydn_43_3.mid ...
[161/295] Parsing data/midi/haydn/haydn_7_1.mid ...
[162/295] Parsing data/midi/haydn/haydn_7_2.mid ...
[163/295] Parsing data/midi/haydn/haydn_7_3.mid ...
[164/295] Parsing data/midi/haydn/haydn_8_1.mid ...
[165/295] Parsing data/midi/haydn/haydn_8_2.mid ...
[166/295] Parsing data/midi/haydn/haydn_8_3.mid ..



[172/295] Parsing data/midi/liszt/liz_et1.mid ...
[173/295] Parsing data/midi/liszt/liz_et2.mid ...
[174/295] Parsing data/midi/liszt/liz_et3.mid ...
[175/295] Parsing data/midi/liszt/liz_et4.mid ...
[176/295] Parsing data/midi/liszt/liz_et5.mid ...
[177/295] Parsing data/midi/liszt/liz_et6.mid ...
[178/295] Parsing data/midi/liszt/liz_et_trans4.mid ...




[179/295] Parsing data/midi/liszt/liz_et_trans5.mid ...
[180/295] Parsing data/midi/liszt/liz_et_trans8.mid ...
[181/295] Parsing data/midi/liszt/liz_liebestraum.mid ...
[182/295] Parsing data/midi/liszt/liz_rhap02.mid ...
[183/295] Parsing data/midi/liszt/liz_rhap09.mid ...
[184/295] Parsing data/midi/liszt/liz_rhap10.mid ...
[185/295] Parsing data/midi/liszt/liz_rhap12.mid ...
[186/295] Parsing data/midi/liszt/liz_rhap15.mid ...
[187/295] Parsing data/midi/mendelssohn/mendel_op19_1.mid ...




[188/295] Parsing data/midi/mendelssohn/mendel_op19_2.mid ...
[189/295] Parsing data/midi/mendelssohn/mendel_op19_3.mid ...




[190/295] Parsing data/midi/mendelssohn/mendel_op19_4.mid ...
[191/295] Parsing data/midi/mendelssohn/mendel_op19_5.mid ...
[192/295] Parsing data/midi/mendelssohn/mendel_op19_6.mid ...
[193/295] Parsing data/midi/mendelssohn/mendel_op30_1.mid ...




[194/295] Parsing data/midi/mendelssohn/mendel_op30_2.mid ...
[195/295] Parsing data/midi/mendelssohn/mendel_op30_3.mid ...
[196/295] Parsing data/midi/mendelssohn/mendel_op30_4.mid ...
[197/295] Parsing data/midi/mendelssohn/mendel_op30_5.mid ...
[198/295] Parsing data/midi/mendelssohn/mendel_op53_5.mid ...
[199/295] Parsing data/midi/mendelssohn/mendel_op62_3.mid ...
[200/295] Parsing data/midi/mendelssohn/mendel_op62_4.mid ...
[201/295] Parsing data/midi/mendelssohn/mendel_op62_5.mid ...
[202/295] Parsing data/midi/mozart/mz_311_1.mid ...




[203/295] Parsing data/midi/mozart/mz_311_2.mid ...
[204/295] Parsing data/midi/mozart/mz_311_3.mid ...
[205/295] Parsing data/midi/mozart/mz_330_1.mid ...




[206/295] Parsing data/midi/mozart/mz_330_2.mid ...
[207/295] Parsing data/midi/mozart/mz_330_3.mid ...
[208/295] Parsing data/midi/mozart/mz_331_1.mid ...
[209/295] Parsing data/midi/mozart/mz_331_2.mid ...
[210/295] Parsing data/midi/mozart/mz_331_3.mid ...
[211/295] Parsing data/midi/mozart/mz_332_1.mid ...
[212/295] Parsing data/midi/mozart/mz_332_2.mid ...
[213/295] Parsing data/midi/mozart/mz_332_3.mid ...
[214/295] Parsing data/midi/mozart/mz_333_1.mid ...
[215/295] Parsing data/midi/mozart/mz_333_2.mid ...
[216/295] Parsing data/midi/mozart/mz_333_3.mid ...
[217/295] Parsing data/midi/mozart/mz_545_1.mid ...
[218/295] Parsing data/midi/mozart/mz_545_2.mid ...
[219/295] Parsing data/midi/mozart/mz_545_3.mid ...
[220/295] Parsing data/midi/mozart/mz_570_1.mid ...
[221/295] Parsing data/midi/mozart/mz_570_2.mid ...
[222/295] Parsing data/midi/mozart/mz_570_3.mid ...
[223/295] Parsing data/midi/muss/muss_1.mid ...
[224/295] Parsing data/midi/muss/muss_2.mid ...




[225/295] Parsing data/midi/muss/muss_3.mid ...
[226/295] Parsing data/midi/muss/muss_4.mid ...
[227/295] Parsing data/midi/muss/muss_5.mid ...




[228/295] Parsing data/midi/muss/muss_6.mid ...
[229/295] Parsing data/midi/muss/muss_7.mid ...
[230/295] Parsing data/midi/muss/muss_8.mid ...




[231/295] Parsing data/midi/schubert/schu_143_1.mid ...
[232/295] Parsing data/midi/schubert/schu_143_2.mid ...
[233/295] Parsing data/midi/schubert/schu_143_3.mid ...
[234/295] Parsing data/midi/schubert/schub_d760_1.mid ...
[235/295] Parsing data/midi/schubert/schub_d760_2.mid ...
[236/295] Parsing data/midi/schubert/schub_d760_3.mid ...
[237/295] Parsing data/midi/schubert/schub_d760_4.mid ...
[238/295] Parsing data/midi/schubert/schub_d960_1.mid ...
[239/295] Parsing data/midi/schubert/schub_d960_2.mid ...
[240/295] Parsing data/midi/schubert/schub_d960_3.mid ...
[241/295] Parsing data/midi/schubert/schub_d960_4.mid ...
[242/295] Parsing data/midi/schubert/schubert_D850_1.mid ...
[243/295] Parsing data/midi/schubert/schubert_D850_2.mid ...
[244/295] Parsing data/midi/schubert/schubert_D850_3.mid ...
[245/295] Parsing data/midi/schubert/schubert_D850_4.mid ...
[246/295] Parsing data/midi/schubert/schubert_D935_1.mid ...
[247/295] Parsing data/midi/schubert/schubert_D935_2.mid ...
[2



[254/295] Parsing data/midi/schubert/schumm-1.mid ...




[255/295] Parsing data/midi/schubert/schumm-2.mid ...
[256/295] Parsing data/midi/schubert/schumm-3.mid ...
[257/295] Parsing data/midi/schubert/schumm-4.mid ...
[258/295] Parsing data/midi/schubert/schumm-5.mid ...
[259/295] Parsing data/midi/schubert/schumm-6.mid ...
[260/295] Parsing data/midi/schumann/schum_abegg.mid ...
[261/295] Parsing data/midi/schumann/scn15_1.mid ...




[262/295] Parsing data/midi/schumann/scn15_10.mid ...
[263/295] Parsing data/midi/schumann/scn15_11.mid ...




[264/295] Parsing data/midi/schumann/scn15_12.mid ...
[265/295] Parsing data/midi/schumann/scn15_13.mid ...
[266/295] Parsing data/midi/schumann/scn15_2.mid ...
[267/295] Parsing data/midi/schumann/scn15_3.mid ...
[268/295] Parsing data/midi/schumann/scn15_4.mid ...
[269/295] Parsing data/midi/schumann/scn15_5.mid ...




[270/295] Parsing data/midi/schumann/scn15_6.mid ...
[271/295] Parsing data/midi/schumann/scn15_7.mid ...
[272/295] Parsing data/midi/schumann/scn15_8.mid ...




[273/295] Parsing data/midi/schumann/scn15_9.mid ...
[274/295] Parsing data/midi/schumann/scn16_1.mid ...
[275/295] Parsing data/midi/schumann/scn16_2.mid ...
[276/295] Parsing data/midi/schumann/scn16_3.mid ...
[277/295] Parsing data/midi/schumann/scn16_4.mid ...
[278/295] Parsing data/midi/schumann/scn16_5.mid ...
[279/295] Parsing data/midi/schumann/scn16_6.mid ...
[280/295] Parsing data/midi/schumann/scn16_7.mid ...
[281/295] Parsing data/midi/schumann/scn16_8.mid ...
[282/295] Parsing data/midi/schumann/scn68_10.mid ...
[283/295] Parsing data/midi/schumann/scn68_12.mid ...




[284/295] Parsing data/midi/tschai/ty_april.mid ...
[285/295] Parsing data/midi/tschai/ty_august.mid ...
[286/295] Parsing data/midi/tschai/ty_dezember.mid ...
[287/295] Parsing data/midi/tschai/ty_februar.mid ...
[288/295] Parsing data/midi/tschai/ty_januar.mid ...
[289/295] Parsing data/midi/tschai/ty_juli.mid ...
[290/295] Parsing data/midi/tschai/ty_juni.mid ...
[291/295] Parsing data/midi/tschai/ty_maerz.mid ...
[292/295] Parsing data/midi/tschai/ty_mai.mid ...
[293/295] Parsing data/midi/tschai/ty_november.mid ...
[294/295] Parsing data/midi/tschai/ty_oktober.mid ...
[295/295] Parsing data/midi/tschai/ty_september.mid ...
Extracted 2800 note/chord events.
Vocabulary size (including PAD & UNK): 165
Encoded note sequence length: 2800
Created 2750 sequences of length 50.
Train sequences: 2475, Val sequences: 275

Training config: small_lstm
ModelConfig(name='small_lstm', sequence_length=50, batch_size=64, embedding_dim=48, lstm_units=128, num_lstm_layers=1, dropout=0.3, recurrent_dr

2025-11-19 21:08:29.475970: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M4
2025-11-19 21:08:29.476031: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-11-19 21:08:29.476039: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.92 GB
2025-11-19 21:08:29.476227: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-11-19 21:08:29.476239: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Epoch 1/40


2025-11-19 21:08:30.474338: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


39/39 - 158s - 4s/step - loss: 4.8255 - sparse_categorical_accuracy: 0.0513 - val_loss: 4.6009 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 0.0010
Epoch 2/40
39/39 - 153s - 4s/step - loss: 4.4199 - sparse_categorical_accuracy: 0.0667 - val_loss: 4.3849 - val_sparse_categorical_accuracy: 0.0800 - learning_rate: 0.0010
Epoch 3/40
39/39 - 154s - 4s/step - loss: 4.2157 - sparse_categorical_accuracy: 0.0933 - val_loss: 4.2564 - val_sparse_categorical_accuracy: 0.0873 - learning_rate: 0.0010
Epoch 4/40
39/39 - 153s - 4s/step - loss: 3.9922 - sparse_categorical_accuracy: 0.1204 - val_loss: 3.9761 - val_sparse_categorical_accuracy: 0.1127 - learning_rate: 0.0010
Epoch 5/40
39/39 - 152s - 4s/step - loss: 3.7538 - sparse_categorical_accuracy: 0.1430 - val_loss: 3.7935 - val_sparse_categorical_accuracy: 0.1236 - learning_rate: 0.0010
Epoch 6/40
39/39 - 154s - 4s/step - loss: 3.5487 - sparse_categorical_accuracy: 0.1499 - val_loss: 3.6685 - val_sparse_categorical_accuracy: 0.1345 - l

Epoch 1/50
39/39 - 159s - 4s/step - loss: 4.7758 - sparse_categorical_accuracy: 0.0618 - val_loss: 4.5915 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 0.0010
Epoch 2/50
39/39 - 158s - 4s/step - loss: 4.4498 - sparse_categorical_accuracy: 0.0784 - val_loss: 4.3762 - val_sparse_categorical_accuracy: 0.0691 - learning_rate: 0.0010
Epoch 3/50
39/39 - 155s - 4s/step - loss: 4.1201 - sparse_categorical_accuracy: 0.1160 - val_loss: 4.0279 - val_sparse_categorical_accuracy: 0.1018 - learning_rate: 0.0010
Epoch 4/50
39/39 - 153s - 4s/step - loss: 3.8312 - sparse_categorical_accuracy: 0.1354 - val_loss: 3.7670 - val_sparse_categorical_accuracy: 0.1382 - learning_rate: 0.0010
Epoch 5/50
39/39 - 153s - 4s/step - loss: 3.5357 - sparse_categorical_accuracy: 0.1612 - val_loss: 3.6607 - val_sparse_categorical_accuracy: 0.1309 - learning_rate: 0.0010
Epoch 6/50
39/39 - 153s - 4s/step - loss: 3.3143 - sparse_categorical_accuracy: 0.1628 - val_loss: 3.4116 - val_sparse_categorical_accuracy:

Epoch 1/60
39/39 - 289s - 7s/step - loss: 4.8895 - sparse_categorical_accuracy: 0.0687 - val_loss: 4.6855 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 5.0000e-04
Epoch 2/60
39/39 - 286s - 7s/step - loss: 4.5670 - sparse_categorical_accuracy: 0.0634 - val_loss: 4.6140 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 5.0000e-04
Epoch 3/60
39/39 - 285s - 7s/step - loss: 4.5294 - sparse_categorical_accuracy: 0.0590 - val_loss: 4.6172 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 5.0000e-04
Epoch 4/60
39/39 - 286s - 7s/step - loss: 4.5193 - sparse_categorical_accuracy: 0.0541 - val_loss: 4.6064 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 5.0000e-04
Epoch 5/60
39/39 - 286s - 7s/step - loss: 4.5104 - sparse_categorical_accuracy: 0.0578 - val_loss: 4.5875 - val_sparse_categorical_accuracy: 0.0473 - learning_rate: 5.0000e-04
Epoch 6/60
39/39 - 292s - 7s/step - loss: 4.5007 - sparse_categorical_accuracy: 0.0651 - val_loss: 4.5709 - val_sparse_c

In [10]:
import json
import pandas as pd

def load_metrics(path: str, name: str) -> dict:
    with open(path, "r") as f:
        d = json.load(f)

    history = d["history"]

    final_train_loss = history["loss"][-1]
    final_val_loss = history["val_loss"][-1]
    final_train_acc = history["sparse_categorical_accuracy"][-1]
    final_val_acc = history["val_sparse_categorical_accuracy"][-1]

    best_val_loss = min(history["val_loss"])
    best_epoch = history["val_loss"].index(best_val_loss) + 1
    val_acc_at_best = history["val_sparse_categorical_accuracy"][best_epoch - 1]

    return {
        "model": name,
        "final_train_loss": final_train_loss,
        "final_val_loss": final_val_loss,
        "final_train_acc": final_train_acc,
        "final_val_acc": final_val_acc,
        "best_val_loss": best_val_loss,
        "best_epoch": best_epoch,
        "val_acc_at_best_epoch": val_acc_at_best,
    }

def main():
    rows = []
    path = '/Users/tanmayswami/Downloads/Northwestern/MSDS 458/Final Project/outputs/music_lstm/'
    rows.append(load_metrics(f"{path}metrics_small_lstm.json", "small_lstm"))
    rows.append(load_metrics(f"{path}metrics_medium_lstm.json", "medium_lstm"))
    rows.append(load_metrics(f"{path}metrics_two_layer_lstm.json", "two_layer_lstm"))

    df = pd.DataFrame(rows)
    df = df.sort_values("best_val_loss").reset_index(drop=True)
    # print(df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

    return df

In [12]:
df = main()

In [14]:
df

Unnamed: 0,model,final_train_loss,final_val_loss,final_train_acc,final_val_acc,best_val_loss,best_epoch,val_acc_at_best_epoch
0,medium_lstm,0.368674,1.385466,0.933737,0.738182,1.376033,47,0.741818
1,small_lstm,1.010676,1.71888,0.717576,0.541818,1.71888,40,0.541818
2,two_layer_lstm,1.41949,1.992827,0.539394,0.410909,1.992827,60,0.410909


In [13]:
df.to_json()

'{"model":{"0":"medium_lstm","1":"small_lstm","2":"two_layer_lstm"},"final_train_loss":{"0":0.3686742783,"1":1.010676384,"2":1.4194897413},"final_val_loss":{"0":1.3854655027,"1":1.7188800573,"2":1.9928267002},"final_train_acc":{"0":0.9337373972,"1":0.7175757289,"2":0.5393939614},"final_val_acc":{"0":0.7381818295,"1":0.5418182015,"2":0.4109090865},"best_val_loss":{"0":1.3760325909,"1":1.7188800573,"2":1.9928267002},"best_epoch":{"0":47,"1":40,"2":60},"val_acc_at_best_epoch":{"0":0.7418181896,"1":0.5418182015,"2":0.4109090865}}'