In [None]:
!pip install mido
!apt install fluidsynth
!pip install midi2audio

Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido
Successfully installed mido-1.3.3
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fluid-soundfont-gm libevdev2 libfluidsynth3 libgudev-1.0-0 libinput-bin
  libinput10 libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a libqt5dbus5
  libqt5gui5 libqt5network5 libqt5svg5 libqt5widgets5 libwacom-bin
  libwacom-common libwacom9 libxcb-icccm4 libxcb-image0 libxcb-keysyms1
  libxcb-render-util0 libxcb-util1 libxcb-xinerama0 libxcb-xinput0 libxcb-xkb1
  libxkbcommon-x11-0 qsynth qt5-gtk-platformtheme qt

In [None]:
import pickle
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Embedding, Dropout, LayerNormalization, Dense
from tensorflow.keras.layers import MultiHeadAttention
from tensorflow.keras.models import Model

import mido
from mido import Message, MidiFile, MidiTrack, MetaMessage

from midi2audio import FluidSynth
from IPython.display import Audio

fs = FluidSynth("/usr/share/sounds/sf2/FluidR3_GM.sf2")

import warnings
warnings.filterwarnings('ignore')

In [None]:
data_path = "./data"
model_path = "./model"
generations_path = "./generations"

In [None]:
NOTE_NAMES = ['C', 'C#', 'D', 'D#', 'E', 'F',
              'F#', 'G', 'G#', 'A', 'A#', 'B']

def note_to_int(note_str):
    if len(note_str) < 2:
        raise ValueError(f"Invalid note string: {note_str}")

    if note_str[1] == '#':
        pitch = note_str[:2]
        octave = int(note_str[2:])
    else:
        pitch = note_str[0]
        octave = int(note_str[1:])

    note_number = NOTE_NAMES.index(pitch) + (octave + 1) * 12
    return note_number

def duration_from_token(token):
    mapping = {
        "SixtyFourth": 0.0625,
        "TripletSixtyFourth": 0.0417,
        "ThirtySecond": 0.125,
        "TripletThirtySecond": 0.0833,
        "Sixteenth": 0.25,
        "TripletSixteenth": 0.1667,
        "DottedSixteenth": 0.375,
        "Eighth": 0.5,
        "Triplet": 0.3333,
        "DottedEighth": 0.75,
        "Quarter": 1.0,
        "DottedQuarter": 1.5,
        "TiedQuarter-Sixteenth": 1.25,
        "TiedQuarter-ThirtySecond": 1.125,
        "Half": 2.0,
        "DottedHalf": 3.0,
        "Whole": 4.0,
        "Unknown": 4.0
    }
    return mapping.get(token, 0.25)

def seconds_to_ticks(seconds, tempo, ticks_per_beat):
    beats = seconds * (1_000_000 / tempo)
    return int(beats * ticks_per_beat)

In [None]:
def append_row(abs_start, abs_end, duration, value, notes, hand, event, shift, df):
    new_row = {'abs_start': abs_start,
               'abs_end': abs_end,
               'duration': duration,
               'value': value,
               'note(s)': notes,
               'hand': hand,
               'event': event,
               'shift': shift}

    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

    return df

