In [18]:
# In this file 
# we train our proposed PolyVAE

import torch
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from torch import optim
from torch.distributions import kl_divergence, Normal
from torch.nn import functional as F
from torch.optim.lr_scheduler import ExponentialLR
from model import PolyVAE
from torch.utils.data import Dataset, DataLoader, TensorDataset

class MinExponentialLR(ExponentialLR):
    def __init__(self, optimizer, gamma, minimum, last_epoch=-1):
        self.min = minimum
        super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1)

    def get_lr(self):
        return [
            max(base_lr * self.gamma**self.last_epoch, self.min)
            for base_lr in self.base_lrs
        ]
###############################
# initial parameters
s_dir = ""
batch_size = 64
n_epochs = 4
data_path = [s_dir + "data/poly_train.npy",
             s_dir + "data/poly_validate.npy",
             s_dir + "data/poly_train.npy"]
save_path = ""
lr = 1e-4
decay = 0.9999
hidden_dims = 512
z_dims = 1024
vae_beta = 0.1
input_dims = 130
seq_len = 10 * 16
beat_num = 10
tick_num = 16
save_period = 1
experiment_name = "default"
##############################


In [19]:
# input data
train_set = np.load(data_path[0], allow_pickle = True)
validate_set = np.load(data_path[1],allow_pickle = True) 

train_x = []
for i,data in enumerate(train_set):
    temp = []
    for d in data["layers"]:
        temp += d
    train_x.append(temp)
train_x = np.array(train_x)
# print(train_x.shape)

validate_x = []
for i,data in enumerate(validate_set):
    temp = []
    for d in data["layers"]:
        temp += d
    validate_x.append(temp)
validate_x = np.array(validate_x)
# print(train_x.shape)
train_x = torch.from_numpy(train_x).long()
validate_x = torch.from_numpy(validate_x).long()

print(train_x.size())
print(validate_x.size())

train_set = TensorDataset(train_x)
validate_set = TensorDataset(validate_x)

train_set = DataLoader(
    dataset = train_set,
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)
validate_set = DataLoader(
    dataset = validate_set,
    batch_size = batch_size, 
    shuffle = False, 
    num_workers = 8, 
    pin_memory = True, 
    drop_last = True
)


torch.Size([58041, 160])
torch.Size([7578, 160])


In [20]:
# import model
model = PolyVAE(input_dims, hidden_dims, z_dims, seq_len, beat_num, tick_num, 4000)
optimizer = optim.Adam(model.parameters(), lr = lr)
if decay > 0:
    scheduler = MinExponentialLR(optimizer, gamma = decay, minimum = 1e-5)
if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('Using: CPU')


Using:  TITAN RTX


In [21]:
# process validete data from the dataloder
validate_data = []
for i,d in enumerate(validate_set):
    validate_data.append(d[0])
print(len(validate_data))

118


In [22]:
# loss function
def std_normal(shape):
    N = Normal(torch.zeros(shape), torch.ones(shape))
    if torch.cuda.is_available():
        N.loc = N.loc.cuda()
        N.scale = N.scale.cuda()
    return N

def loss_function(recon, target, r_dis, beta):
    CE = F.cross_entropy(recon.view(-1, recon.size(-1)), target, reduction = "mean")
#     rhy_CE = F.nll_loss(recon_rhythm.view(-1, recon_rhythm.size(-1)), target_rhythm, reduction = "mean")
    normal1 =  std_normal(r_dis.mean.size())
    KLD1 = kl_divergence(r_dis, normal1).mean()
    max_indices = recon.view(-1, recon.size(-1)).max(-1)[-1]
#     print(max_indices)
    correct = max_indices == target
    acc = torch.sum(correct.float()) / target.size(0)
    return acc, CE + beta * (KLD1)


In [23]:
def save_model(model, optimizer):
    root_model_path = os.path.join(experiment_dir, 'best_model.pt')
    model_dict = model.state_dict()
    state_dict = {'model': model_dict, 'optimizer': optimizer.state_dict()}
    torch.save(state_dict, root_model_path)

def load_model(model, optimizer):
    state_dict = torch.load(os.path.join(experiment_dir, 'best_model.pt'))
    model.load_state_dict(state_dict['model'])
    optimizer.load_state_dict(state_dict['optimizer'])

def record_stats(train_loss, train_acc, val_loss, val_acc):
    training_losses.append(train_loss)
    training_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    plot_stats()

def plot_stats():
    e = len(training_losses)
    x_axis = np.arange(1, e + 1, 1)
    plt.figure()
    plt.plot(x_axis, training_losses, label="Training Loss")
    plt.plot(x_axis, val_losses, label="Validation Loss")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.savefig(os.path.join(experiment_dir, "loss_plot.png"))
    plt.close()

    plt.figure()
    plt.plot(x_axis, training_accs, label="Training Accuracy")
    plt.plot(x_axis, val_accs, label="Validation Accuracy")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.savefig(os.path.join(experiment_dir, "acc_plot.png"))
    plt.close()

