In [1]:
# In this file 
# we train our proposed PolyVAE

import torch
import os
import numpy as np
from torch import optim
from torch.distributions import kl_divergence, Normal
from torch.nn import functional as F
from torch.optim.lr_scheduler import ExponentialLR
from model import PolyVAE
from torch.utils.data import Dataset, DataLoader, TensorDataset

class MinExponentialLR(ExponentialLR):
    def __init__(self, optimizer, gamma, minimum, last_epoch=-1):
        self.min = minimum
        super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1)

    def get_lr(self):
        return [
            max(base_lr * self.gamma**self.last_epoch, self.min)
            for base_lr in self.base_lrs
        ]
###############################
# initial parameters
s_dir = ""
batch_size = 64
n_epochs = 1
data_path = [s_dir + "data/poly_train.npy",
             s_dir + "data/poly_validate.npy",
             s_dir + "data/poly_train.npy"]
save_path = ""
lr = 1e-4
decay = 0.9999
hidden_dims = 512
z_dims = 1024
vae_beta = 0.1
input_dims = 130
seq_len = 10 * 16
beat_num = 10
tick_num = 16
save_period = 1
##############################


In [2]:
# input data
train_set = np.load(data_path[0], allow_pickle = True)
validate_set = np.load(data_path[1],allow_pickle = True) 

train_x = []
for i,data in enumerate(train_set):
    temp = []
    for d in data["layers"]:
        temp += d
    train_x.append(temp)
train_x = np.array(train_x)
# print(train_x.shape)

validate_x = []
for i,data in enumerate(validate_set):
    temp = []
    for d in data["layers"]:
        temp += d
    validate_x.append(temp)
validate_x = np.array(validate_x)
# print(train_x.shape)
train_x = torch.from_numpy(train_x).long()
validate_x = torch.from_numpy(validate_x).long()

print(train_x.size())
print(validate_x.size())

train_set = TensorDataset(train_x)
validate_set = TensorDataset(validate_x)

train_set = DataLoader(
    dataset = train_set,
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)
validate_set = DataLoader(
    dataset = validate_set,
    batch_size = batch_size, 
    shuffle = False, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)


torch.Size([58041, 160])
torch.Size([7578, 160])


In [3]:
# import model
model = PolyVAE(input_dims, hidden_dims, z_dims, seq_len, beat_num, tick_num, 4000)
optimizer = optim.Adam(model.parameters(), lr = lr)
if decay > 0:
    scheduler = MinExponentialLR(optimizer, gamma = decay, minimum = 1e-5)
if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('Using: CPU')


Using:  GeForce GTX 1080 Ti


In [4]:
# process validete data from the dataloder
validate_data = []
for i,d in enumerate(validate_set):
    validate_data.append(d[0])
print(len(validate_data))

118


In [5]:
# loss function
def std_normal(shape):
    N = Normal(torch.zeros(shape), torch.ones(shape))
    if torch.cuda.is_available():
        N.loc = N.loc.cuda()
        N.scale = N.scale.cuda()
    return N

def loss_function(recon, target, r_dis, beta):
    CE = F.cross_entropy(recon.view(-1, recon.size(-1)), target, reduction = "mean")
#     rhy_CE = F.nll_loss(recon_rhythm.view(-1, recon_rhythm.size(-1)), target_rhythm, reduction = "mean")
    normal1 =  std_normal(r_dis.mean.size())
    KLD1 = kl_divergence(r_dis, normal1).mean()
    max_indices = recon.view(-1, recon.size(-1)).max(-1)[-1]
#     print(max_indices)
    correct = max_indices == target
    acc = torch.sum(correct.float()) / target.size(0)
    return acc, CE + beta * (KLD1)