def tokens_to_dataframe(token_sequence, ticks_per_beat=480):
    tempo = 500_000  # microseconds per beat (120 BPM)

    chord_mode = False
    chord_notes = []
    duration_token = ""
    current_track = ""
    current_tick = 0
    current_hand = "left"

    df = pd.DataFrame(columns=["abs_start", "abs_end", "duration", "value", "note(s)", "hand", "event", "shift"])

    def beats_to_ticks(beats):
        return int(beats * ticks_per_beat)

    i = 0
    while i < len(token_sequence)-1:
        token = token_sequence[i]

        if token == "[HAND_LEFT]":
            current_hand = "left"
            i += 1
            continue

        elif token == "[HAND_RIGHT]":
            current_hand = "right"
            i += 1
            continue

        elif token.startswith("TIME_SHIFT_"):
            try:
                shift = float(token.split("_")[-1])
                current_tick += beats_to_ticks(shift)
            except Exception as e:
                print("TIME_SHIFT error:", e)
            i += 1
            continue

        elif token == "[CHORD_START]":
            chord_mode = True
            chord_notes = []
            i += 1
            continue

        elif token == "[CHORD_END]":
            for pitch, duration in chord_notes:
                duration_ticks = beats_to_ticks(duration)
                df = append_row(current_tick, current_tick+duration_ticks, duration_ticks, duration_token, pitch, current_hand, "note_on", 0, df)
                df = append_row(current_tick+duration_ticks, current_tick+duration_ticks, 0, duration_token, pitch, current_hand, "note_off", 0, df)

            chord_mode = False
            chord_notes = []
            i += 1
            continue

        elif token.startswith("NOTE_"):
            try:
                pitch = note_to_int(token.split("_")[1])
            except Exception as e:
                print("NOTE conversion error:", e)
                i += 4
                continue

            duration = 1.0  # default
            for j in range(i, min(i + 10, len(token_sequence))):
                if token_sequence[j].startswith("VALUE_"):
                    try:
                        duration_token = token_sequence[j + 1]
                    except IndexError:
                        print("DURATION token not found")
                        break
                    duration = duration_from_token(duration_token)
                    break

            duration_ticks = beats_to_ticks(duration)

            if chord_mode:
                chord_notes.append((pitch, duration))
            else:
                df = append_row(current_tick, current_tick+duration_ticks, duration_ticks, duration_token, pitch, current_hand, "note_on", 0, df)
                df = append_row(current_tick+duration_ticks, current_tick+duration_ticks, 0, duration_token, pitch, current_hand, "note_off", 0, df)

            i += 1
            continue

        i += 1

    df = df.sort_values(by="abs_start")
    df.loc[df['hand'] == 'left', 'shift'] = df.loc[df['hand'] == 'left', 'abs_start'].diff().fillna(0)
    df.loc[df['hand'] == 'right', 'shift'] = df.loc[df['hand'] == 'right', 'abs_start'].diff().fillna(0)

    return df

In [None]:
def csv_to_midi(df, midi_path):
    # Create a new MIDI file with 2 tracks
    mid = MidiFile()
    right_hand_track = MidiTrack()
    left_hand_track = MidiTrack()
    mid.tracks.append(right_hand_track)
    mid.tracks.append(left_hand_track)

    # Initialize time accumulators for each hand
    time_accum = {'right': 0, 'left': 0}

    # Iterate over rows and write messages to the appropriate track
    for _, row in df.iterrows():
        note = int(row['note(s)'])
        event = row['event']
        hand = row['hand']
        shift = int(row['shift'])

        # Determine the MIDI message
        msg_type = 'note_on' if event == 'note_on' else 'note_off'
        msg = Message(msg_type, note=note, velocity=80, time=int(shift))

        if hand == 'right':
            right_hand_track.append(msg)
        elif hand == 'left':
            left_hand_track.append(msg)

    # Save the MIDI file
    mid.save(midi_path)

In [None]:
def sample_next_token(probs, temperature=1.0):
    probs = np.asarray(probs).astype("float64")
    logits = np.log(np.maximum(probs, 1e-8)) / temperature
    exp_preds = np.exp(logits)
    softmax = exp_preds / np.sum(exp_preds)
    return np.random.choice(len(softmax), p=softmax)

def generate_tokens(model, seed_sequence, tokenizer, num_tokens=50, max_seq_length=256, temperature=1.0):
    sequence = seed_sequence[:max_seq_length]

    for _ in range(num_tokens):
        padded_seq = np.array(sequence[-max_seq_length:]).reshape(1, -1)
        preds = model.predict(padded_seq, verbose=0)

        next_token_probs = preds[0]

        next_token_id = sample_next_token(next_token_probs, temperature)
        sequence.append(next_token_id)

    return sequence

In [None]:
class MusicTokenizer:
    def __init__(self):
        self.token_to_id = {}
        self.id_to_token = {}

    def build_vocab(self, sequences):
        for line in sequences:
            parts = line.strip().split()
            for token in parts:
                if token not in self.token_to_id:
                    token_id = len(self.token_to_id)
                    self.token_to_id[token] = token_id
                    self.id_to_token[token_id] = token

    def encode(self, sequences):
        return [[self.token_to_id[token] for token in line.strip().split()] for line in sequences]

    def decode(self, id_sequences):
        return [" ".join([self.id_to_token[token_id] for token_id in line]) for line in id_sequences]

    def decode2(self, id_sequences):
        return [" ".join([self.id_to_token[id] for id in id_sequences])]

