In [1]:
import os
os.chdir("/workspace/")
import json

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import models.vqvae as vqvae
import utils.losses as losses 
# import options.option_vq as option_vq
import utils.utils_model as utils_model
from dataset import dataset_VQ, dataset_TM_eval
import utils.eval_trans as eval_trans
from options.get_eval_option import get_opt
from models.evaluator_wrapper import EvaluatorModelWrapper
import warnings
warnings.filterwarnings('ignore')
from utils.word_vectorizer import WordVectorizer

In [2]:
import easydict

args = easydict.EasyDict({
    "batch_size": 256,
    "beta": 1.0,
    "code_dim": 512,
    "commit": 0.02,
    "dataname": "t2m",
    "depth": 3,
    "dilation_growth_rate": 3,
    "down_t": 2,
    "eval_iter": 1000,
    "gamma": 0.05,
    "loss_vel": 0.5,
    "lr": 0.0002,
    "lr_scheduler": [
        200000,
        200000
    ],
    "mu": 0.99,
    "nb_code": 512,
    "nb_joints": 0,
    "nb_vis": 20,
    "out_dir": "output_vqfinal/VQ-VAE/",
    "exp_name": "test/",
    "output_emb_width": 512,
    "print_iter": 200,
    "quantizer": "ema_reset",
    "recons_loss": "l1_smooth",
    "results_dir": "visual_results/",
    "resume_gpt": None,
    "resume_pth": None,
    "seed": 123,
    "stride_t": 2,
    "total_iter": 300000,
    "vis_gt": False,
    "visual_name": "baseline",
    "vq_act": "relu",
    "vq_norm": None,
    "warm_up_iter": 1000,
    "weight_decay": 0.0,
    "width": 512,
    "window_size": 64
})

In [3]:
torch.manual_seed(args.seed)

args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
os.makedirs(args.out_dir, exist_ok = True)

##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))

2023-08-03 11:28:43,342 INFO {
    "batch_size": 256,
    "beta": 1.0,
    "code_dim": 512,
    "commit": 0.02,
    "dataname": "t2m",
    "depth": 3,
    "dilation_growth_rate": 3,
    "down_t": 2,
    "eval_iter": 1000,
    "exp_name": "test/",
    "gamma": 0.05,
    "loss_vel": 0.5,
    "lr": 0.0002,
    "lr_scheduler": [
        200000,
        200000
    ],
    "mu": 0.99,
    "nb_code": 512,
    "nb_joints": 0,
    "nb_vis": 20,
    "out_dir": "output_vqfinal/VQ-VAE/test/",
    "output_emb_width": 512,
    "print_iter": 200,
    "quantizer": "ema_reset",
    "recons_loss": "l1_smooth",
    "results_dir": "visual_results/",
    "resume_gpt": null,
    "resume_pth": null,
    "seed": 123,
    "stride_t": 2,
    "total_iter": 300000,
    "vis_gt": false,
    "visual_name": "baseline",
    "vq_act": "relu",
    "vq_norm": null,
    "warm_up_iter": 1000,
    "weight_decay": 0.0,
    "width": 512,
    "window_size": 64
}


In [4]:
w_vectorizer = WordVectorizer('./glove', 'our_vab')

if args.dataname == 'kit' : 
    dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'  
    args.nb_joints = 21
    
else :
    dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
    args.nb_joints = 22

logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')

wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)

2023-08-03 11:28:45,545 INFO Training on t2m, motions are with 22 joints
Reading checkpoints/t2m/Comp_v6_KLD005/opt.txt
Loading Evaluation Model Wrapper (Epoch 28) Completed!!


In [5]:
train_loader = dataset_VQ.DATALoader(args.dataname,
                                        args.batch_size,
                                        window_size=args.window_size,
                                        unit_length=2**args.down_t)

train_loader_iter = dataset_VQ.cycle(train_loader)

val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
                                        32,
                                        w_vectorizer,
                                        unit_length=2**args.down_t)

100%|██████████| 23384/23384 [00:12<00:00, 1802.27it/s]


Total number of motions 20941


100%|██████████| 1460/1460 [00:00<00:00, 1501.36it/s]

