In [None]:
import torch
from dataset import prepare_dataset
from model import DisentangleVAE
from ptvae import RnnEncoder, PtvaeEncoder, PtvaeDecoder, RnnDecoder
import matplotlib.pyplot as plt
import numpy as np


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Criar componentes do modelo (ajustar se necessário)
chd_encoder = RnnEncoder(36, 1024, 256)
rhy_encoder = PtvaeEncoder(device=device, z_size=256, max_pitch=31, min_pitch=0)
chd_decoder = RnnDecoder(z_dim=256)
pt_decoder = PtvaeDecoder(note_embedding=None, dec_dur_hid_size=64, z_size=512)

model = DisentangleVAE('disvae', device, chd_encoder, rhy_encoder, pt_decoder, chd_decoder)
model.load_state_dict(torch.load('model_param/polydis-v1.pt', map_location=device))
model.to(device)
model.eval()


In [None]:
train_loader, _ = prepare_dataset(seed=123, bs_train=16, bs_val=16,
                                  portion=8, shift_low=-6, shift_high=5,
                                  num_bar=2, contain_chord=True,
                                  random_train=False, random_val=False)


In [None]:
losses = []
kl_chd_list = []
kl_rhy_list = []

for batch in train_loader:
    _, _, pr_mat, x, c, dt_x = batch
    x, c, pr_mat, dt_x = x.to(device), c.to(device), pr_mat.to(device), dt_x.to(device)

    with torch.no_grad():
        pitch_outs, dur_outs, dist_chd, dist_rhy, recon_root, recon_chroma, recon_bass = model.run(
            x, c, pr_mat,
            tfr1=0.0, tfr2=0.0, tfr3=0.0, confuse=False
        )

        loss_vals = model.loss_function(
            x, c,
            pitch_outs, dur_outs,
            dist_chd, dist_rhy,
            recon_root, recon_chroma, recon_bass,
            beta=0.1, weights=[1.0, 0.5]
        )

        total_loss, _, _, _, _, kl_chd, kl_rhy, *_ = loss_vals

        losses.append(total_loss.item())
        kl_chd_list.append(kl_chd.item())
        kl_rhy_list.append(kl_rhy.item())


In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(losses)
plt.title('Total Loss')

plt.subplot(1, 3, 2)
plt.plot(kl_chd_list)
plt.title('KL Divergence - z_chd')

plt.subplot(1, 3, 3)
plt.plot(kl_rhy_list)
plt.title('KL Divergence - z_txt')

plt.tight_layout()
plt.show()
