load and preprocess data based on 001.

In [1]:
import os
import numpy as np
import pretty_midi
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [2]:
LATENT_DIM = 100
NOTE_LENGTH = 32
FEATURE_DIM = 4
BATCH_SIZE = 32
EPOCHS = 1000
MIDI_ROOT = path
def midi_to_note_vector(midi_file):
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
        notes = []
        for instrument in pm.instruments:
            if instrument.is_drum:
                continue
            for note in instrument.notes:
                notes.append([note.pitch, note.velocity, note.end - note.start, note.start])
        notes = sorted(notes, key=lambda x: x[3])
        note_vectors = []
        for i in range(len(notes) - 1):
            delta_time = notes[i+1][3] - notes[i][3]
            note_vectors.append([notes[i][0]/128, notes[i][1]/128, notes[i][2], delta_time])
            if len(note_vectors) == NOTE_LENGTH:
                break
        return np.array(note_vectors[:NOTE_LENGTH]) if len(note_vectors) == NOTE_LENGTH else None
    except Exception as e:
        print(f"Skipping {midi_file} due to error: {e}")
        return None
def load_midi_dataset(path):
    data = []
    for root, _, files in os.walk(path):
        for file in files:
            if file.endswith('.mid') or file.endswith('.midi'):
                vec = midi_to_note_vector(os.path.join(root, file))
                if vec is not None:
                    data.append(vec)
    return np.array(data, dtype=np.float32)

In [None]:
def build_generator():
    model = tf.keras.Sequential([
        layers.Input(shape=(LATENT_DIM,)),
        layers.Dense(256, activation='relu'),
        layers.Dense(NOTE_LENGTH * FEATURE_DIM),
        layers.Reshape((NOTE_LENGTH, FEATURE_DIM)),
        layers.Activation('sigmoid')
    ])
    return model
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Input(shape=(NOTE_LENGTH, FEATURE_DIM)),
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

def visualize_output(generator):
    noise = tf.random.normal([1, LATENT_DIM])
    generated = generator(noise, training=False)[0]
    plt.imshow(generated.numpy().T, aspect='auto', cmap='viridis')
    plt.title("Generated Note Sequence")
    plt.xlabel("Time Step")
    plt.ylabel("Features")
    plt.colorbar()
    plt.show()

In [None]:

def train_gan(generator, discriminator, dataset):
    loss_fn = tf.keras.losses.BinaryCrossentropy()
    gen_opt = tf.keras.optimizers.Adam(1e-4)
    disc_opt = tf.keras.optimizers.Adam(1e-4)

    gen_losses = []
    disc_losses = []

    @tf.function
    def train_step(real_batch):
        noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            fake_batch = generator(noise, training=True)
            real_output = discriminator(real_batch, training=True)
            fake_output = discriminator(fake_batch, training=True)
            gen_loss = loss_fn(tf.ones_like(fake_output), fake_output)
            disc_loss = loss_fn(tf.ones_like(real_output), real_output) + \
                        loss_fn(tf.zeros_like(fake_output), fake_output)
        grads_g = gen_tape.gradient(gen_loss, generator.trainable_variables)
        grads_d = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        gen_opt.apply_gradients(zip(grads_g, generator.trainable_variables))
        disc_opt.apply_gradients(zip(grads_d, discriminator.trainable_variables))
        return gen_loss, disc_loss

    for epoch in range(EPOCHS):
        idx = np.random.randint(0, dataset.shape[0], BATCH_SIZE)
        real_batch = dataset[idx]
        gen_loss, disc_loss = train_step(real_batch)

        gen_losses.append(gen_loss.numpy())
        disc_losses.append(disc_loss.numpy())

        if epoch % 100 == 0:
            print(f"Epoch {epoch}: Gen Loss={gen_loss.numpy():.4f}, Disc Loss={disc_loss.numpy():.4f}")
            visualize_output(generator)

    plt.figure(figsize=(10, 5))
    plt.plot(gen_losses, label='Generator Loss', color='blue')
    plt.plot(disc_losses, label='Discriminator Loss', color='red')
    plt.title('Epoch vs Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()
dataset = load_midi_dataset(MIDI_ROOT)
print(f"Loaded {len(dataset)} MIDI samples.")

generator = build_generator()
discriminator = build_discriminator()
train_gan(generator, discriminator, dataset)