In [None]:
!pip install pretty_midi
!pip install midi2audio
!apt install fluidsynth

In [None]:
import glob, random
import numpy as np
import pretty_midi
import tensorflow as tf
import matplotlib.pyplot as plt
from midi2audio import FluidSynth
import IPython.display as ipd

In [None]:
tf.config.list_physical_devices("GPU")

In [None]:
class UnsupportedMidiFileException(Exception):
    "Unsupported MIDI File"

def transpose_to_c(midi: pretty_midi.PrettyMIDI, key_number: int) -> None:
    for inst in midi.instruments:
        if not inst.is_drum:
            for note in inst.notes:
                note.pitch -= key_number % 12

def get_pianoroll(midi: pretty_midi.PrettyMIDI, note_low: int, note_high: int, seqlen: int, tempo: float) -> np.ndarray:
    pianoroll: np.ndarray = midi.get_piano_roll(fs=2*tempo/60)
    if pianoroll.shape[1] < seqlen:
        raise UnsupportedMidiFileException
    pianoroll = pianoroll[note_low:note_high, 0:seqlen]
    pianoroll = np.heaviside(pianoroll, 0)
    return np.transpose(pianoroll)

def read_midi(filename: str, is_sep_sop_alt: bool, seqlen: int) -> tuple[np.ndarray, np.ndarray, np.ndarray] \
                                                                 | tuple[np.ndarray, np.ndarray]:
    midi = pretty_midi.PrettyMIDI(filename)
    if len(midi.key_signature_changes) != 1:
        raise UnsupportedMidiFileException

    key_number: int = midi.key_signature_changes[0].key_number
    transpose_to_c(midi, key_number)
    key_mode: np.ndarray = np.array([int(key_number/12)])
    tempo_time, tempo = midi.get_tempo_changes()
    if len(tempo) != 1:
        raise UnsupportedMidiFileException

    if is_sep_sop_alt:
        if len(midi.instruments) < 2:
            raise UnsupportedMidiFileException
        pianoroll_sop: np.ndarray = get_pianoroll(midi.instruments[0], 36, 84, seqlen, tempo[0])
        pianoroll_alt: np.ndarray = get_pianoroll(midi.instruments[1], 36, 84, seqlen, tempo[0])
        return pianoroll_sop, pianoroll_alt, key_mode
    else:
        pianoroll: np.ndarray = get_pianoroll(midi, 36, 84, seqlen, tempo[0])
        return pianoroll, key_mode

In [None]:
def make_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    midi: pretty_midi.PrettyMIDI = pretty_midi.PrettyMIDI(resolution=480)
    for pianoroll in pianorolls:
        inst: pretty_midi.Instrument = pretty_midi.Instrument(program=1)
        for i in range(pianoroll.shape[0]):
            for j in range(pianoroll.shape[1]):
                if pianoroll[i,j] > 0.5:
                    inst.notes.append(
                        pretty_midi.Note(start=i/2, end=(i+1)/2, pitch=j+36, velocity=100)
                    )
        midi.instruments.append(inst)
    midi.write(filename)

def show_and_play_midi(pianorolls: list[np.ndarray], filename: str) -> None:
    for pianoroll in pianorolls:
        plt.matshow(np.transpose(pianoroll))
        plt.show()
    make_midi(pianorolls, filename)

    fs: FluidSynth = FluidSynth(sound_font="/usr/share/sounds/sf2/FluidR3_GM.sf2")
    fs.midi_to_audio(filename, "output.wav")
    ipd.display(ipd.Audio("output.wav"))

In [None]:
dir: str = "/content/drive/MyDrive/impl/musdl/chorales/midi/"
filenames: list[str] = []
xs: list[np.ndarray] = []

for f in glob.glob(f"{dir}*.mid"):
    print(f)
    try:
        x, _ = read_midi(f, is_sep_sop_alt=False, seqlen=16)
        filenames.append(f)
        xs.append(x)
    except UnsupportedMidiFileException:
        print("skip")

x_all: np.ndarray = np.expand_dims(np.array(xs), axis=-1)

In [None]:
print(x_all.shape)

In [None]:
seq_length: int = x_all.shape[1]
dim: int = x_all.shape[2]
encoded_dim: int = 32
hidden_dim: int = 1024

In [None]:
generator: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(encoded_dim, )),
    tf.keras.layers.Dense(hidden_dim),
    tf.keras.layers.LeakyReLU(0.3),
    tf.keras.layers.Reshape(
        target_shape=(1, 1, hidden_dim)
    ),
    tf.keras.layers.Conv2DTranspose(
        hidden_dim, (4, 1), strides=(4, 1), padding="valid"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2DTranspose(
        hidden_dim, (4, 1), strides=(4, 1), padding="valid"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2DTranspose(
        1, (1, dim), strides=1, padding="valid", activation="sigmoid"
    ),
])
generator.summary()

In [None]:
discriminator: tf.keras.Sequential = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(seq_length, dim, 1)),
    tf.keras.layers.Conv2D(
        hidden_dim, (1, dim), strides=1, padding="valid"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2D(
        hidden_dim, (4, 1), strides=(4, 1), padding="valid"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Conv2D(
        hidden_dim, (4, 1), strides=(4, 1), padding="valid"
    ),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.3),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(1, use_bias=True, activation="sigmoid")
])
discriminator.summary()

In [None]:
discriminator.compile(optimizer="adam", loss="binary_crossentropy", metrics=["binary_accuracy"])
discriminator.trainable = False

gan = tf.keras.Model(generator.inputs, discriminator(generator.outputs))
gan.compile(optimizer="adam", loss="binary_crossentropy", metrics=["binary_accuracy"])

In [None]:
label_noisy: bool = True
iterations: int = 10000
batch_size: int = 64

idx_from: int = 0
for step in range(1, iterations+1):
    # train discriminator
    rvs: np.ndarray = np.random.normal(size=(batch_size, encoded_dim))  # seed for generation
    print(rvs.shape)
    x_gen: np.ndarray = generator.predict(rvs, verbose=0)  # generated pianoroll
    x_real: np.ndarray = x_all[idx_from:idx_from+batch_size]  # real pianoroll
    x: np.ndarray = np.concatenate([x_real, x_gen])
    print(x.shape)

    # make label: fake=1, real=0
    if label_noisy:
        labels: np.ndarray = np.concatenate([
            np.zeros((batch_size, 1)) + 0.2 * np.abs(np.random.random((batch_size, 1))),
            np.ones((batch_size, 1)) - 0.2 * np.abs(np.random.random((batch_size, 1)))
        ])
    else:
        labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])

    loss = discriminator.train_on_batch(x, labels)
    if step % 50 == 0:
        print(f"{step}: D loss = {loss}")

    # train generator
    rvs = np.random.normal(size=(batch_size, encoded_dim))  # seed for generation
    mislead_labels: np.ndarray = np.ones((batch_size, 1))  # label for discriminator
    for i in range(5):
        loss = gan.train_on_batch(rvs, mislead_labels)
    if step % 50 == 0:
        print(f"{step}: G loss = {loss}")

    idx_from += batch_size
    if idx_from + batch_size > len(x_all):
        idx_from = 0

In [None]:
my_z: np.ndarray = np.random.multivariate_normal(
    np.zeros(encoded_dim), np.identity(encoded_dim)
)
print(my_z)

my_x = generator.predict(np.array([my_z]))
show_and_play_midi([np.squeeze(my_x)], "output.mid")