## Import libraries

In [1]:
from models.model_v2 import *
from data.midi_preprocessing import *
from utils.dataset_loader import MaestroV3DataModule, MaestroV3DataSet

  import pkg_resources


## Loss callback

In [2]:
# Callback to track the loss of the minmax game.
class LossTracker(L.Callback):
    def __init__(self):
        self.g_losses = []
        self.d_losses = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
    #def on_train_epoch_end(self, trainer, pl_module):
        g_loss = trainer.callback_metrics.get("g_loss")
        d_loss = trainer.callback_metrics.get("d_loss")
        self.g_losses.append(g_loss.item())
        self.d_losses.append(d_loss.item())

    def plot(self):
        display.clear_output(wait=True)
        plt.plot(self.g_losses, label="Generator Loss")
        plt.plot(self.d_losses, label="Discriminator Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training Losses")
        plt.legend()
        plt.grid(True)
        plt.show()

## Train the model

In [3]:
# Create the GAN.
model = GAN(
    lr=0.0002,
    b1=0.5,
    b2=0.999,
    lambda_1=0.1,
    lambda_2=1,
    latent_dim=100,
    gen_updates=2,
    dis_updates=1,
    mbd_B_dim=50,
    mbd_C_dim=5,
    batch_size=72,
    a=32,
    apply_mbd=True
)

# Dataset.
data_file_path = "data/preprocessed/maestro-v3.0.0/dataset2/dataset.h5"
dm = MaestroV3DataModule(data_file_path, mode="pair")

# LossTracker.
loss_tracker = LossTracker()

# Optuna params:
    #lr: 1.3422089190108735e-05
    #b1: 0.6685047603190793
    #b2: 0.9570758815656579
    #lambda_1: 0.5571400649008067
    #lambda_2: 0.36707966570014705
    #latent_dim: 94
    #gen_updates: 1
    #dis_updates: 4
    #mbd_B_dim: 28
    #mbd_C_dim: 21
    #batch_size: 16
    #a: 120
    #apply_mbd: False

# Define the trainer.
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=100,
    callbacks=[loss_tracker]
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dm)

/usr/lib/python3.13/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 4070 SUPER') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | Mode 
--------------------------------------------------------
0 | generator     | Generator     | 523 K  | train
1 | discriminator | Discriminator | 355 K  | train
--------------------------------------------------------
879 K     Trainable params
0         Non-trainable params
879 K     Total params
3.516     Total estimated model params size (MB)
57        Modules 

Epoch 0:  57%|████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                       | 653/1140 [00:37<00:27, 17.56it/s, v_num=402, d_loss=1.360, g_loss=0.785]

In [None]:
# Plot generator and discriminator losses.
loss_tracker.plot()

In [None]:
#plt.plot(loss_tracker.g_losses, label="Generator Loss")
#plt.plot(loss_tracker.d_losses, label="Discriminator Loss")
#plt.xlabel("Epoch")
#plt.ylabel("Loss")
#plt.title("Training Losses")
#plt.legend()
#plt.grid(True)
#plt.savefig("200_epoch_model_v2_losses.pdf")

## Testing

In [None]:
import random

# Define the dataset.
dataset = MaestroV3DataSet(data_file_path, mode="pair")

# Random index.
rnd_idx = random.randint(0, len(dataset))

# Choose the first random sample from the dataset.
bar_0, _ = dataset[rnd_idx] # [1, 128, 16]
bar_0 = bar_0.unsqueeze(0)
print(bar_0.shape)

In [None]:
# Set model in evaluation.
model.eval()

# Generate noise.
noise = torch.randn(7, 1, 94)

#for z in noise:
#    print(z.shape)

# Generate 8 bar.
bars = [bar_0]
for i, z in enumerate(noise):
    # Previous bar.
    prev = bars[i-1]

    # Create the pair.
    x = z, prev

    # Generate current bar.
    curr = model(x)

    # Save genjerated bar
    bars.append(curr)

In [None]:
# Convert bars in numpy array.
bars_numpy = []
for bar in bars:
    bar = bar.squeeze(0, 1).detach().numpy()
    #print(bar.shape)
    bars_numpy.append(bar)

In [None]:
# Plot each single bar.
for i, bar in enumerate(bars_numpy):
    print("Bar", i)
    show_piano_roll(bar)

In [None]:
# Create the full piano roll.
full_piano_roll = np.hstack([bar for bar in bars_numpy])
print("Full piano roll")
print("Shape:", full_piano_roll.shape)
show_piano_roll(full_piano_roll)

# Multiply by 50.
full_piano_roll *= 50

In [None]:
# Create the output midi file.
pm = piano_roll_to_pretty_midi(full_piano_roll, fs=8)
output_fn = "output_test.midi"
pm.write(output_fn)