# Compare models

In [2]:
# from src.common.metrics.fid import calculate_fid
from src.models.diffusion.ddpm_trainer import DDPMTrainer
from src.models.diffusion.ddpm import DDPM
from src.models.representation.ae.auto_encoder import Autoencoder, Decoder, Encoder
from src.common.diagnostic.summary import show_summary
import pickle
import torch

In [3]:
events_dataset = pickle.load(open("data/preprocessed_note_events.pkl", "rb"))

In [3]:
print(events_dataset[0])

tensor([[67.0000,  0.4016,  1.6628,  0.6615],
        [69.0000,  0.4724,  0.6497,  0.2812],
        [70.0000,  0.5354,  0.2669,  0.8659],
        [69.0000,  0.3780,  0.8190,  0.1315],
        [67.0000,  0.4803,  0.1211,  0.6797],
        [67.0000,  0.4094,  0.6849,  0.1497],
        [74.0000,  0.4882,  0.0898,  1.7865],
        [55.0000,  0.3937,  1.7904,  0.4648]])


## DDPM

In [30]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64

encoder = Encoder(
    input_dim=4,
    hidden_dim=128,
    latent_dim=64,
    num_layers=4
)

decoder = Decoder(
    latent_dim=64,
    hidden_dim=128,
    num_layers=4,
    output_dim=4
)

ae_model = Autoencoder(
    encoder=encoder,
    decoder=decoder,
).to(DEVICE)

ddpm = DDPM(1_000)

trainer = DDPMTrainer(
    model=ae_model,
    optimizer=None,
    diffusion=ddpm,
    run_name=None,
)

trainer.load_model(
    f"./models/denoisers/ae/ddpm_midi_autoencoder/ddpm_midi_autoencoder.pth/ddpm_midi_autoencoder.pth",
)

show_summary(ae_model, input_shape=events_dataset[0].shape, batch_size=BATCH_SIZE, dataset=events_dataset)

Loading model from models\denoisers\ae\ddpm_midi_autoencoder\ddpm_midi_autoencoder.pth\ddpm_midi_autoencoder.pth
Model loaded from ./models/denoisers/ae/ddpm_midi_autoencoder/ddpm_midi_autoencoder.pth/ddpm_midi_autoencoder.pth
Autoencoder(
  (encoder): Encoder(
    (diff_timestep_embedding): Embedding(1000, 128)
    (lstm): LSTM(4, 128, num_layers=4, batch_first=True, dropout=0.1, bidirectional=True)
    (linear): Linear(in_features=256, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (lstm): LSTM(4, 128, num_layers=4, batch_first=True)
    (mom): MemoryOverwriteModule(
      (forget_gate): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Sigmoid()
      )
      (overwrite_sig): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Sigmoid()
      )
      (overwrite_tanh): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): Tanh()
      )
    )
    (latent_to_

In [19]:
import pretty_midi
import numpy as np

@torch.inference_mode
def note_events_to_pretty_midi(note_array: torch.Tensor | np.ndarray, path="eg.mid", default_program=0):
    if isinstance(note_array, torch.Tensor):
        note_array = note_array.detach().cpu().numpy()
    #scale  and velocities
    note_array[:, 1] *= 127
    # clamp pitches and velocities
    note_array[:, 0] = np.clip(note_array[:, 0], 0, 127)
    note_array[:, 1] = np.clip(note_array[:, 1], 0, 127)
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=default_program)

    current_time = 0.0
    # print(note_array.shape)
    for row in note_array:
        pitch, velocity, delta, duration = row
        current_time += delta
        start = current_time
        end = start + duration

        note = pretty_midi.Note(
            velocity=int(velocity),
            pitch=int(pitch),
            start=start,
            end=end
        )
        instrument.notes.append(note)

    pm.instruments.append(instrument)
    pm.write(path)

In [29]:
from scipy.io.wavfile import write

