In [None]:
!pip install pretty_midi
!pip install midi2audio
!apt install fluidsynth
!pip install tensorflow==2.15
!pip install tensorflow-probability==0.23

In [None]:
import glob, random
import numpy as np
import pretty_midi
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from midi2audio import FluidSynth
import IPython.display as ipd

In [None]:
class UnsupportedMidiFileException(Exception):
    "Unsupported MIDI File"

def transpose_to_c(midi: pretty_midi.PrettyMIDI, key_number: int) -> None:
    for inst in midi.instruments:
        if not inst.is_drum:
            for note in inst.notes:
                note.pitch -= key_number % 12

def get_pianoroll(midi: pretty_midi.PrettyMIDI, note_low: int, note_high: int, seqlen: int, tempo: float) -> np.ndarray:
    pianoroll: np.ndarray = midi.get_piano_roll(fs=2*tempo/60)
    if pianoroll.shape[1] < seqlen:
        raise UnsupportedMidiFileException
    pianoroll = pianoroll[note_low:note_high, 0:seqlen]
    pianoroll = np.heaviside(pianoroll, 0)
    return np.transpose(pianoroll)

def read_midi(filename: str, is_sep_sop_alt: bool, seqlen: int) -> tuple[np.ndarray, np.ndarray, np.ndarray] \
                                                                 | tuple[np.ndarray, np.ndarray]:
    midi = pretty_midi.PrettyMIDI(filename)
    if len(midi.key_signature_changes) != 1:
        raise UnsupportedMidiFileException

    key_number: int = midi.key_signature_changes[0].key_number
    transpose_to_c(midi, key_number)
    key_mode: np.ndarray = np.array([int(key_number/12)])
    tempo_time, tempo = midi.get_tempo_changes()
    if len(tempo) != 1:
        raise UnsupportedMidiFileException

    if is_sep_sop_alt:
        if len(midi.instruments) < 2:
            raise UnsupportedMidiFileException
        pianoroll_sop: np.ndarray = get_pianoroll(midi.instruments[0], 36, 84, seqlen, tempo[0])
        pianoroll_alt: np.ndarray = get_pianoroll(midi.instruments[1], 36, 84, seqlen, tempo[0])
        return pianoroll_sop, pianoroll_alt, key_mode
    else:
        pianoroll: np.ndarray = get_pianoroll(midi, 36, 84, seqlen, tempo[0])
        return pianoroll, key_mode

In [None]:
def make_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    midi: pretty_midi.PrettyMIDI = pretty_midi.PrettyMIDI(resolution=480)
    for pianoroll in pianorolls:
        inst: pretty_midi.Instrument = pretty_midi.Instrument(program=1)
        for i in range(pianoroll.shape[0]):
            for j in range(pianoroll.shape[1]):
                if pianoroll[i,j] > 0.5:
                    inst.notes.append(
                        pretty_midi.Note(start=i/2, end=(i+1)/2, pitch=j+36, velocity=100)
                    )
        midi.instruments.append(inst)
    midi.write(filename)

def show_and_play_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    for pianoroll in pianorolls:
        plt.matshow(np.transpose(pianoroll))
        plt.show()
    make_midi(pianorolls, filename)

    fs: FluidSynth = FluidSynth(sound_font="/usr/share/sounds/sf2/FluidR3_GM.sf2")
    fs.midi_to_audio(filename, "output.wav")
    ipd.display(ipd.Audio("output.wav"))

In [None]:
dir: str = "/content/drive/MyDrive/impl/musdl/chorales/midi/"
filenames: list[str] = []
xs: list[np.ndarray] = []

for f in glob.glob(f"{dir}*.mid"):
    print(f)
    try:
        x, _ = read_midi(f, is_sep_sop_alt=False, seqlen=32)
        filenames.append(f)
        xs.append(x)
    except UnsupportedMidiFileException:
        print("skip")

x_all: np.ndarray = np.array(xs)

In [None]:
print(x_all.shape)

In [None]:
seq_length: int = x_all.shape[1]
input_dim: int = x_all.shape[2]
encoded_dim: int = 16
hidden_dim: int = 2048

In [None]:
prior: tfp.distributions.Independent = tfp.distributions.Independent(
    tfp.distributions.Normal(loc=tf.zeros(encoded_dim), scale=1),
    reinterpreted_batch_ndims=1
)

In [None]:
encoder: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Input(
        shape=(seq_length, input_dim, 1)
    ),
    tf.keras.layers.Conv2D(
        filters=hidden_dim, kernel_size=(1, input_dim), strides=1, padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2D(
        filters=hidden_dim, kernel_size=(4, 1), strides=(4, 1), padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2D(
        filters=hidden_dim, kernel_size=(4, 1), strides=(4, 1), padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        tfp.layers.MultivariateNormalTriL.params_size(encoded_dim), activation=None
    ),
    tfp.layers.MultivariateNormalTriL(
        encoded_dim,
        activity_regularizer=tfp.layers.KLDivergenceRegularizer(prior, weight=0.001)
    )
])
encoder.summary()

In [None]:
decoder: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Dense(
        hidden_dim, input_dim=encoded_dim, activation="relu"
    ),
    tf.keras.layers.Reshape(
        target_shape=(1, 1, hidden_dim)
    ),

    tf.keras.layers.Conv2DTranspose(
        filters=hidden_dim, kernel_size=(4, 1), strides=(4, 1), padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2DTranspose(
        filters=hidden_dim, kernel_size=(4, 1), strides=(4, 1), padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2DTranspose(
        filters=hidden_dim, kernel_size=(2, 1), strides=(2, 1), padding="valid", activation="relu"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),
    tf.keras.layers.Dropout(0.3),

    tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=(1, input_dim), strides=1, padding="valid", activation="sigmoid"
    )
])
decoder.summary()

In [None]:
vae: tf.keras.Model = tf.keras.Model(encoder.inputs, decoder(encoder.outputs))
vae.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
    loss="binary_crossentropy",
    metrics="binary_accuracy"
)
vae.summary()

In [None]:
vae.fit(x_all, x_all, batch_size=32, epochs=256)

In [None]:
rng: np.random.Generator = np.random.default_rng()
z: np.ndarray = rng.multivariate_normal(np.zeros(encoded_dim), np.identity(encoded_dim))
x: np.ndarray = decoder.predict(np.array([z]))
show_and_play_midi([np.squeeze(x)], "output.mid")