In [None]:
!pip install pretty_midi
!pip install midi2audio
!apt install fluidsynth
!pip install tensorflow==2.15
!pip install tensorflow-probability==0.23

In [None]:
import glob, random
import numpy as np
import pretty_midi
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from midi2audio import FluidSynth
import IPython.display as ipd

In [None]:
class UnsupportedMidiFileException(Exception):
    "Unsupported MIDI File"

def transpose_to_c(midi: pretty_midi.PrettyMIDI, key_number: int) -> None:
    for inst in midi.instruments:
        if not inst.is_drum:
            for note in inst.notes:
                note.pitch -= key_number % 12

def get_pianoroll(midi: pretty_midi.PrettyMIDI, note_low: int, note_high: int, seqlen: int, tempo: float) -> np.ndarray:
    pianoroll: np.ndarray = midi.get_piano_roll(fs=2*tempo/60)
    if pianoroll.shape[1] < seqlen:
        raise UnsupportedMidiFileException
    pianoroll = pianoroll[note_low:note_high, 0:seqlen]
    pianoroll = np.heaviside(pianoroll, 0)
    return np.transpose(pianoroll)

def read_midi(filename: str, is_sep_sop_alt: bool, seqlen: int) -> tuple[np.ndarray, np.ndarray, np.ndarray] \
                                                                 | tuple[np.ndarray, np.ndarray]:
    midi = pretty_midi.PrettyMIDI(filename)
    if len(midi.key_signature_changes) != 1:
        raise UnsupportedMidiFileException

    key_number: int = midi.key_signature_changes[0].key_number
    transpose_to_c(midi, key_number)
    key_mode: np.ndarray = np.array([int(key_number/12)])
    tempo_time, tempo = midi.get_tempo_changes()
    if len(tempo) != 1:
        raise UnsupportedMidiFileException

    if is_sep_sop_alt:
        if len(midi.instruments) < 2:
            raise UnsupportedMidiFileException
        pianoroll_sop: np.ndarray = get_pianoroll(midi.instruments[0], 36, 84, seqlen, tempo[0])
        pianoroll_alt: np.ndarray = get_pianoroll(midi.instruments[1], 36, 84, seqlen, tempo[0])
        return pianoroll_sop, pianoroll_alt, key_mode
    else:
        pianoroll: np.ndarray = get_pianoroll(midi, 36, 84, seqlen, tempo[0])
        return pianoroll, key_mode

def add_rest_notes(pianoroll: np.ndarray) -> np.ndarray:
    rests: np.ndarray = 1 - np.sum(pianoroll, axis=1)
    return np.concatenate([pianoroll, np.expand_dims(rests, 1)], axis=1)

In [None]:
def make_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    midi: pretty_midi.PrettyMIDI = pretty_midi.PrettyMIDI(resolution=480)
    for pianoroll in pianorolls:
        inst: pretty_midi.Instrument = pretty_midi.Instrument(program=1)
        for i in range(pianoroll.shape[0]):
            for j in range(pianoroll.shape[1]):
                if pianoroll[i,j] > 0.5:
                    inst.notes.append(
                        pretty_midi.Note(start=i/2, end=(i+1)/2, pitch=j+36, velocity=100)
                    )
        midi.instruments.append(inst)
    midi.write(filename)

def show_and_play_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    for pianoroll in pianorolls:
        plt.matshow(np.transpose(pianoroll))
        plt.show()
    make_midi(pianorolls, filename)

    fs: FluidSynth = FluidSynth(sound_font="/usr/share/sounds/sf2/FluidR3_GM.sf2")
    fs.midi_to_audio(filename, "output.wav")
    ipd.display(ipd.Audio("output.wav"))

In [None]:
dir: str = "/content/drive/MyDrive/impl/musdl/chorales/midi/"
filenames: list[str] = []
xs: list[np.ndarray] = []
keymodes: list[np.ndarray] = []

for f in glob.glob(f"{dir}*.mid"):
    print(f)
    try:
        sop, _, key_mode = read_midi(f, is_sep_sop_alt=True, seqlen=64)
        x = add_rest_notes(sop)
        filenames.append(f)
        xs.append(x)
        keymodes.append(key_mode)
    except UnsupportedMidiFileException:
        print("skip")

x_all: np.ndarray = np.array(xs)

In [None]:
print(x_all.shape)

In [None]:
idxs_train, idxs_test = train_test_split(range(len(x_all)), test_size=len(x_all)//2, shuffle=False)
x_train: np.ndarray = x_all[idxs_train]
x_test: np.ndarray = x_all[idxs_test]

In [None]:
seq_length: int = x_train.shape[1]
input_dim: int = x_train.shape[2]
encoded_dim: int = 16
lstm_dim: int = 1024

In [None]:
prior: tfp.distributions.Independent = tfp.distributions.Independent(
    tfp.distributions.Normal(loc=tf.zeros(encoded_dim), scale=1),
    reinterpreted_batch_ndims=1
)

In [None]:
encoder: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Input(
        shape=(seq_length, input_dim)
    ),
    tf.keras.layers.LSTM(
        lstm_dim,
        use_bias=True,
        activation="tanh",
        return_sequences=False
    ),
    tf.keras.layers.Dense(
        tfp.layers.MultivariateNormalTriL.params_size(encoded_dim), activation=None
    ),
    tfp.layers.MultivariateNormalTriL(
        encoded_dim,
        activity_regularizer=tfp.layers.KLDivergenceRegularizer(prior, weight=0.001)
    ),
])
encoder.summary()

In [None]:
decoder: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(encoded_dim,)),
    tf.keras.layers.RepeatVector(seq_length),
    tf.keras.layers.LSTM(lstm_dim, use_bias=True, activation="tanh", return_sequences=True),
    tf.keras.layers.Dense(input_dim, use_bias=True, activation="softmax")
])
decoder.summary()

In [None]:
vae: tf.keras.Model = tf.keras.Model(encoder.inputs, decoder(encoder.outputs))
vae.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["categorical_accuracy"])
vae.summary()

In [None]:
vae.fit(x_train, x_train, batch_size=32, epochs=1000)

In [None]:
z: np.ndarray = encoder.predict(x_test)
x_new: np.ndarray = decoder.predict(z)

In [None]:
k: int = random.randint(0, len(idxs_test))
print(f"melody id: {k}")
show_and_play_midi([x_test[k, :, 0:-1]], "input.mid")
show_and_play_midi([x_new[k, :, 0:-1]], "output.mid")

In [None]:
rate: float = 0.5
k: int = random.randint(0, len(idxs_test))
l: int = random.randint(0, len(idxs_test))
print(f"melody id: {k}, {l}")

z_new: np.ndarray = rate * z[k] + (1 - rate) * z[l]
x_new: np.ndarray = decoder.predict(np.array([z_new]))
show_and_play_midi([x_test[k, :, 0:-1]], "input1.mid")
show_and_play_midi([x_test[l, :, 0:-1]], "input2.mid")
show_and_play_midi([x_new[0, :, 0:-1]], "output.mid")