def midiToWav(midi_path, wav_path):
    """
    Convert MIDI file to WAV file using pretty_midi.
    """
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    audio_data = midi_data.fluidsynth()
    write(wav_path, 44100, audio_data.astype(np.float32))

@torch.inference_mode()
def sampler(model, diffusion: DDPM, noise,) -> np.ndarray:
    model.eval()
    samples = diffusion.p_sample_loop(model, noise, clip=True)
    return samples

def diff_adapter(tensor: torch.Tensor) -> torch.Tensor:
    """
    Adapter function to convert the output of the model to the expected format.
    """
    tensor[:, :, 0] = (tensor[:, :, 0] + 1) * 63.5
    tensor[:, :, 1:] = (tensor[:, :, 1:] + 1) / 2
    return tensor

In [None]:
import os
import math
from tqdm import tqdm
import torch

OUTPUT_DATA_PATH = "data/output"

REAL_MIDI_PATH = f"{OUTPUT_DATA_PATH}/real_midi"
REAL_WAV_PATH = f"{OUTPUT_DATA_PATH}/real_wav"
GENERATED_MIDI_PATH = f"{OUTPUT_DATA_PATH}/generated_midi_ddpm"
GENERATED_WAV_PATH = f"{OUTPUT_DATA_PATH}/generated_wav_ddpm"

FID_SAMPLE_SIZE = 1000
BATCH_SIZE = 64
NUM_TIMESTEPS = 1000

device = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(REAL_MIDI_PATH, exist_ok=True)
os.makedirs(REAL_WAV_PATH, exist_ok=True)
os.makedirs(GENERATED_MIDI_PATH, exist_ok=True)
os.makedirs(GENERATED_WAV_PATH, exist_ok=True)

In [None]:
transformed_real_files = os.listdir(REAL_WAV_PATH)[:FID_SAMPLE_SIZE]
dataset_samples = [t for t in events_dataset][:FID_SAMPLE_SIZE]

# Convert dataset to WAV
if len(dataset_samples) == len(transformed_real_files):
    print("Dataset already converted to WAV.")
else:
    for i, real_sample in enumerate(tqdm(dataset_samples, desc="Converting dataset to WAV")):
        midi_path = f"{REAL_MIDI_PATH}/data_{i}.mid"
        wav_path = f"{REAL_WAV_PATH}/data_{i}.wav"
        # midiToWav(f"{REAL_MIDI_PATH}/{midi_path}", wav_path)
        note_events_to_pretty_midi(real_sample, path=midi_path, default_program=0)
        midiToWav(midi_path, wav_path)

Converting dataset to WAV: 100%|██████████| 1000/1000 [00:20<00:00, 49.64it/s]


In [34]:
# Generate fake samples
iterations = math.ceil(FID_SAMPLE_SIZE / BATCH_SIZE)

midi_count = 0
for _ in tqdm(range(iterations), desc="Generating fake samples"):
    noise = torch.randn(BATCH_SIZE, 8, 4, device=device)
    generated = sampler(
        model=ae_model,
        diffusion=ddpm,
        noise=noise
    )
    generated = diff_adapter(generated)
    
    for midi in generated:
        midi_path = f"{GENERATED_MIDI_PATH}/data_{midi_count}.mid"
        wav_path = f"{GENERATED_WAV_PATH}/data_{midi_count}.wav"

        note_events_to_pretty_midi(midi, path=midi_path, default_program=0)
        midiToWav(midi_path,  wav_path)
        midi_count += 1

Generating fake samples: 100%|██████████| 16/16 [01:53<00:00,  7.11s/it]


In [11]:
!set PYTHONPATH=%CD% && python ./scripts/calculate_fad.py

Calculating FAD for ddpm...
FAD score for ddpm: 5.207745154545066
FAD scores: {'ddpm': 5.207745154545066}


  import pkg_resources
Using cache found in C:\Users\xconv/.cache\torch\hub\harritaylor_torchvggish_master
