In [1]:
import os
import sys
import time
import numpy as np
import torch
import random
from torch import optim
from torch.optim import lr_scheduler
sys.path.append(os.getcwd())
from torch.utils.data import DataLoader
from data.dataloader_dance import DanceDataset, seq_collate
from model.GroupNet_dance import GroupNet
import math

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

class Args:
    def __init__(self):
        # Training parameters
        self.seed = 1
        self.dataset = 'dance'
        self.batch_size = 32
        self.past_length = 8 # number of frames to condition on
        self.future_length = 10 # number of frames to predict into the future 
        self.traj_scale = 1 # scale factor applied to trajectory coordinates
        self.learn_prior = False # whether to learn prior distribution vs using fixed prior
        self.lr = 3e-5 # learning rate
        self.weight_decay = 0.001
        self.sample_k = 20 # number of samples to generate during testing for diverse predictions
        self.num_epochs = 100
        self.decay_step = 10 # number of epochs before applying learning rate decay
        self.decay_gamma = 0.9 # learning rate decay factor
        self.print_every_it = 18 # print training stats every N iterations
        self.test_every_it = 27 # test model every N iterations

        # Model parameters
        self.ztype = 'gaussian' # type of latent distribution: 'gaussian' or 'vmf'
        self.zdim = 32 # dimension of latent variable
        self.hidden_dim = 64 # dimension of hidden layers
        self.hyper_scales = [15,53] # scales for hyperprior ([5,11] for nba)
        self.num_decompose = 2 # number of decomposed distributions
        self.min_clip = 2.0

        # Save/load parameters
        self.model_save_dir = 'saved_models/dance'
        self.model_save_epoch = 5 # save model every N epochs
        self.epoch_continue = 0 # epoch to continue training from, 0 if training from scratch
    
    def __str__(self):
        return str(self.__dict__)

args = Args()

In [3]:
""" setup """
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.set_default_dtype(torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device:',device)
print(args)

if not os.path.isdir(args.model_save_dir):
    os.makedirs(args.model_save_dir)

device: cuda
{'seed': 1, 'dataset': 'dance', 'batch_size': 32, 'past_length': 8, 'future_length': 10, 'traj_scale': 1, 'learn_prior': False, 'lr': 3e-05, 'weight_decay': 0.001, 'sample_k': 20, 'num_epochs': 100, 'decay_step': 10, 'decay_gamma': 0.9, 'print_every_it': 18, 'test_every_it': 27, 'ztype': 'gaussian', 'zdim': 32, 'hidden_dim': 64, 'hyper_scales': [15, 53], 'num_decompose': 2, 'min_clip': 2.0, 'model_save_dir': 'saved_models/dance', 'model_save_epoch': 5, 'epoch_continue': 0}


In [4]:
def train(train_loader,epoch):
	model.train()
	total_iter_num = len(train_loader)
	iter_num = 0
	for data in train_loader:
		total_loss, loss_pred, loss_recover, loss_kl, loss_diverse = model(data)
		""" optimize """
		optimizer.zero_grad()
		total_loss.backward()
		optimizer.step()

		if iter_num % args.print_every_it == 0:
			print('Epochs: {:02d}/{:02d}| It: {:04d}/{:04d} | Train loss: {:6.3f} (pred: {:6.3f}| recover: {:6.3f}| kl: {:6.3f}| diverse: {:6.3f})'
					.format(epoch,args.num_epochs,iter_num,total_iter_num,total_loss.item(),loss_pred,loss_recover,loss_kl,loss_diverse))

		# if iter_num % args.test_every_it == 0:
		# 	with torch.no_grad():
		# 		model.eval()
		# 		total_loss, loss_pred, loss_recover, loss_kl, loss_diverse = model(test_loader)
		# 		print('\t\t\t Test  | Test loss: {:6.3f} (pred: {:6.3f}| recover: {:6.3f}| kl: {:6.3f}| diverse: {:6.3f})'
		# 				.format(epoch,args.num_epochs,iter_num,total_iter_num,total_loss.item(),loss_pred,loss_recover,loss_kl,loss_diverse))
		
		iter_num += 1

	scheduler.step()
	model.step_annealer()

In [5]:
foo = np.load('datasets/dance/train.npy') 
foo.shape

(1702, 18, 53, 6)

In [6]:
# for ct, data_test in enumerate(test_loader):
#     print(model(data_test))

In [7]:
""" model & optimizer """
model = GroupNet(args,device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.decay_gamma)

""" dataloader """
train_set = DanceDataset(
    obs_len=args.past_length,
    pred_len=args.future_length,
    training=True)

train_loader = DataLoader(
    train_set,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=seq_collate,
    pin_memory=True)

test_set = DanceDataset(
    obs_len=args.past_length,
    pred_len=args.future_length,
    training=False)

test_loader = DataLoader(
    test_set,
    batch_size=len(test_set),
    shuffle=True,
    num_workers=4,
    collate_fn=seq_collate,
    pin_memory=True)

Loaded Train data with shape: (1702, 18, 53, 2)
1702
Loaded Test data with shape: (426, 18, 53, 2)
426


In [8]:
""" Loading if needed """
if args.epoch_continue > 0:
    checkpoint_path = os.path.join(args.model_save_dir,str(args.epoch_continue)+'.p')
    print('load model from: {checkpoint_path}')
    model_load = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(model_load['model_dict'])
    if 'optimizer' in model_load:
        optimizer.load_state_dict(model_load['optimizer'])
    if 'scheduler' in model_load:
        scheduler.load_state_dict(model_load['scheduler'])

In [9]:
""" start training """
model.set_device(device)
for epoch in range(args.epoch_continue, args.num_epochs):
    train(train_loader,epoch)
    """ save model """
    if  (epoch + 1) % args.model_save_epoch == 0:
        model_saved = {'model_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch + 1,'model_cfg': args}
        saved_path = os.path.join(args.model_save_dir,str(epoch+1)+'.p')
        torch.save(model_saved, saved_path)

  soft_max_1d = F.softmax(trans_input)


Epochs: 00/100| It: 0000/0054 | Train loss: 16.032 (pred:  0.986| recover: 12.902| kl:  2.000| diverse:  0.145)
Epochs: 00/100| It: 0018/0054 | Train loss: 10.707 (pred:  0.647| recover:  7.964| kl:  2.000| diverse:  0.096)
Epochs: 00/100| It: 0036/0054 | Train loss:  6.604 (pred:  0.395| recover:  4.154| kl:  2.000| diverse:  0.056)
Epochs: 01/100| It: 0000/0054 | Train loss:  4.454 (pred:  0.607| recover:  1.754| kl:  2.000| diverse:  0.092)
Epochs: 01/100| It: 0018/0054 | Train loss:  3.899 (pred:  0.469| recover:  1.358| kl:  2.000| diverse:  0.071)
Epochs: 01/100| It: 0036/0054 | Train loss:  3.641 (pred:  0.317| recover:  1.279| kl:  2.000| diverse:  0.045)
Epochs: 02/100| It: 0000/0054 | Train loss:  3.669 (pred:  0.408| recover:  1.202| kl:  2.000| diverse:  0.059)
Epochs: 02/100| It: 0018/0054 | Train loss:  3.635 (pred:  0.452| recover:  1.116| kl:  2.000| diverse:  0.068)
Epochs: 02/100| It: 0036/0054 | Train loss:  3.175 (pred:  0.334| recover:  0.794| kl:  2.000| diverse: 

KeyboardInterrupt: 