In [None]:
# In this file 
# we train the MeasureVAE "Learning to Traverse Latent Spaces for Musical Score Inpainting", published in ISMIR 2019
# The core model code is from their releasing code.

from MeasureVAE.measure_vae import MeasureVAE
from utils.helpers import *
import numpy as np
from torch import optim
import random
from data_processor import DatasetProcessor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
data_path = [
    "data/bachinv_1.npy",
    "data/bachinv_2.npy"
]

# paramters initialization
num_notes = 86
note_embedding_dim=10
metadata_embedding_dim=2
num_encoder_layers=2
encoder_hidden_size=512
encoder_dropout_prob=0.5
has_metadata=False
latent_space_dim=512
num_decoder_layers=2
decoder_hidden_size=512
decoder_dropout_prob=0.5
batch_size=128
num_epochs=50
train=True
plot=False
log=True
lr=1e-4
seq_len = 8 * 6
n_epochs = 50
save_period = 5
ratio = [0.8,0.2,0]


In [None]:
# import data

trainset = []
validset = []
for dpath in data_path:
    dp = DatasetProcessor(data = dpath)
    dp.split(ratio[::])
    trainset += dp.process(dataset = dp.trainset, vae = "MeasureVAE", spb = 48, bar_num = 1 , shift_note = [-4,-3,-2,-1,0,1,2,3,4])
    validset += dp.process(dataset = dp.validset, vae = "MeasureVAE", spb = 48, bar_num = 1, shift_note = [-4,-3,-2,-1,0,1,2,3,4])
trainset = torch.from_numpy(np.array(trainset)).long()
validset = torch.from_numpy(np.array(validset)).long()
print(trainset.size(), validset.size())

In [None]:
trainset = TensorDataset(trainset)
validset = TensorDataset(validset)
train_loader = DataLoader(dataset = trainset, batch_size = batch_size, shuffle = True)
valid_loader = DataLoader(dataset = validset, batch_size = batch_size, shuffle = False)

In [None]:
# loss function
def compute_kld_loss(z_dist, prior_dist, beta=0.001):
    """

    :param z_dist: torch.nn.distributions object
    :param prior_dist: torch.nn.distributions
    :param beta:
    :return: kl divergence loss
    """
    kld = torch.distributions.kl.kl_divergence(z_dist, prior_dist)
    kld = beta * kld.sum(1).mean()
    return kld

def mean_crossentropy_loss(weights, targets):
    """
    Evaluates the cross entropy loss
    :param weights: torch Variable,
            (batch_size, seq_len, num_notes)
    :param targets: torch Variable,
            (batch_size, seq_len)
    :return: float, loss
    """
    criteria = torch.nn.CrossEntropyLoss(reduction='mean')
    batch_size, seq_len, num_notes = weights.size()
    weights = weights.contiguous().view(-1, num_notes)
    targets = targets.contiguous().view(-1)
    loss = criteria(weights, targets)
    return loss

def mean_accuracy(weights, targets):
    """
    Evaluates the mean accuracy in prediction
    :param weights: torch Variable,
            (batch_size, seq_len, num_notes)
    :param targets: torch Variable,
            (batch_size, seq_len)
    :return float, accuracy
    """
    _, _, num_notes = weights.size()
    weights = weights.contiguous().view(-1, num_notes)
    targets = targets.contiguous().view(-1)

    _, max_indices = weights.max(1)
    correct = max_indices == targets
    return torch.sum(correct.float()) / targets.size(0)



In [None]:
# import measureVAE
save_path = "model_backup/"
model = MeasureVAE(
    num_notes = num_notes,
    note_embedding_dim=note_embedding_dim,
    metadata_embedding_dim=metadata_embedding_dim,
    num_encoder_layers=num_encoder_layers,
    encoder_hidden_size=encoder_hidden_size,
    encoder_dropout_prob=encoder_dropout_prob,
    latent_space_dim=latent_space_dim,
    num_decoder_layers=num_decoder_layers,
    decoder_hidden_size=decoder_hidden_size,
    decoder_dropout_prob=decoder_dropout_prob,
    has_metadata=has_metadata
)
optimizer = optim.Adam(model.parameters(), lr = lr)
if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('Using: CPU')



In [None]:
# start training
model.train()
step = 0
for epoch in range(n_epochs):
    tlen = len(train_loader)
    vlen = len(valid_loader)
    total_loss = 0.0
    total_acc = 0.0
    total_vloss = 0.0
    total_vacc = 0.0
    for i, data in enumerate(train_loader):
        model.train()
        x = data[0]
        target = x.view(-1)
        if torch.cuda.is_available():
            x = x.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        weights, samples, z_dist, prior_dist, z_tilde, z_prior = model(measure_score_tensor=x,train=True)
        recons_loss = mean_crossentropy_loss(weights=weights, targets=target)
        dist_loss = compute_kld_loss(z_dist, prior_dist)
        loss = recons_loss + dist_loss
        accuracy = mean_accuracy(weights=weights,targets=target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        print("epoch: %d batch: %d/%d loss: %.5f acc: %.5f"  % (epoch, i, tlen, loss.item(), accuracy.item()),end = "\r")
        total_loss += loss.item()
        total_acc += accuracy.item()
        
    for i,data in enumerate(valid_loader):
        model.eval()
        v_x = data[0]
        v_target = v_x.view(-1)
        if torch.cuda.is_available():
            v_x = v_x.cuda()
            v_target = v_target.cuda()
        with torch.no_grad():
            v_weights, v_samples, v_z_dist, v_prior_dist, v_z_tilde, v_z_prior = model(measure_score_tensor=v_x,train=False)
            v_recons_loss = mean_crossentropy_loss(weights=v_weights, targets=v_target)
            v_dist_loss = compute_kld_loss(v_z_dist, v_prior_dist)
            v_loss = v_recons_loss + v_dist_loss
            v_accuracy = mean_accuracy(weights=v_weights,targets=v_target)
            total_vloss += v_loss.item()
            total_vacc += v_accuracy.item()
        print("epoch: %d batch: %d/%d val loss: %.5f val acc: %.5f"  % (epoch ,i, vlen, v_loss.item(), v_accuracy.item()),end = "\r")
    total_loss /= tlen
    total_acc /= tlen
    total_vloss /= vlen
    total_vacc /= vlen
    print("epoch: %d loss: %.5f acc: %.5f val_loss: %.5f val_acc: %.5f"  % (epoch, total_loss, total_acc, total_vloss, total_vacc))
            
        
    if (epoch + 1) % save_period == 0:
        filename = "measure-vae-" + 'loss_' + str(total_loss) + "_" + str(total_acc) + "_" + str(epoch+1) + ".pt"
        torch.save(model.cpu().state_dict(),save_path + filename)
        model.cuda()