In [1]:
import sys
sys.path.append('../')

In [2]:
import os, pandas as pd
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Concatenate, TimeDistributed, Dropout
from tensorflow.keras.models import Model
import glob
import numpy as np
from tqdm import tqdm
import config

2025-09-09 21:59:30.641136: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
use_gpu = False

if not use_gpu:
    tf.config.set_visible_devices([], 'GPU')

W0000 00:00:1757442574.645129    9382 gpu_device.cc:2431] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 5.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.


In [4]:
file_paths = glob.glob(os.path.join(config.MidiFiles.preprocessed_csv_files, "*.csv"))

In [5]:
EVENT_TYPES = ['Control_c', 'Note_on_c', 'Program_c', 'Pitch_bend_c']
event2id = {e: i for i, e in enumerate(EVENT_TYPES)}
num_event_classes = len(EVENT_TYPES)

In [6]:
def preprocess_file(df):
  df['delta_time'] = np.log1p(df['delta_time'])

  EVENT_TYPES = ['Control_c', 'Note_on_c', 'Program_c', 'Pitch_bend_c']
  one_hot = pd.get_dummies(df['event'])

  for col in EVENT_TYPES:
    if col not in one_hot:
      one_hot[col] = False

  one_hot = one_hot[EVENT_TYPES]
  df = pd.concat([df.drop(columns=['event']), one_hot], axis=1)

  df.fillna(0, inplace=True)

  df[["arg1", "arg2", "arg3"]] = np.clip(df[["arg1", "arg2", "arg3"]], 0, 127)
  df[["arg1", "arg2", "arg3"]] = df[["arg1", "arg2", "arg3"]] / 127.0
  return df

In [7]:
songs = []
for p in tqdm(file_paths):
    try:
        df = pd.read_csv(p)
        df = preprocess_file(df)
        songs.append(df) # Append the processed DataFrame
    except Exception as e:
        print("Skipping", p, e)

  0%|          | 0/1068 [00:00<?, ?it/s]

100%|██████████| 1068/1068 [00:13<00:00, 78.03it/s]


In [8]:
seq_len = 32
num_features = 8 
num_event_classes = 4  
embedding_dim_event = 16 

batch_size = 128
epochs = 30

In [9]:
def sequence_generator(songs, seq_len):
    while True:
        for df in songs:
            data = df.values.astype(np.float32)

            # split features
            # features: delta_time, arg1-3, Control_c, Note_on_c, Program_c, Pitch_bend_c
            X_seq = data[:, :]  # all 8 features
            for i in range(len(data) - seq_len):
                X = X_seq[i:i+seq_len]
                y = X_seq[i+seq_len]

                # Separate outputs
                y_event = y[4:]  # one-hot event columns
                y_arg1 = y[1:2]
                y_arg2 = y[2:3]
                y_arg3 = y[3:4]
                y_delta = y[0:1]

                yield X, {'out_event': y_event,
                         'out_arg1': y_arg1,
                         'out_arg2': y_arg2,
                         'out_arg3': y_arg3,
                         'out_delta': y_delta}

In [10]:
dataset = tf.data.Dataset.from_generator(
    lambda: sequence_generator(songs, seq_len),
    output_signature=(
        tf.TensorSpec(shape=(seq_len, num_features), dtype=tf.float32),
        {
            'out_event': tf.TensorSpec(shape=(num_event_classes,), dtype=tf.float32),
            'out_arg1': tf.TensorSpec(shape=(1,), dtype=tf.float32),
            'out_arg2': tf.TensorSpec(shape=(1,), dtype=tf.float32),
            'out_arg3': tf.TensorSpec(shape=(1,), dtype=tf.float32),
            'out_delta': tf.TensorSpec(shape=(1,), dtype=tf.float32)
        }
    )
)

dataset = dataset.shuffle(10000, seed=42).batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [11]:
inputs = Input(shape=(seq_len, num_features))

x = LSTM(256, return_sequences=False)(inputs)
x = Dropout(0.2)(x)

out_event = Dense(num_event_classes, activation='softmax', name='out_event')(x)
out_arg1 = Dense(1, activation='sigmoid', name='out_arg1')(x)
out_arg2 = Dense(1, activation='sigmoid', name='out_arg2')(x)
out_arg3 = Dense(1, activation='sigmoid', name='out_arg3')(x)
out_delta = Dense(1, activation='linear', name='out_delta')(x)

model = Model(inputs=inputs, outputs=[out_event, out_arg1, out_arg2, out_arg3, out_delta])

model.compile(
    optimizer='adam',
    loss={
        'out_event': 'categorical_crossentropy',
        'out_arg1': 'mse',
        'out_arg2': 'mse',
        'out_arg3': 'mse',
        'out_delta': 'mse'
    },
    loss_weights={
        'out_event': 1.0,
        'out_arg1': 0.1,
        'out_arg2': 0.1,
        'out_arg3': 0.1,
        'out_delta': 0.1
    }
)

model.summary()

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
import re

steps_per_epoch = 61_000

output_path = config.MidiFiles.weights_path

checkpoint_filepath = output_path + '/lstm-{epoch:02d}-{loss:.4f}.keras'

last_epoch = 0
files = [f for f in os.listdir(output_path) if f.endswith(".keras")]
if files:
    # Extract epoch numbers
    epochs = [int(re.search(r"lstm-(\d+)-", f).group(1)) for f in files]
    last_epoch = max(epochs)

    # Pick last checkpoint
    last_checkpoint = [f for f in files if f"lstm-{last_epoch:02d}-" in f][0]
    last_checkpoint_path = os.path.join(output_path, last_checkpoint)

    print(f"Resuming from checkpoint: {last_checkpoint_path}")

    from tensorflow.keras.models import load_model
    
    model = load_model(last_checkpoint_path)

checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
    save_best_only=False, # Set to True to save only the best model based on a monitored metric
    monitor='loss', # Metric to monitor if save_best_only is True
    mode='min',     # Mode for the monitor metric ('min' for loss, 'max' for accuracy)
    save_freq='epoch' # Save after each epoch
)

early_stopping_callback = EarlyStopping(
    monitor='loss',
    patience=5,
    restore_best_weights=True
)

history = model.fit(
    dataset, 
    epochs=epochs, 
    steps_per_epoch=steps_per_epoch, 
    callbacks=[checkpoint_callback, early_stopping_callback],
    initial_epoch=last_epoch
)

Epoch 1/30


2025-09-09 22:00:00.321626: I external/local_xla/xla/service/service.cc:163] XLA service 0x61416fd60f50 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2025-09-09 22:00:00.321647: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): Host, Default Version
2025-09-09 22:00:00.419023: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.


[1m    1/61000[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m168:59:45[0m 10s/step - loss: 2.2367 - out_arg1_loss: 0.2569 - out_arg2_loss: 0.0097 - out_arg3_loss: 0.1621 - out_delta_loss: 7.3975 - out_event_loss: 1.4541

I0000 00:00:1757442601.953440    9562 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m  523/61000[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:29:08[0m 207ms/step - loss: 0.4361 - out_arg1_loss: 0.0269 - out_arg2_loss: 0.0141 - out_arg3_loss: 0.1297 - out_delta_loss: 2.8020 - out_event_loss: 0.1388

KeyboardInterrupt: 

In [None]:
import pickle

history_filepath = output_path + '/lstm-history.pkl'

with open(history_filepath, "wb") as f:
    pickle.dump(history.history, f)