In [26]:
model_name = "platune_maestrov4_win2"
step = None
device_id = "2"
test = True
suffix = ""
data = "real"
max_samples = 1000
N_SIGNAL = 131072
batch_size = 16
NB_STEPS = 10
SEED = 42
DO_CLAP = True

In [None]:
%load_ext autoreload
%autoreload 2

import os

import sys
sys.path.append("/data/nils/repos2/codecs_benchmark/evaluation")
os.chdir("/data/nils/repos2/platune")

import json
import torch

from IPython.display import display, Audio
from tqdm import tqdm
import time
import numpy as np
from after.dataset import CombinedDataset, SimpleDataset
from after.autoencoder.wrappers import M2LWrapper

import gin

gin.enter_interactive_mode()

torch.set_grad_enabled(False)

device = "cuda:2"

In [None]:
path = "/data/nils/datasets/maestro/m2l_midiv2"
dataset = SimpleDataset(path, keys="all")

In [None]:
from platune.model import PLaTune, SDEdit

gin.clear_config()

folder = os.path.join("/data/nils/repos2/platune/runs", model_name)
autoencoder_path = "music2latent"
config = os.path.join(folder, "config.gin")

# parse config
gin.parse_config_file(config)

paths = []

for sub in ["version_4", "version_3", "version_2", "version_1", "version_0"]:
    for subsub in [
            "last-v1.ckpt",
            "last.ckpt",
    ]:
        paths.append(os.path.join(folder, sub, "checkpoints", subsub))

for p in paths:
    try:
        print(p)
        state_dict = torch.load(p, map_location="cpu")["state_dict"]
        break
    except:
        continue

if "sdedit" in model_name.lower():
    model = SDEdit()
    SDEDIT = True
else:
    model = PLaTune()
    SDEDIT = False

model.load_state_dict(state_dict, strict=True)
model.eval()
model.to(device)

# Parse config
gin.parse_config_file(config)
SR = 44100
try:
    N_SIGNAL_LATENT = gin.query_parameter("%SEQ_LEN")
except:
    N_SIGNAL_LATENT = 64
print(N_SIGNAL)

if autoencoder_path == "music2latent":
    emb_model = M2LWrapper(device=device)
    AE_RATIO = 4096
else:
    emb_model = torch.jit.load(autoencoder_path).eval().to(device)

CONTINUOUS_KEYS = gin.query_parameter("%CONTINUOUS_KEYS")
DISCRETE_KEYS = gin.query_parameter("%DISCRETE_KEYS")

descriptors = DISCRETE_KEYS + CONTINUOUS_KEYS
descriptors_nonmidi = [d for d in descriptors if d not in ["pitch", "octave"]]

In [30]:
N_SIGNAL = AE_RATIO * N_SIGNAL_LATENT
hop_length = AE_RATIO

from platune.datasets.midi_descriptors import compute_midi_descriptors


def get_midi(midi_data, chunk_number):
    length = N_SIGNAL / SR
    tstart = chunk_number * N_SIGNAL / SR
    tend = (chunk_number + 1) * N_SIGNAL / SR
    out_notes = []
    for note in midi_data.instruments[0].notes:
        if note.end > tstart and note.start < tend:
            note.start = max(0, note.start - tstart)
            note.end = min(note.end - tstart, length)
            out_notes.append(note)

    if len(out_notes) == 0:
        return True, None
    midi_data.instruments[0].notes = out_notes
    # midi_data.adjust_times([0, length], [0, length])
    return False, midi_data

In [None]:
idx1 = 1079
idx2 = 490


def get_data(idx):
    data = dataset[idx]

    z = data["z"][..., :N_SIGNAL_LATENT]

    silence_test, midi_cur = get_midi(data["midi"], 0)

    out = compute_midi_descriptors(midi_cur,
                                   target_length=N_SIGNAL_LATENT,
                                   total_time=N_SIGNAL / SR)

    z = torch.from_numpy(z).unsqueeze(0)

    print([a.shape for a in out.values()])

    cont_features = torch.stack(
        [torch.from_numpy(out[desc]).float() for desc in descriptors])

    print(cont_features.shape)
    return z, out, cont_features


z1, descriptors1, cont_features1 = get_data(idx1)

z2, descriptors2, cont_features2 = get_data(idx2)

audio1 = emb_model.decode(z1.to(device)).squeeze(0).cpu().numpy()
display(Audio(audio1, rate=SR))

audio2 = emb_model.decode(z2.to(device)).squeeze(0).cpu().numpy()
display(Audio(audio2, rate=SR))

In [None]:
NB_STEPS = 50

a = model.process_attributes(torch.tensor([]), cont_features1, label=None)
c = model.normalize_attr(a)

cs_rec = model.z_to_cs(z1, c=c, nb_steps=NB_STEPS)
cs_rec2 = model.z_to_cs(z2, c=c, nb_steps=NB_STEPS)

c_rec = cs_rec[:, :model.control_dim]
s_rec = cs_rec[:, model.control_dim:]

c_dist, s_dist = model.get_cs_distributions(c, warmup=False, zero_var=False)
cs = model.get_cs_samples(c_dist, s_dist)

cs[:, :model.control_dim] = c_rec
# z_sample = model.cs_to_z(cs, nb_steps=NB_STEPS)

cs = cs_rec.clone()
cs[:, :model.control_dim] = cs_rec2[:, :model.control_dim]
z_sample_swap = model.cs_to_z(cs, nb_steps=NB_STEPS)

# z_rec = model.cs_to_z(cs_rec, c=c, nb_steps=NB_STEPS)

# Variation 1
cs_var = cs_rec.clone()
cs_var[:, 3] += 1.
c_var = cs_var[:, :model.control_dim]

# Variation 1
# cs_var = cs_rec
# c_var = c.clone()
# c_var[:, 0] += 0.5

z_var = model.cs_to_z(cs_var, c=c_var, nb_steps=NB_STEPS)

display(Audio(audio1, rate=SR))

# audio_sample = emb_model.decode(z_sample.to(device)).squeeze(0).cpu().numpy()
# display(Audio(audio_sample, rate=SR))

audio_sample = emb_model.decode(
    z_sample_swap.to(device)).squeeze(0).cpu().numpy()
display(Audio(audio_sample, rate=SR))

# audio_rec = emb_model.decode(z_rec.to(device)).squeeze(0).cpu().numpy()
# display(Audio(audio_rec, rate=SR))

audio_var = emb_model.decode(z_var.to(device)).squeeze(0).cpu().numpy()
display(Audio(audio_var, rate=SR))

In [None]:
import matplotlib.pyplot as plt
for i, descr in enumerate(descriptors):
    print(f"{descr}")

    plt.figure(figsize=(12, 2))
    plt.title(f"{descr}")
    plt.plot(c[0, i].cpu().numpy(), label="Original Normalized")
    plt.plot(c_rec[0, i].cpu().numpy(), label="Extracted")
    plt.plot(c_var[0, i].cpu().numpy(), label="Modified")
    # plt.ylim(-4, 4)
    plt.show()