In [1]:
'''
In this file 
We train our proposed SketchVAE
'''
import torch
import os
import numpy as np
from torch import optim
from torch.distributions import kl_divergence, Normal
from torch.nn import functional as F
from torch.optim.lr_scheduler import ExponentialLR
from SketchVAE.sketchvae import SketchVAE
from torch.utils.data import Dataset, DataLoader, TensorDataset

class MinExponentialLR(ExponentialLR):
    def __init__(self, optimizer, gamma, minimum, last_epoch=-1):
        self.min = minimum
        super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1)

    def get_lr(self):
        return [
            max(base_lr * self.gamma**self.last_epoch, self.min)
            for base_lr in self.base_lrs
        ]
###############################
# initial parameters
s_dir = "" # folder_address
batch_size = 64
n_epochs = 100
data_path = ["data/irish_train_chord_rhythm.npy",
             "data/irish_validate_chord_rhythm.npy",
             "data/irish_test_chord_rhythm.npy"]
save_path = "model_backup" # save_model_address 
lr = 1e-4
decay = 0.9999
hidden_dims = 1024
zp_dims = 128
zr_dims = 128
vae_beta = 0.1
input_dims = 130
pitch_dims = 129
rhythm_dims = 3
seq_len = 4 * 6
beat_num = 4
tick_num = 6
# set here to config your save_period (2 i.e. save the model every 2 epochs)
save_period = 2
##############################


In [2]:
# load data
def processed_data_tensor(data):
    print("processed data:")
    gd = np.array([d[0] for d in data])
    px = np.array([d[1] for d in data])
    rx = np.array([d[2] for d in data])
    len_x = np.array([d[3] for d in data])
    nrx = []
    for i,r in enumerate(rx):
        temp = np.zeros((seq_len, rhythm_dims))
        lins = np.arange(0, len(r))
        temp[lins, r - 1] = 1
        nrx.append(temp)
    nrx = np.array(nrx)
    gd = torch.from_numpy(gd).long()
    px = torch.from_numpy(px).long()
    rx = torch.from_numpy(rx).float()
    len_x = torch.from_numpy(len_x).long()
    nrx = torch.from_numpy(nrx).float()
    print("processed finish!")
    return TensorDataset(px, rx, len_x, nrx, gd)
train_set = np.load(os.path.join(s_dir,data_path[0]), allow_pickle = True)
validate_set = np.load(os.path.join(s_dir,data_path[1]),allow_pickle = True)
train_set = DataLoader(
    dataset = processed_data_tensor(train_set),
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)
validate_set = DataLoader(
    dataset = processed_data_tensor(validate_set),
    batch_size = batch_size, 
    shuffle = False, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)


processed data:
processed finish!
processed data:
processed finish!


In [3]:
# import model
model = SketchVAE(input_dims, pitch_dims, rhythm_dims, hidden_dims, zp_dims, zr_dims, seq_len, beat_num, tick_num, 4000)
optimizer = optim.Adam(model.parameters(), lr = lr)
if decay > 0:
    scheduler = MinExponentialLR(optimizer, gamma = decay, minimum = 1e-5)
if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('Using: CPU')


Using:  NVIDIA GeForce RTX 2080 Ti


In [4]:
# process validete data from the dataloder
validate_data = []
for i,d in enumerate(validate_set):
    validate_data.append(d)
print(len(validate_data))

486


In [5]:
# loss function
def std_normal(shape):
    N = Normal(torch.zeros(shape), torch.ones(shape))
    if torch.cuda.is_available():
        N.loc = N.loc.cuda()
        N.scale = N.scale.cuda()
    return N

def loss_function(recon, target, p_dis, r_dis, beta):
    CE = F.cross_entropy(recon.view(-1, recon.size(-1)), target, reduction = "mean")
    normal1 = std_normal(p_dis.mean.size())
    normal2=  std_normal(r_dis.mean.size())
    KLD1 = kl_divergence(p_dis, normal1).mean()
    KLD2 = kl_divergence(r_dis, normal2).mean()
    max_indices = recon.view(-1, recon.size(-1)).max(-1)[-1]
    correct = max_indices == target
    acc = torch.sum(correct.float()) / target.size(0)
    return acc, CE + beta * (KLD1 + KLD2)