In [None]:
# start training
logs = []
device = torch.device(torch.cuda.current_device())
iteration = 0
step = 0
for epoch in range(n_epochs):
    print("epoch: %d\n__________________________________________" % (epoch), flush = True)
    mean_loss = 0.0
    mean_acc = 0.0
    v_mean_loss = 0.0
    v_mean_acc = 0.0
    total = 0
    for i, d in enumerate(train_set):
        # validate display
        x = gd = d[0]
        model.train()
        j = i % len(validate_data)
        v_x = v_gd = validate_data[j]
        
        x = x.to(device = device,non_blocking = True)
        gd = gd.to(device = device,non_blocking = True)
        v_x = v_x.to(device = device,non_blocking = True)
        v_gd = v_gd.to(device = device,non_blocking = True)
            
        optimizer.zero_grad()
        recon, r_dis, iteration = model(x, gd)
        
        acc, loss = loss_function(recon, gd.view(-1), r_dis, vae_beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        mean_loss += loss.item()
        mean_acc += acc.item()
        
        model.eval()
        with torch.no_grad():
            v_recon, v_r_dis, _ = model(v_x, v_gd)
            v_acc, v_loss = loss_function(v_recon, v_gd.view(-1), v_r_dis, vae_beta)
            v_mean_loss += v_loss.item()
            v_mean_acc += v_acc.item()
        step += 1
        total += 1
        if decay > 0:
            scheduler.step()
        print("batch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
              % (i,loss.item(), acc.item(), v_loss.item(),v_acc.item(),iteration),flush = True)
    mean_loss /= total
    mean_acc /= total
    v_mean_loss /= total
    v_mean_acc /= total
    print("epoch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
              % (epoch, mean_loss, mean_acc, v_mean_loss, v_mean_acc, iteration),flush = True)
    logs.append([mean_loss,mean_acc,v_mean_loss,v_mean_acc,iteration])
    if (epoch + 1) % save_period == 0:
        filename = "sketchvae-" + 'loss_' + str(mean_loss) + "_" + str(epoch+1) + "_" + str(iteration) + ".pt"
        torch.save(model.cpu().state_dict(),save_path + filename)
        model.cuda()
    np.save("sketchvae-log.npy", logs)
        


epoch: 0
__________________________________________
batch 0 loss: 5.21777 acc: 0.00088 | val loss 5.19814 acc: 0.00000 iteration: 1
batch 1 loss: 5.17792 acc: 0.05947 | val loss 5.16774 acc: 0.26045 iteration: 2
batch 2 loss: 5.14806 acc: 0.34961 | val loss 5.10428 acc: 0.35947 iteration: 3
batch 3 loss: 5.11382 acc: 0.35205 | val loss 5.01315 acc: 0.51250 iteration: 4
batch 4 loss: 5.06965 acc: 0.39336 | val loss 5.00198 acc: 0.41602 iteration: 5
batch 5 loss: 5.03020 acc: 0.38682 | val loss 4.91021 acc: 0.46875 iteration: 6
batch 6 loss: 5.01098 acc: 0.35088 | val loss 4.76562 acc: 0.58623 iteration: 7
batch 7 loss: 4.95353 acc: 0.37754 | val loss 4.85851 acc: 0.35967 iteration: 8
batch 8 loss: 4.92823 acc: 0.34629 | val loss 4.89572 acc: 0.28291 iteration: 9
batch 9 loss: 4.86625 acc: 0.36562 | val loss 4.76055 acc: 0.36572 iteration: 10
batch 10 loss: 4.81285 acc: 0.35908 | val loss 4.52244 acc: 0.45596 iteration: 11
batch 11 loss: 4.72808 acc: 0.36924 | val loss 4.39923 acc: 0.450

batch 100 loss: 3.32245 acc: 0.34756 | val loss 3.25845 acc: 0.40410 iteration: 101
batch 101 loss: 3.43810 acc: 0.33896 | val loss 3.36108 acc: 0.39502 iteration: 102
batch 102 loss: 3.28727 acc: 0.38887 | val loss 3.15915 acc: 0.43799 iteration: 103
batch 103 loss: 3.46859 acc: 0.34424 | val loss 3.13755 acc: 0.45967 iteration: 104
batch 104 loss: 3.31745 acc: 0.39189 | val loss 3.40719 acc: 0.31455 iteration: 105
batch 105 loss: 3.37716 acc: 0.35742 | val loss 3.16309 acc: 0.44209 iteration: 106
batch 106 loss: 3.28438 acc: 0.37412 | val loss 3.10720 acc: 0.45664 iteration: 107
batch 107 loss: 3.45458 acc: 0.35107 | val loss 3.40893 acc: 0.34492 iteration: 108
batch 108 loss: 3.19450 acc: 0.40430 | val loss 3.88756 acc: 0.21475 iteration: 109
batch 109 loss: 3.37214 acc: 0.36172 | val loss 3.69774 acc: 0.28311 iteration: 110
batch 110 loss: 3.06224 acc: 0.43877 | val loss 2.59375 acc: 0.50117 iteration: 111
batch 111 loss: 3.34376 acc: 0.36025 | val loss 2.87626 acc: 0.43701 iterati

batch 198 loss: 3.11952 acc: 0.37412 | val loss 3.59384 acc: 0.22383 iteration: 199
batch 199 loss: 2.85728 acc: 0.39131 | val loss 3.84734 acc: 0.12783 iteration: 200
batch 200 loss: 3.15916 acc: 0.35693 | val loss 3.35338 acc: 0.35869 iteration: 201
batch 201 loss: 3.01121 acc: 0.39590 | val loss 3.44624 acc: 0.32568 iteration: 202
batch 202 loss: 2.88698 acc: 0.41279 | val loss 3.60732 acc: 0.24697 iteration: 203
batch 203 loss: 2.93188 acc: 0.39619 | val loss 3.07759 acc: 0.40293 iteration: 204
batch 204 loss: 2.97237 acc: 0.40117 | val loss 3.23468 acc: 0.36689 iteration: 205
batch 205 loss: 3.02435 acc: 0.37852 | val loss 3.12520 acc: 0.39678 iteration: 206
batch 206 loss: 3.04444 acc: 0.39502 | val loss 3.01497 acc: 0.41094 iteration: 207
batch 207 loss: 2.94775 acc: 0.40615 | val loss 3.12972 acc: 0.38877 iteration: 208
batch 208 loss: 3.00978 acc: 0.39639 | val loss 3.24216 acc: 0.30938 iteration: 209
batch 209 loss: 2.99803 acc: 0.37627 | val loss 3.10201 acc: 0.36357 iterati

batch 296 loss: 2.32697 acc: 0.42529 | val loss 1.91049 acc: 0.61523 iteration: 297
batch 297 loss: 2.32642 acc: 0.41123 | val loss 1.88909 acc: 0.61299 iteration: 298
batch 298 loss: 2.30993 acc: 0.44609 | val loss 2.70473 acc: 0.39805 iteration: 299
batch 299 loss: 2.51247 acc: 0.38877 | val loss 2.51947 acc: 0.36123 iteration: 300
batch 300 loss: 2.42787 acc: 0.41650 | val loss 2.30934 acc: 0.34121 iteration: 301
batch 301 loss: 2.32179 acc: 0.43018 | val loss 3.30012 acc: 0.32949 iteration: 302
batch 302 loss: 2.31771 acc: 0.42920 | val loss 2.21789 acc: 0.42402 iteration: 303
batch 303 loss: 2.56081 acc: 0.36729 | val loss 2.31405 acc: 0.41895 iteration: 304
batch 304 loss: 2.47831 acc: 0.40830 | val loss 2.26271 acc: 0.46172 iteration: 305
batch 305 loss: 2.40957 acc: 0.39912 | val loss 2.45157 acc: 0.52246 iteration: 306
batch 306 loss: 2.34896 acc: 0.41924 | val loss 1.94645 acc: 0.52432 iteration: 307
batch 307 loss: 2.20366 acc: 0.45020 | val loss 1.98826 acc: 0.52656 iterati

batch 394 loss: 2.26450 acc: 0.43525 | val loss 2.15940 acc: 0.43877 iteration: 395
batch 395 loss: 2.31173 acc: 0.42373 | val loss 2.34335 acc: 0.41738 iteration: 396
batch 396 loss: 2.24161 acc: 0.44297 | val loss 2.29195 acc: 0.43018 iteration: 397
batch 397 loss: 2.12979 acc: 0.45918 | val loss 1.98600 acc: 0.48398 iteration: 398
batch 398 loss: 2.44991 acc: 0.38350 | val loss 2.00800 acc: 0.46582 iteration: 399
batch 399 loss: 2.30069 acc: 0.41885 | val loss 2.01025 acc: 0.41670 iteration: 400
batch 400 loss: 2.21902 acc: 0.44395 | val loss 2.13174 acc: 0.45117 iteration: 401
batch 401 loss: 2.25638 acc: 0.43096 | val loss 2.26520 acc: 0.41055 iteration: 402
batch 402 loss: 2.18955 acc: 0.47197 | val loss 1.26182 acc: 0.73145 iteration: 403
batch 403 loss: 2.24338 acc: 0.43799 | val loss 1.55148 acc: 0.60791 iteration: 404
batch 404 loss: 2.10256 acc: 0.43838 | val loss 1.88435 acc: 0.57402 iteration: 405
batch 405 loss: 2.16881 acc: 0.44082 | val loss 2.70730 acc: 0.28945 iterati