In [None]:
class PositionalEmbeddingAdder(tf.keras.layers.Layer):
    def __init__(self, max_seq_length, d_model, **kwargs):
        super().__init__(**kwargs)
        self.max_seq_length = max_seq_length
        self.d_model = d_model
        self.position_embeddings = Embedding(max_seq_length, d_model)

    def call(self, x):
        positions = tf.range(start=0, limit=tf.shape(x)[1], delta=1)
        pos_embeds = self.position_embeddings(positions)
        return x + pos_embeds

class LastToken(tf.keras.layers.Layer):
    def call(self, x):
        return x[:, -1, :]

def transformer_model(input_vocab_size, output_vocab_size, max_seq_length, d_model=128, num_heads=4, num_layers=2, dropout_rate=0.25):
    inputs = Input(shape=(max_seq_length,), dtype=tf.int32)

    # Token embedding
    token_embedding = Embedding(input_vocab_size, d_model)(inputs)

    # Add positional embedding
    outputs = PositionalEmbeddingAdder(max_seq_length, d_model)(token_embedding)

    # Transformer blocks
    for _ in range(num_layers):
        attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(
            outputs, outputs,
            attention_mask=tf.linalg.band_part(tf.ones((max_seq_length, max_seq_length)), -1, 0)
        )
        attention_output = Dropout(dropout_rate)(attention_output)
        attention_output = LayerNormalization(epsilon=1e-7)(outputs + attention_output)

        ffn_output = Dense(d_model * 4, activation='gelu')(attention_output)
        ffn_output = Dense(d_model, activation='gelu')(ffn_output)
        ffn_output = Dropout(dropout_rate)(ffn_output)

        outputs = LayerNormalization(epsilon=1e-7)(attention_output + ffn_output)

    # Only keep the last token's output to predict the next token
    outputs = LastToken()(outputs)

    # Final prediction layer
    outputs = Dense(output_vocab_size, activation='softmax')(outputs)

    return Model(inputs=inputs, outputs=outputs)

In [None]:
with open(f"{model_path}/tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

x_test = np.load(f'{data_path}/x_test.npy')
max_seq_length = len(x_test[0])

from tensorflow import keras
model = keras.models.load_model(f"{model_path}/model_10epochs.keras", custom_objects={
    "PositionalEmbeddingAdder": PositionalEmbeddingAdder,
    "LastToken": LastToken
})

In [None]:
def generate(max_seq_length, sample_idx, temperature=0.7, num_tokens=500, ticks_per_beat=360):
    seed_sequence = list(x_test[sample_idx])

    generated_token_ids = generate_tokens(model, seed_sequence, tokenizer, num_tokens=num_tokens, max_seq_length=max_seq_length, temperature=temperature)

    id_to_token = {v: k for k, v in tokenizer.token_to_id.items()}
    generated_tokens = [id_to_token.get(tok_id, "<UNK>") for tok_id in generated_token_ids]
    generated_sequence = " ".join(generated_tokens)

    folder_name = f"generation_{sample_idx}"
    os.makedirs(os.path.join(generations_path, folder_name), exist_ok=True)

    df_seed = tokens_to_dataframe(tokenizer.decode2(seed_sequence)[0].split(" "), ticks_per_beat=ticks_per_beat)
    seed_path = f'{generations_path}/{folder_name}/seed.mid'
    csv_to_midi(df_seed, seed_path)

    df_gen = tokens_to_dataframe(generated_sequence.split(" "), ticks_per_beat=ticks_per_beat)
    gen_path = f'{generations_path}/{folder_name}/generated.mid'
    csv_to_midi(df_gen, gen_path)

    fs.midi_to_audio(seed_path, f'{generations_path}/{folder_name}/seed.wav')
    fs.midi_to_audio(gen_path, f'{generations_path}/{folder_name}/generated.wav')

    print(f"Generated Token Sequence: \n{generated_sequence}")

In [58]:
for i in range(10):
    sample_idx = random.randint(0, len(x_test)-1)
    generate(max_seq_length=max_seq_length, sample_idx=sample_idx, num_tokens=500)

Generated Token Sequence: 
TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_E6 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_E5 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_A2 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_C5 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_A3 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_C5 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_E3 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_C6 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_C5 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_A2 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_A4 VALUE_Half TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT_0.500 [HAND_LEFT] NOTE_A3 VALUE_Eighth TIME_SHIFT_0.000 [HAND_RIGHT] NOTE_A5 VALUE_DottedQuarter TIME_SHIFT_0.500 [HAND_LEFT] NOTE_C4 VALUE_Eighth TIME_SHIFT