# Compare models

In [1]:
# from src.common.metrics.fid import calculate_fid
from src.models.diffusion.ddpm_trainer import DDPMTrainer
from src.models.diffusion.ddpm import DDPM
from src.models.representation.ae.auto_encoder import Autoencoder, Decoder, Encoder
from src.common.diagnostic.summary import show_summary
import pickle
import torch

In [2]:
events_dataset = pickle.load(open("data/preprocessed_note_events.pkl", "rb"))

In [3]:
print(events_dataset[0])

tensor([[67.0000,  0.4016,  1.6628,  0.6615],
        [69.0000,  0.4724,  0.6497,  0.2812],
        [70.0000,  0.5354,  0.2669,  0.8659],
        [69.0000,  0.3780,  0.8190,  0.1315],
        [67.0000,  0.4803,  0.1211,  0.6797],
        [67.0000,  0.4094,  0.6849,  0.1497],
        [74.0000,  0.4882,  0.0898,  1.7865],
        [55.0000,  0.3937,  1.7904,  0.4648]])


## DDPM

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64

encoder = Encoder(
    input_dim=4,
    hidden_dim=128,
    latent_dim=64,
    num_layers=4
)

decoder = Decoder(
    latent_dim=64,
    hidden_dim=128,
    num_layers=4,
    output_dim=4
)

ae_model = Autoencoder(
    encoder=encoder,
    decoder=decoder,
).to(DEVICE)

diffusion = DDPM(1_000)

trainer = DDPMTrainer(
    model=ae_model,
    optimizer=None,
    diffusion=diffusion,
    run_name=None,
)

trainer.load_model(
    f"./models/denoisers/ae/ddpm_midi_autoencoder/ddpm_midi_autoencoder.pth/ddpm_midi_autoencoder.pth",
)

show_summary(ae_model, input_shape=events_dataset[0].shape, batch_size=BATCH_SIZE, dataset=events_dataset)

Loading model from models\denoisers\ae\ddpm_midi_autoencoder\ddpm_midi_autoencoder.pth\ddpm_midi_autoencoder.pth
Model loaded from ./models/denoisers/ae/ddpm_midi_autoencoder/ddpm_midi_autoencoder.pth/ddpm_midi_autoencoder.pth
Autoencoder(
  (encoder): Encoder(
    (diff_timestep_embedding): Embedding(1000, 128)
    (lstm): LSTM(4, 128, num_layers=4, batch_first=True, dropout=0.1, bidirectional=True)
    (linear): Linear(in_features=256, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (lstm): LSTM(4, 128, num_layers=4, batch_first=True)
    (mom): MemoryOverwriteModule(
      (forget_gate): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Sigmoid()
      )
      (overwrite_sig): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Sigmoid()
      )
      (overwrite_tanh): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Tanh()
      )
    )
    (latent_to_

In [None]:
import pretty_midi
from scipy.io.wavfile import write
import numpy as np

def midiToWav(midi_path, wav_path):
    """
    Convert MIDI file to WAV file using pretty_midi.
    """
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    audio_data = midi_data.fluidsynth()
    write(wav_path, 44100, audio_data.astype(np.float32))

In [6]:
import os
import math
from tqdm import tqdm


OUTPUT_DATA_PATH = "data/output"
REAL_PAHT = f"{OUTPUT_DATA_PATH}/real"
GENERATED_PATH = f"{OUTPUT_DATA_PATH}/generated"
MIDI_PATH = f"{OUTPUT_DATA_PATH}/midi"

REAL_MIDI_PATH = "data/midi_dataset/midis"

FID_SAMPLE_SIZE = 1000
BATCH_SIZE = 64
NUM_TIMESTEPS = 1000

device = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(REAL_PAHT, exist_ok=True)
os.makedirs(GENERATED_PATH, exist_ok=True)
os.makedirs(MIDI_PATH, exist_ok=True)

In [None]:
# transform real dataset to .wav
real_files = [f for f in os.listdir(REAL_MIDI_PATH) if f.endswith('.mid')][:FID_SAMPLE_SIZE]
transformed_real_files = os.listdir(REAL_PAHT)[:FID_SAMPLE_SIZE]

if len(real_files) == len(transformed_real_files):
    print("Real data already converted to WAV.")
else:
    for i, midi_path in enumerate(tqdm(real_files, desc="Converting MIDI to WAV")):
        wav_path = f"{REAL_PAHT}/data_{i}.wav"
        midiToWav(f"{REAL_MIDI_PATH}/{midi_path}", wav_path)

Converting MIDI to WAV:   0%|          | 0/10854 [00:00<?, ?it/s]




Converting MIDI to WAV:   0%|          | 1/10854 [00:00<1:27:11,  2.07it/s]




Converting MIDI to WAV:   0%|          | 2/10854 [00:04<7:23:46,  2.45s/it]




Converting MIDI to WAV:   0%|          | 3/10854 [00:04<4:46:38,  1.59s/it]




Converting MIDI to WAV:   0%|          | 4/10854 [00:08<7:23:26,  2.45s/it]




Converting MIDI to WAV:   0%|          | 5/10854 [00:09<5:37:50,  1.87s/it]




Converting MIDI to WAV:   0%|          | 6/10854 [00:10<4:22:54,  1.45s/it]




Converting MIDI to WAV:   0%|          | 7/10854 [00:12<5:17:37,  1.76s/it]





KeyboardInterrupt: 

In [None]:
# generate fake samples
iterations = math.ceil(FID_SAMPLE_SIZE / BATCH_SIZE)

midi_count = 0
for _ in range(iterations):
    noise = torch.randn(BATCH_SIZE, 3, 32, 32, device=device)
    generated = sampler.p_sample_loop(model, noise, num_inference_steps=NUM_TIMESTEPS, clip=True, quiet=True)
    generated = torch.from_numpy(generated)
    
    for midi in generated:
        midi = midi.cpu().numpy()
        note_events_to_pretty_midi(midi, path=f"{MIDI_PATH}/sample_{midi_count}.mid", default_program=0)
        midi_count += 1

In [None]:
# # %pip install numba==0.48.0

# # Download and load a popular pretrained encoder for symbolic music (MusicVAE - Magenta)
# import urllib.request
# import os

# encoder_path = './models/magenta/cat-mel_2bar_small.tar'
# # if not os.path.exists(encoder_path):
# #     url = 'https://huggingface.co/magenta/music-vae/resolve/main/mel_2bar_small.pth'  # Example: MusicVAE 2-bar melody encoder
# #     urllib.request.urlretrieve(url, encoder_path)

# # MusicVAE encoder architecture (simplified, adjust as needed for your use case)
# from magenta.models.music_vae.trained_model import TrainedModel

# music_vae = TrainedModel(
#     config='cat-mel_2bar_small',
#     batch_size=64,
#     checkpoint_dir_or_path=encoder_path
# )
# pretrained_encoder = music_vae._config.data_converter
# # Note: You may need to adapt your pipeline to use the MusicVAE encoder output for FID calculation.

In [None]:
# fid = calculate_fid(
#     model=ae_model,
#     sampler=trainer.sampler,
#     encoder=pretrained_encoder, # use pretrained encoder for FID calculation
# )

---