In [1]:
import os
import sys
import argparse
import time
import numpy as np
import torch
import random
from torch import optim
from torch.optim import lr_scheduler
sys.path.append(os.getcwd())
from torch.utils.data import DataLoader
from data.dataloader_dance import DanceDataset, seq_collate
from model.GroupNet_dance import GroupNet
import math

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

class Args:
    def __init__(self):
        # Training parameters
        self.seed = 1
        self.dataset = 'dance'
        self.batch_size = 32
        self.past_length = 8
        self.future_length = 10
        self.traj_scale = 1
        self.learn_prior = False
        self.lr = 1e-4
        self.sample_k = 20
        self.num_epochs = 100
        self.decay_step = 10
        self.decay_gamma = 0.5
        self.iternum_print = 100

        # Model parameters
        self.ztype = 'gaussian'
        self.zdim = 32
        self.hidden_dim = 64
        self.hyper_scales = [5,11]
        self.num_decompose = 2
        self.min_clip = 2.0

        # Save/load parameters
        self.model_save_dir = 'saved_models/dance'
        self.model_save_epoch = 5
        self.epoch_continue = 0
        self.gpu = 0
    
    def __str__(self):
        return str(self.__dict__)

args = Args()

In [3]:
""" setup """
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.set_default_dtype(torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cuda', index=args.gpu) if torch.cuda.is_available() else torch.device('cpu')
# if torch.cuda.is_available(): 
#     torch.cuda.set_device(args.gpu)
print('device:',device)
print(args)

device: cuda
{'seed': 1, 'dataset': 'dance', 'batch_size': 32, 'past_length': 8, 'future_length': 10, 'traj_scale': 1, 'learn_prior': False, 'lr': 0.0001, 'sample_k': 20, 'num_epochs': 100, 'decay_step': 10, 'decay_gamma': 0.5, 'iternum_print': 100, 'ztype': 'gaussian', 'zdim': 32, 'hidden_dim': 64, 'hyper_scales': [5, 11], 'num_decompose': 2, 'min_clip': 2.0, 'model_save_dir': 'saved_models/dance', 'model_save_epoch': 5, 'epoch_continue': 0, 'gpu': 0}


In [4]:
def train(train_loader,epoch):
    model.train()
    total_iter_num = len(train_loader)
    iter_num = 0
    for data in train_loader:
        total_loss,loss_pred,loss_recover,loss_kl,loss_diverse = model(data)
        """ optimize """
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if iter_num % args.iternum_print == 0:
            print('Epochs: {:02d}/{:02d}| It: {:04d}/{:04d} | Total loss: {:03f}| Loss_pred: {:03f}| Loss_recover: {:03f}| Loss_kl: {:03f}| Loss_diverse: {:03f}'
            .format(epoch,args.num_epochs,iter_num,total_iter_num,total_loss.item(),loss_pred,loss_recover,loss_kl,loss_diverse))
        iter_num += 1

    scheduler.step()
    model.step_annealer()

In [5]:
foo = np.load('datasets/dance/train.npy') 
foo.shape

(1702, 18, 53, 6)

In [6]:
""" model & optimizer """
model = GroupNet(args,device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.decay_gamma)

""" dataloader """
train_set = DanceDataset(
    obs_len=args.past_length,
    pred_len=args.future_length,
    training=True)

train_loader = DataLoader(
    train_set,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=seq_collate,
    pin_memory=True)

Loaded Train data with shape: (1702, 18, 53, 2)
1702


In [7]:
""" Loading if needed """
if args.epoch_continue > 0:
    checkpoint_path = os.path.join(args.model_save_dir,str(args.epoch_continue)+'.p')
    print('load model from: {checkpoint_path}')
    model_load = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(model_load['model_dict'])
    if 'optimizer' in model_load:
        optimizer.load_state_dict(model_load['optimizer'])
    if 'scheduler' in model_load:
        scheduler.load_state_dict(model_load['scheduler'])

In [8]:
train_set.trajs.shape

(1702, 18, 53, 2)

In [None]:
""" start training """
model.set_device(device)
for epoch in range(args.epoch_continue, args.num_epochs):
    train(train_loader,epoch)
    """ save model """
    if  (epoch + 1) % args.model_save_epoch == 0:
        model_saved = {'model_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch + 1,'model_cfg': args}
        saved_path = os.path.join(args.model_save_dir,str(epoch+1)+'.p')
        torch.save(model_saved, saved_path)

  soft_max_1d = F.softmax(trans_input)


Epochs: 00/100| It: 0000/0054 | Total loss: 16.031925| Loss_pred: 0.986369| Loss_recover: 12.900703| Loss_kl: 2.000000| Loss_diverse: 0.144853
Epochs: 01/100| It: 0000/0054 | Total loss: 3.213174| Loss_pred: 0.445052| Loss_recover: 0.701560| Loss_kl: 2.000000| Loss_diverse: 0.066562
Epochs: 02/100| It: 0000/0054 | Total loss: 2.514754| Loss_pred: 0.276757| Loss_recover: 0.200157| Loss_kl: 2.000000| Loss_diverse: 0.037840
