In [None]:
import torch
import torch.nn as nn
import pandas as pd
import ast
from dataclasses import dataclass
from xformers.factory.model_factory import xFormer, xFormerConfig
from xformers.components.positional_embedding import (PositionEmbedding, PositionEmbeddingConfig, register_positional_embedding)

import pretty_midi
import librosa
import numpy as np
import uuid
import os
import mir_eval.display
import matplotlib.pyplot as plt
%matplotlib inline
import IPython.display

In [None]:
# Init model from config

PAD_IDX = 128
BOS_IDX = 129
EOS_IDX = 130
PAD_VALUE = 0.0

NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
EMB_SIZE=64
MAX_LEN = 256
SRC_VOCAB_SIZE = 128+3 # 0-127 representing From C-1 to G9, 128 for PAD_IDX, 129 for BOS, 130 for EOS
TGT_VOCAB_SIZE = 128+3
NHEAD = 4
HIDDEN_LAYER_MULTIPLIER = 4
DROPOUT = 0.2

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_SAVE_PATH = './Model/MidiGen.pth'

@dataclass
class MidiEmbeddingConfig(PositionEmbeddingConfig):
    pitch_size: int
    dropout: float


@register_positional_embedding("midi", MidiEmbeddingConfig)
class MidiEmbedding(PositionEmbedding):
    def __init__(self, dim_model: int, seq_len: int, pitch_size: int, dropout: float = 0.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dim_model = dim_model
        self.seq_len = seq_len
        self.pitch_size = pitch_size
        self.dropout = torch.nn.Dropout(p=dropout)

        self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
        self.pitch_embeddings = nn.Embedding(
            self.pitch_size, self.dim_model - 3)

        self.position_ids: Optional[torch.Tensor] = None

    def init_weights(self, gain: float = 1.0):
        torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
        torch.nn.init.normal_(self.pitch_embeddings.weight, std=0.02 * gain)

    def forward(self, x: torch.Tensor):
        sentence = x[0]
        extra = x[1]

        position_ids = torch.arange(sentence.shape[1], dtype=torch.long, device=sentence.device)[
            None, :
        ].repeat(sentence.shape[0], 1)

        pitch_token = self.pitch_embeddings(sentence)

        x = torch.cat([pitch_token, extra], dim=-1)
        pos = self.position_embeddings(position_ids)

        X = x + pos
        X = self.dropout(X)

        return X


class MidiTransformer(nn.Module):
    def __init__(self, model_config) -> None:
        super().__init__()
        self.dim_model = model_config[1]['dim_model']
        self.model_config = xFormerConfig(model_config)
        self.transformer = xFormer.from_config(self.model_config)
        self.generator = nn.Sequential(
            nn.Linear(self.dim_model, self.dim_model*2), 
            nn.LeakyReLU(), 
            nn.Linear(self.dim_model*2, model_config[1]['position_encoding_config']['pitch_size']))
        self.extra_generator = nn.Sequential(
            nn.Linear(self.dim_model, self.dim_model*2), 
            nn.Linear(self.dim_model*2, 3))  # [time_since_last_note, duration, velocity]

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encode(src, src_mask)
        out = self.decode(tgt, memory, tgt_mask)
        return self.generator(out), self.extra_generator(out)

    def encode(self, src, src_mask=None):
        encoders = self.transformer.encoders
        memory = src[:]
        if isinstance(encoders, torch.nn.ModuleList):
            for encoder in encoders:
                memory = encoder(memory, input_mask=src_mask)
        else:
            if self.transformer.rev_enc_pose_encoding:
                memory = self.transformer.rev_enc_pose_encoding(src)

            # Reversible Encoder
            x = torch.cat([memory, memory], dim=-1)

            # Apply the optional input masking
            if src_mask is not None:
                if x.dim() - src_mask.dim() > 1:
                    src_mask.unsqueeze(0)
                x += src_mask.unsqueeze(-1)

            x = encoders(x)
            memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
        return memory

    def decode(self, tgt, memory, tgt_mask=None):
        for decoder in self.transformer.decoders:
            tgt = decoder(target=tgt, memory=memory, input_mask=tgt_mask)
        return tgt


model_config = [
    {
        "reversible": True,  # Reversible encoder can save a lot memory when training
        "block_type": "encoder",
        "num_layers": NUM_ENCODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "midi",  # The vocab type position encoding includes token embedding layer and position encoding layer
            "seq_len": MAX_LEN,
            "pitch_size": SRC_VOCAB_SIZE,
        },
        "multi_head_config": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "linformer",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            # Hidden layer dimension is HIDDEN_LAYER_MULTIPLIER times dim_model
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER,
        },
    },
    {
        "reversible": False,
        "block_type": "decoder",
        "num_layers": NUM_DECODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "midi",
            "seq_len": MAX_LEN,
            "pitch_size": TGT_VOCAB_SIZE,
        },
        "multi_head_config_masked": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "nystrom",
                "dropout": 0,
                "causal": True,  # Causal attention is used to prevent the decoder from attending the future tokens in the target sequences
                "seq_len": MAX_LEN,
            },
        },
        "multi_head_config_cross": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "favor",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER,
        },
    },
]