In [None]:
# start training
logs = []
device = torch.device(torch.cuda.current_device())
iteration = 0
step = 0
for epoch in range(n_epochs):
    print("epoch: %d\n__________________________________________" % (epoch), flush = True)
    mean_loss = 0.0
    mean_acc = 0.0
    v_mean_loss = 0.0
    v_mean_acc = 0.0
    total = 0
    for i, d in enumerate(train_set):
        # validate display
        model.train()
        j = i % len(validate_data)
        px, rx, len_x, nrx, gd = d
        v_px, v_rx, v_len_x, v_nrx, v_gd = validate_data[j]
        
        px = px.to(device = device,non_blocking = True)
        len_x = len_x.to(device = device,non_blocking = True)
        nrx = nrx.to(device = device,non_blocking = True)
        gd = gd.to(device = device,non_blocking = True)

        v_px = v_px.to(device = device,non_blocking = True)
        v_len_x = v_len_x.to(device = device,non_blocking = True)
        v_nrx = v_nrx.to(device = device,non_blocking = True)
        v_gd = v_gd.to(device = device,non_blocking = True)
            
        optimizer.zero_grad()
        recon, p_dis, r_dis, iteration = model(px, nrx, len_x, gd)
        
        acc, loss = loss_function(recon, gd.view(-1), p_dis, r_dis, vae_beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        mean_loss += loss.item()
        mean_acc += acc.item()
        
        model.eval()
        with torch.no_grad():
            v_recon, v_p_dis, v_r_dis, _ = model(v_px, v_nrx, v_len_x, v_gd)
            v_acc, v_loss = loss_function(v_recon, v_gd.view(-1), v_p_dis, v_r_dis, vae_beta)
            v_mean_loss += v_loss.item()
            v_mean_acc += v_acc.item()
        step += 1
        total += 1
        if decay > 0:
            scheduler.step()
        print("batch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
              % (i,loss.item(), acc.item(), v_loss.item(),v_acc.item(),iteration),flush = True)
    mean_loss /= total
    mean_acc /= total
    v_mean_loss /= total
    v_mean_acc /= total
    print("epoch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
              % (epoch, mean_loss, mean_acc, v_mean_loss, v_mean_acc, iteration),flush = True)
    logs.append([mean_loss,mean_acc,v_mean_loss,v_mean_acc,iteration])
    if (epoch + 1) % save_period == 0:
        filename = "sketchvae-" + 'loss_' + str(v_mean_loss) + "_acc_" + str(v_mean_acc) + "_epoch_" +  str(epoch+1) + "_it_" + str(iteration) + ".pt"
        torch.save(model.cpu().state_dict(),os.path.join(s_dir, save_path, filename))
        model.cuda()
    np.save(os.path.join(s_dir,"sketchvae-log.npy"), logs)
        


epoch: 0
__________________________________________
batch 0 loss: 4.87248 acc: 0.00195 | val loss 4.80699 acc: 0.33854 iteration: 1
batch 1 loss: 4.79179 acc: 0.48307 | val loss 4.68266 acc: 0.66406 iteration: 2
batch 2 loss: 4.69104 acc: 0.68359 | val loss 4.56285 acc: 0.69792 iteration: 3
batch 3 loss: 4.59149 acc: 0.69987 | val loss 4.40933 acc: 0.69141 iteration: 4
batch 4 loss: 4.45620 acc: 0.71745 | val loss 4.21867 acc: 0.72070 iteration: 5
batch 5 loss: 4.25593 acc: 0.69857 | val loss 3.98189 acc: 0.69206 iteration: 6
batch 6 loss: 4.08158 acc: 0.69206 | val loss 3.60086 acc: 0.68945 iteration: 7
batch 7 loss: 3.72593 acc: 0.69727 | val loss 3.10402 acc: 0.69661 iteration: 8
batch 8 loss: 3.24531 acc: 0.69661 | val loss 2.48939 acc: 0.70247 iteration: 9
batch 9 loss: 2.55781 acc: 0.71094 | val loss 1.98744 acc: 0.71745 iteration: 10
batch 10 loss: 2.04878 acc: 0.70508 | val loss 2.12163 acc: 0.70117 iteration: 11
batch 11 loss: 1.98912 acc: 0.71224 | val loss 2.25247 acc: 0.704