In [1]:
import torch
import numpy as np

In [2]:
import sys
sys.path.append('./base_model')

In [3]:
from base_model.dataset import wrap_dataset
from base_model.model import DisentangleVAE
from interface import PolyDisVAE

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def load_sample(path):
    dataset = wrap_dataset([path], [0], shift_low=0, shift_high=0,
                           num_bar=2, contain_chord=True)
    mel, prs, pr_mat, x, c, dt_x = dataset[0]
    return x, c, pr_mat

In [6]:
def prepare_tensors(x, c, pr_mat, device):
    x = torch.tensor(x).long().unsqueeze(0).to(device)
    c = torch.tensor(c).float().unsqueeze(0).to(device)
    pr_mat = torch.tensor(pr_mat).float().unsqueeze(0).to(device)
    return x, c, pr_mat

In [7]:
def load_model(device):
    interface = PolyDisVAE.init_model(device=device)
    interface.load_model('./model_param/polydis-v1.pt')
    model = DisentangleVAE('disvae', device,
                           interface.chd_encoder,
                           interface.txt_encoder,
                           interface.pnotree_decoder,
                           interface.chd_decoder)
    model.eval()
    return model

In [8]:
load_model(device)

  dic = torch.load(model_path, map_location=self.device)


DisentangleVAE(
  (chd_encoder): ChordEncoder(
    (gru): GRU(36, 1024, batch_first=True, bidirectional=True)
    (linear_mu): Linear(in_features=2048, out_features=256, bias=True)
    (linear_var): Linear(in_features=2048, out_features=256, bias=True)
  )
  (rhy_encoder): TextureEncoder(
    (cnn): Sequential(
      (0): Conv2d(1, 10, kernel_size=(4, 12), stride=(4, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    )
    (fc1): Linear(in_features=290, out_features=1000, bias=True)
    (fc2): Linear(in_features=1000, out_features=256, bias=True)
    (gru): GRU(256, 1024, batch_first=True, bidirectional=True)
    (linear_mu): Linear(in_features=2048, out_features=256, bias=True)
    (linear_var): Linear(in_features=2048, out_features=256, bias=True)
  )
  (decoder): PianoTreeDecoder(
    (note_embedding): Linear(in_features=135, out_features=128, bias=True)
    (z2dec_hid_linear): Linear(in_features=512, out_feature

## run function

```
 def run(self, x, c, pr_mat, tfr1, tfr2, tfr3, confuse=True):
        embedded_x, lengths = self.decoder.emb_x(x)
        # cc = self.get_chroma(pr_mat)
        dist_chd = self.chd_encoder(c)
        # pr_mat = self.confuse_prmat(pr_mat)
        dist_rhy = self.rhy_encoder(pr_mat)
        z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], True)
        dec_z = torch.cat([z_chd, z_rhy], dim=-1)
        pitch_outs, dur_outs = self.decoder(dec_z, False, embedded_x,
                                            lengths, tfr1, tfr2)
        recon_root, recon_chroma, recon_bass = self.chd_decoder(z_chd, False,
                                                                tfr3, c)
        return pitch_outs, dur_outs, dist_chd, dist_rhy, recon_root, \
            recon_chroma, recon_bass
```

In [9]:
# pra ver como est√° sendo calculado o debug_run

def debug_run(model, x, c, pr_mat):
    names = [
        'recon_pitch', 'recon_dur', 'dist_chd', 'dist_rhy',
        'recon_root', 'recon_chroma', 'recon_bass'
    ]
    outputs = model.run(x, c, pr_mat, 0., 0., 0.)
    print('--- run() outputs ---')
    for n, o in zip(names, outputs):
        if isinstance(o, torch.distributions.Distribution):
            print(f'{n}: mean {tuple(o.mean.shape)}, std {tuple(o.stddev.shape)}')
        else:
            print(f'{n}: {tuple(o.shape)}')
    return outputs

## loss function

```
def loss_function(self, x, c, recon_pitch, recon_dur, dist_chd,
                      dist_rhy, recon_root, recon_chroma, recon_bass,
                      beta, weights, weighted_dur=False):
        recon_loss, pl, dl = self.decoder.recon_loss(x, recon_pitch, recon_dur,
                                                     weights, weighted_dur)
        kl_loss, kl_chd, kl_rhy = self.kl_loss(dist_chd, dist_rhy)
        chord_loss, root, chroma, bass = self.chord_loss(c, recon_root,
                                                         recon_chroma,
                                                         recon_bass)
        loss = recon_loss + beta * kl_loss + chord_loss
        return loss, recon_loss, pl, dl, kl_loss, kl_chd, kl_rhy, chord_loss, \
               root, chroma, bass

```

In [10]:
# pra ver como funciona a loss function

def debug_loss_function(model, x, c, run_outputs):
    labels = [
        'total_loss', 'recon_loss', 'pitch_loss', 'dur_loss',
        'kl_loss', 'kl_chd', 'kl_rhy',
        'chord_loss', 'root_loss', 'chroma_loss', 'bass_loss'
    ]
    loss_values = model.loss_function(x, c, *run_outputs, beta=0.1, weights=(1, 0.5))
    print('--- loss_function() breakdown ---')
    for label, val in zip(labels, loss_values):
        if torch.is_tensor(val):
            if val.dim() == 0:
                print(f'{label}: {val.item()}')
            else:
                print(f'{label}: {tuple(val.shape)}')
        else:
            print(f'{label}: {val}')
    return loss_values

In [11]:
def calc_loss():
    data_path = './dataSet/POP09-PIANOROLL-4-bin-quantization/001.npz'
    model = load_model(device=device)
    x, c, pr_mat = load_sample(data_path)
    x, c, pr_mat = prepare_tensors(x, c, pr_mat, model.device)

    loss, *_ = model.loss(x, c, pr_mat)
    print('Loss:', loss.item())

In [12]:
calc_loss()

Loss: 0.3808215856552124


In [13]:
from compute_single_loss import main as compute_single_loss

In [14]:
compute_single_loss()

Loss: 0.4973600208759308


In [15]:
def calc_loss_debug():
    data_path = './dataSet/POP09-PIANOROLL-4-bin-quantization/001.npz'
    model = load_model(device=None)
    x, c, pr_mat = load_sample(data_path)
    x, c, pr_mat = prepare_tensors(x, c, pr_mat, model.device)

    run_outs = debug_run(model, x, c, pr_mat)
    debug_loss_function(model, x, c, run_outs)

    loss, *_ = model.loss(x, c, pr_mat)
    print('Final loss from model.loss():', loss.item())

In [16]:
calc_loss_debug()

--- run() outputs ---
recon_pitch: (1, 32, 15, 130)
recon_dur: (1, 32, 15, 5, 2)
dist_chd: mean (1, 256), std (1, 256)
dist_rhy: mean (1, 256), std (1, 256)
recon_root: (1, 8, 12)
recon_chroma: (1, 8, 12, 2)
recon_bass: (1, 8, 12)
--- loss_function() breakdown ---
total_loss: 0.39179641008377075
recon_loss: 0.24111223220825195
pitch_loss: 0.1323319375514984
dur_loss: 0.2175605744123459
kl_loss: 1.2786529064178467
kl_chd: 0.3809277415275574
kl_rhy: 0.8977251648902893
chord_loss: 0.02281886339187622
root_loss: 0.0028574406169354916
chroma_loss: 0.01977548561990261
bass_loss: 0.00018593735876493156
Final loss from model.loss(): 0.3384990692138672