model = MidiTransformer(model_config=model_config)
model = model.to(DEVICE)

In [None]:
# Load model weights
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

In [None]:
def greedy_decode(model, src, src_mask=None, max_len=MAX_LEN, start_symbol=BOS_IDX):
    src = (src[0].to(DEVICE), src[1].to(DEVICE))
    memory = model.encode(src, src_mask)
    ys_token = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    ys_extra = torch.ones(1, 1, 3).fill_(PAD_VALUE).type(torch.long).to(DEVICE)
    ys = (ys_token, ys_extra)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:, -1, :])
        extra = model.extra_generator(out[:, -1, :]).unsqueeze(0)
        _, next_token = torch.max(prob, dim=1)
        next_token = next_token.item()
        ys_token = torch.cat([ys_token, torch.ones(1, 1).type_as(src[0].data).fill_(next_token)], dim=1)
        ys_extra = torch.cat([ys_extra, extra], dim=1)
        ys = (ys_token, ys_extra)
        if next_token == EOS_IDX:
            break
    return ys

def predict_next_midi_sentence(model, src):
    model.eval()
    with torch.no_grad():
        result = greedy_decode(model, src)
    return result



In [None]:
def plot_piano_roll(pm:pretty_midi.PrettyMIDI, start_pitch=pretty_midi.note_name_to_number('C2'), end_pitch=pretty_midi.note_name_to_number('G7'), fs=100):
    """
    Plot a piano roll of a pretty_midi object.

    Args:
        pm (pretty_midi.PrettyMIDI): A pretty_midi object.
        start_pitch (int): The MIDI pitch number to start plotting from.
        end_pitch (int): The MIDI pitch number to end plotting at.
        fs (int, optional): The sampling frequency used to generate the piano roll. Defaults to 100.

    Returns:
        None
    """
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))
    
def play_pm(pm:pretty_midi.PrettyMIDI, fs=44100) -> None:
    display(IPython.display.Audio(pm.synthesize(fs=fs), rate=fs))


def print_midi_info(pm:pretty_midi.PrettyMIDI):
    print(f'There are {len(pm.time_signature_changes)} time signature changes')
    print(f'There are {len(pm.instruments)} instruments')
    print(f'Tempo: {pm.estimate_tempo()}')
    for i, instrument in enumerate(pm.instruments):
        print('-'*40)
        print(f'Instrument {i} has {len(instrument.notes)} notes')
        notes_info = []
        for note in instrument.notes:
            note_info = {
                'Tick': pm.time_to_tick(note.start),
                'StartTime': note.start,
                'EndTime': note.end,
                'Pitch': note.pitch,
                'Note': pretty_midi.note_number_to_name(note.pitch),
                'Velocity': note.velocity,
                'Duration': note.end - note.start
            }
            notes_info.append(note_info)
        df = pd.DataFrame(notes_info)
        df.sort_values('Tick', inplace=True)
        print(df)
    print('-'*40)