In [24]:
# rename the experiment dir name every time when you tune the hyperparameters
experiment_dir = os.path.join("experiment_data/", experiment_name)
os.makedirs(experiment_dir, exist_ok=True)

training_losses = []
val_losses = []
training_accs = []
val_accs = []

best_loss = 1000

# start training
logs = []
device = torch.device(torch.cuda.current_device())
iteration = 0
step = 0
for epoch in range(n_epochs):
    print("epoch: %d\n__________________________________________" % (epoch), flush = True)
    mean_loss = 0.0
    mean_acc = 0.0
    v_mean_loss = 0.0
    v_mean_acc = 0.0
    total = 0
    for i, d in enumerate(tqdm(train_set)):
        # validate display
        x = gd = d[0]
        model.train()
        j = i % len(validate_data)
        v_x = v_gd = validate_data[j]
        
        x = x.to(device = device,non_blocking = True)
        gd = gd.to(device = device,non_blocking = True)
        v_x = v_x.to(device = device,non_blocking = True)
        v_gd = v_gd.to(device = device,non_blocking = True)
            
        optimizer.zero_grad()
        recon, r_dis, iteration = model(x, gd)
        
        acc, loss = loss_function(recon, gd.view(-1), r_dis, vae_beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        mean_loss += loss.item()
        mean_acc += acc.item()
        
        model.eval()
        with torch.no_grad():
            v_recon, v_r_dis, _ = model(v_x, v_gd)
            v_acc, v_loss = loss_function(v_recon, v_gd.view(-1), v_r_dis, vae_beta)
            v_mean_loss += v_loss.item()
            v_mean_acc += v_acc.item()
        step += 1
        total += 1
        if decay > 0:
            scheduler.step()
        # print("batch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
        #       % (i,loss.item(), acc.item(), v_loss.item(),v_acc.item(),iteration),flush = True)
    mean_loss /= total
    mean_acc /= total
    v_mean_loss /= total
    v_mean_acc /= total

    record_stats(mean_loss, mean_acc, v_mean_loss, v_mean_acc)

    print("epoch %d loss: %.5f acc: %.5f | val loss %.5f acc: %.5f iteration: %d"  
              % (epoch, mean_loss, mean_acc, v_mean_loss, v_mean_acc, iteration),flush = True)
    logs.append([mean_loss,mean_acc,v_mean_loss,v_mean_acc,iteration])
    if (epoch + 1) % save_period == 0:
        filename = "sketchvae-" + 'loss_' + str(mean_loss) + "_" + str(epoch+1) + "_" + str(iteration) + ".pt"
        torch.save(model.cpu().state_dict(),save_path + filename)
        model.cuda()
    np.save("sketchvae-log.npy", logs)
        
    if v_mean_loss < best_loss:
        save_model(model, optimizer)


c: 0.35557 iteration: 1578
batch 672 loss: 1.58351 acc: 0.55801 | val loss 2.14852 acc: 0.41904 iteration: 1579
batch 673 loss: 1.42782 acc: 0.58779 | val loss 2.04590 acc: 0.43506 iteration: 1580
batch 674 loss: 1.48352 acc: 0.57686 | val loss 1.97702 acc: 0.38203 iteration: 1581
batch 675 loss: 1.56484 acc: 0.56299 | val loss 1.45002 acc: 0.56465 iteration: 1582
batch 676 loss: 1.38614 acc: 0.60635 | val loss 1.86669 acc: 0.47656 iteration: 1583
batch 677 loss: 1.56394 acc: 0.54980 | val loss 1.60041 acc: 0.53145 iteration: 1584
batch 678 loss: 1.38407 acc: 0.62207 | val loss 1.51593 acc: 0.53691 iteration: 1585
batch 679 loss: 1.45253 acc: 0.57129 | val loss 1.56095 acc: 0.53760 iteration: 1586
batch 680 loss: 1.47279 acc: 0.56689 | val loss 1.70205 acc: 0.49492 iteration: 1587
batch 681 loss: 1.47808 acc: 0.59658 | val loss 1.89621 acc: 0.47598 iteration: 1588
batch 682 loss: 1.44367 acc: 0.58730 | val loss 2.12022 acc: 0.41113 iteration: 1589
batch 683 loss: 1.53246 acc: 0.55088 |

In [25]:
# os.makedirs(experiment_dir, exist_ok=True)
# record_stats(mean_loss, mean_acc, v_mean_loss, v_mean_acc)