Pointer Pointing at 0





In [6]:
##### ---- Network ---- #####
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
                       args.nb_code,
                       args.code_dim,
                       args.output_emb_width,
                       args.down_t,
                       args.stride_t,
                       args.width,
                       args.depth,
                       args.dilation_growth_rate,
                       args.vq_act,
                       args.vq_norm)

In [7]:
from lion_pytorch import Lion

if args.resume_pth : 
    logger.info('loading checkpoint from {}'.format(args.resume_pth))
    ckpt = torch.load(args.resume_pth, map_location='cpu')
    net.load_state_dict(ckpt['net'], strict=True)
net.train()
net.cuda()

##### ---- Optimizer & Scheduler ---- #####
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
# optimizer = Lion(net.parameters(), lr=1e-4, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
  
Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints)

In [8]:
def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr):

    current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr

    return optimizer, current_lr


In [9]:
##### ------ warm-up ------- #####
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.

for nb_iter in range(1, args.warm_up_iter):
    
    optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
    
    gt_motion = next(train_loader_iter)
    gt_motion = gt_motion.cuda().float() # (bs, 64, dim)

    pred_motion, loss_commit, perplexity = net(gt_motion)
    loss_motion = Loss(pred_motion, gt_motion)
    loss_vel = Loss.forward_vel(pred_motion, gt_motion)
    
    loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    avg_recons += loss_motion.item()
    avg_perplexity += perplexity.item()
    avg_commit += loss_commit.item()
    
    if nb_iter % args.print_iter ==  0 :
        avg_recons /= args.print_iter
        avg_perplexity /= args.print_iter
        avg_commit /= args.print_iter
        
        logger.info(f"Warmup. Iter {nb_iter} :  lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons.  {avg_recons:.5f}")
        
        avg_recons, avg_perplexity, avg_commit = 0., 0., 0.


2023-08-03 11:29:59,420 INFO Warmup. Iter 200 :  lr 0.00004 	 Commit. 0.28895 	 PPL. 84.34 	 Recons.  0.71186
2023-08-03 11:30:20,473 INFO Warmup. Iter 400 :  lr 0.00008 	 Commit. 1.05421 	 PPL. 130.10 	 Recons.  0.51088
2023-08-03 11:30:42,436 INFO Warmup. Iter 600 :  lr 0.00012 	 Commit. 2.26659 	 PPL. 222.63 	 Recons.  0.40075
2023-08-03 11:31:03,715 INFO Warmup. Iter 800 :  lr 0.00016 	 Commit. 3.17441 	 PPL. 279.31 	 Recons.  0.33115


In [10]:
##### ---- Training ---- #####
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)

for nb_iter in range(1, args.total_iter + 1):
    
    gt_motion = next(train_loader_iter)
    gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len
    
    pred_motion, loss_commit, perplexity = net(gt_motion)
    loss_motion = Loss(pred_motion, gt_motion)
    loss_vel = Loss.forward_vel(pred_motion, gt_motion)
    
    loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    avg_recons += loss_motion.item()
    avg_perplexity += perplexity.item()
    avg_commit += loss_commit.item()
    
    if nb_iter % args.print_iter ==  0 :
        avg_recons /= args.print_iter
        avg_perplexity /= args.print_iter
        avg_commit /= args.print_iter
        
        writer.add_scalar('./Train/L1', avg_recons, nb_iter)
        writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter)
        writer.add_scalar('./Train/Commit', avg_commit, nb_iter)
        
        logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons.  {avg_recons:.5f}")
        
        avg_recons, avg_perplexity, avg_commit = 0., 0., 0.,

    if nb_iter % args.eval_iter==0 :
        best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
        

2023-08-03 11:31:53,047 INFO --> 	 Eva. Iter 0 :, FID. 1.8976, Diversity Real. 9.7029, Diversity. 9.1374, R_precision_real. [0.51529255 0.70345745 0.79321809], R_precision. [0.39827128 0.57646277 0.67952128], matching_score_real. 2.887122757891391, matching_score_pred. 3.686022038155414
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could b