def plot_piano_roll_with_beats(pm, start_pitch=pretty_midi.note_name_to_number('C2'), end_pitch=pretty_midi.note_name_to_number('G7'), fs=100, xlim=(0, 5)):
    """
    Plot a piano roll of a pretty_midi object, with beat and downbeat markers.

    Args:
        pm (pretty_midi.PrettyMIDI): A pretty_midi object.
        start_pitch (int): The MIDI pitch number to start plotting from.
        end_pitch (int): The MIDI pitch number to end plotting at.
        fs (int, optional): The sampling frequency used to generate the piano roll. Defaults to 100.
        xlim (int, optional): The x-limit of the plot. Defaults to 5.

    Returns:
        None
    """
    # Get beat and downbeat times
    beats = pm.get_beats()
    downbeats = pm.get_downbeats()

    # Plot piano roll
    plt.figure(figsize=(16, 4))
    plot_piano_roll(pm, start_pitch, end_pitch, fs=fs)
    ymin, ymax = plt.ylim()

    # Plot beats as grey lines, downbeats as red lines
    mir_eval.display.events(beats, base=ymin, height=ymax, color='#AAAAAA')
    mir_eval.display.events(downbeats, base=ymin, height=ymax, color='r')

    # Set xlim for clarity
    plt.xlim(xlim)
    plt.show()

def quantize_pretty_midi(pm:pretty_midi.PrettyMIDI, threshold=1/8):
    quantized_pm = pretty_midi.PrettyMIDI()
    beats = pm.get_beats()
    threshold = threshold * (beats[1]- beats[0])
    for instrument in pm.instruments:
        quantized_instrument = pretty_midi.Instrument(program=instrument.program)
        notes = []
        for note in instrument.notes:
            nearest_start_beat = min(beats, key=lambda x: abs(x - note.start))
            if abs(nearest_start_beat - note.start) <= threshold:
                note.start = nearest_start_beat
            if abs(note.end - note.start) < threshold:
                nearest_end_beat = min(beats, key=lambda x: abs(x - note.end))
                if abs(nearest_end_beat - note.end) <= threshold:
                    note.end = nearest_end_beat
            notes.append(note)
        notes.sort(key=lambda note: note.start)
        quantized_instrument.notes = notes
        quantized_pm.instruments.append(instrument)

    return quantized_pm

def pm_to_df(pm: pretty_midi.PrettyMIDI) -> pd.DataFrame:
    # Get the downbeats
    downbeats = pm.get_downbeats()

    # Initialize lists for each column
    sentence_list = []
    time_since_downbeat_list = []
    duration_list = []
    velocity_list = []
    time_since_last_note_start_list = []
    instrument_program_list = []
    track_list = []
    sentence_index_list = []

    # Initialize variable for last note start time
    last_note_start_time = 0

    # Iterate over each downbeat
    for i, downbeat in enumerate(downbeats):
        # Find the nearest downbeat to the previous downbeat
        if downbeat == downbeats[-1]:
            next_beat = downbeat + downbeats[-1] - downbeats[-2]
        else:
            next_beat = downbeats[i+1]
        # Get the notes between the previous downbeat and this downbeat
        section_notes = []
        note_instrument_map = {}
        for n, instrument in enumerate(pm.instruments):
            for note in instrument.notes:
                if next_beat > note.start >= downbeat:
                    section_notes.append(note)
                    note_instrument_map[note] = (instrument.program, n)
        # Sort the notes by start time
        section_notes.sort(key=lambda x: x.start)
        # Construct the sentence, time since downbeat, duration, and velocity lists
        sentence = [note.pitch for note in section_notes]
        time_since_downbeat = [note.start - downbeat for note in section_notes]
        duration = [note.end - note.start for note in section_notes]
        velocity = [note.velocity for note in section_notes]
        instrument_program = [note_instrument_map[note][0] for note in section_notes]
        track = [note_instrument_map[note][1] for note in section_notes]
        time_since_last_note_start = []
        for j, note in enumerate(section_notes):
            if j == 0:
                time_since_last_note_start.append(note.start - last_note_start_time)
            else:
                time_since_last_note_start.append(note.start - section_notes[j-1].start)
        # Append the lists to the overall lists
        if len(sentence) > 0:
            sentence_list.append(sentence)
            time_since_downbeat_list.append(time_since_downbeat)
            duration_list.append(duration)
            velocity_list.append(velocity)
            time_since_last_note_start_list.append(time_since_last_note_start)
            instrument_program_list.append(instrument_program)
            track_list.append(track)
            sentence_index_list.append(i)
        # Update last note start time
        if len(section_notes) > 0:
            last_note_start_time = section_notes[-1].start

    # Construct the DataFrame
    df = pd.DataFrame({
        "Sentence": sentence_list,
        "TimeSinceLastNoteStart": time_since_last_note_start_list,
        "TimeSinceDownbeat": time_since_downbeat_list,
        "Duration": duration_list,
        "Velocity": velocity_list,
        "InstrumentProgram": instrument_program_list,
        "Track": track_list,
        "SentenceIndex": sentence_index_list
    })
    df['Tempo'] = pm.estimate_tempo()
    df['MIDI'] = str(uuid.uuid4())[:8]
    

    return df

def df_to_pretty_midi(df:pd.DataFrame, program=0)->pretty_midi.PrettyMIDI:
    """
    Construct a PrettyMIDI object from a DataFrame with columns "Sentence", "Duration", "Velocity", "TimeSinceLastNoteStart".

    Args:
        df (pandas.DataFrame): The DataFrame to construct the PrettyMIDI object from.

    Returns:
        pretty_midi.PrettyMIDI: A PrettyMIDI object constructed from the DataFrame.
    """
    # Create an empty PrettyMIDI object
    pm = pretty_midi.PrettyMIDI()

    inst = pretty_midi.Instrument(program=program)
    pm.instruments.append(inst)

    time = 0
    # Iterate over each row in the DataFrame
    for i, row in df.iterrows():
        # Get the sentence, time since downbeat, duration, velocity, and time since last note start values
        sentence = row["Sentence"]
        duration = row["Duration"]
        velocity = row["Velocity"]
        time_since_last_note_start = row["TimeSinceLastNoteStart"]

        for j, pitch in enumerate(sentence):
            time += time_since_last_note_start[j]
            # Create a note object
            note = pretty_midi.Note(velocity=velocity[j], pitch=pitch, start=time, end=time+duration[j])
            inst.notes.append(note)

    return pm

def df_to_src(df:pd.DataFrame, idx=0):
    row = df.loc[idx]
    TimeSinceLastNoteStart = torch.tensor([PAD_VALUE] + row.at['TimeSinceLastNoteStart'] + [PAD_VALUE]).unsqueeze(-1)
    Duration = torch.tensor([PAD_VALUE] + row.at['Duration'] + [PAD_VALUE]).unsqueeze(-1)
    Velocity = torch.tensor([PAD_VALUE] + row.at['Velocity'] + [PAD_VALUE]).unsqueeze(-1)/100
    extra = torch.cat([TimeSinceLastNoteStart, Duration, Velocity], dim=-1)
    sentence = torch.tensor(row.at['Sentence'])
    sentence = torch.cat([torch.tensor([BOS_IDX]), sentence, torch.tensor([EOS_IDX])], dim=0)
    return (sentence.unsqueeze(0), extra.unsqueeze(0).float())

def src_to_df(src):
    sentence, extra = src
    sentence = sentence[0][1:-1].tolist()
    extra = extra[0, 1:-1].transpose(0, 1)
    TimeSinceLastNoteStart = extra[0].tolist()
    Duration = extra[1].tolist()
    Velocity = (extra[2]*100).tolist()

    df = pd.DataFrame({
        "Sentence": [sentence],
        "TimeSinceLastNoteStart": [TimeSinceLastNoteStart],
        "Duration": [Duration],
        "Velocity": [Velocity]
    })
    return df


In [None]:
# Lets load a midi file and check the data
pm = pretty_midi.PrettyMIDI('./Data/MaestroPianoMidi/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi')
# Lets quantize it and plot again
pm_quantized = quantize_pretty_midi(pm)
print('Original MIDI:')
print_midi_info(pm_quantized)
plot_piano_roll_with_beats(pm_quantized, xlim=(0, 5))
play_pm(pm_quantized)

# Lets convert it to dataframe
df = pm_to_df(pm_quantized)
display(df)

In [None]:

# Then convert df to src for feeding to the model
src = df_to_src(df, 2)
# Predict the next midi sentence
next_sentence = predict_next_midi_sentence(model, src)
# Convert the result to dataframe
next_sentence = src_to_df(next_sentence)
print('Predicted Next MIDI:')
display(next_sentence)
# Concatenate it with the original df
df_new = pd.concat([df, next_sentence])
# Convert to pretty midi object
pm_new = df_to_pretty_midi(df_new)
pm_new = quantize_pretty_midi(pm_new)
print_midi_info(pm_new)
# Plot and play
plot_piano_roll_with_beats(pm_new, xlim=None)
play_pm(pm_new)