In [1]:
import os
import json

import torch
os.chdir("/workspace/")
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import models.vqvae as vqvae
# import options.option_vq as option_vq
import utils.utils_model as utils_model
from dataset import dataset_TM_eval
import utils.eval_trans as eval_trans
from options.get_eval_option import get_opt
from models.evaluator_wrapper import EvaluatorModelWrapper
import warnings
warnings.filterwarnings('ignore')

In [2]:
import easydict

opt = easydict.EasyDict({
    "dataname": "t2m",
    "batch_size": 256,
    "window_size": 64,
    "total_iter": 300000,
    "warm_up_iter": 1000,
    "lr": 2e-4,
    "lr_scheduler": [200000],#[50000, 400000],
    "gamma": 0.05,
    "weight_decay": 0.0,
    "commit": 0.02,
    "loss_vel": 0.5,
    "recons_loss": "l1_smooth",
    "code_dim": 512,
    "nb_code": 512,
    "mu": 0.99,
    "down_t": 2,
    "stride_t": 2,
    "width": 512,
    "depth": 3,
    "dilation_growth_rate": 3,
    "output_emb_width": 512,
    "vq_act": "relu",
    "vq_norm": None,
    "quantizer": "ema_reset",
    "beta": 1.0,
    "resume_pth": None,
    "resume_gpt": None,
    "out_dir": "output_vqfinal/TEST_VQVAE/eval",
    "results_dir": "visual_results/",
    "visual_name": "baseline",
    "exp_name": "TEST_VQVAE/eval",
    "print_iter": 200,
    "eval_iter": 1000,
    "seed": 123,
    "vis_gt": False,
    "nb_vis": 20,
    "nb_joints" : 0,
    "resume_pth" : "output_vqfinal/VQ-VAE/eval/net_last.pth"
})

In [3]:
logger = utils_model.get_logger(opt.out_dir)
writer = SummaryWriter(opt.out_dir)
logger.info(json.dumps(vars(opt), indent=4, sort_keys=True))


2023-06-19 07:34:04,634 INFO {
    "batch_size": 256,
    "beta": 1.0,
    "code_dim": 512,
    "commit": 0.02,
    "dataname": "t2m",
    "depth": 3,
    "dilation_growth_rate": 3,
    "down_t": 2,
    "eval_iter": 1000,
    "exp_name": "TEST_VQVAE/eval",
    "gamma": 0.05,
    "loss_vel": 0.5,
    "lr": 0.0002,
    "lr_scheduler": [
        200000
    ],
    "mu": 0.99,
    "nb_code": 512,
    "nb_joints": 0,
    "nb_vis": 20,
    "out_dir": "output_vqfinal/TEST_VQVAE/eval",
    "output_emb_width": 512,
    "print_iter": 200,
    "quantizer": "ema_reset",
    "recons_loss": "l1_smooth",
    "results_dir": "visual_results/",
    "resume_gpt": null,
    "resume_pth": "output_vqfinal/VQ-VAE/eval/net_last.pth",
    "seed": 123,
    "stride_t": 2,
    "total_iter": 300000,
    "vis_gt": false,
    "visual_name": "baseline",
    "vq_act": "relu",
    "vq_norm": null,
    "warm_up_iter": 1000,
    "weight_decay": 0.0,
    "width": 512,
    "window_size": 64
}


In [4]:
from utils.word_vectorizer import WordVectorizer
w_vectorizer = WordVectorizer('./glove', 'our_vab')


In [5]:
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if opt.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'

In [6]:
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)


Reading checkpoints/t2m/Comp_v6_KLD005/opt.txt
Loading Evaluation Model Wrapper (Epoch 28) Completed!!


In [7]:
opt.nb_joints = 21 if opt.dataname == 'kit' else 22

val_loader = dataset_TM_eval.DATALoader(opt.dataname, True, 32, w_vectorizer, unit_length=2**opt.down_t)

##### ---- Network ---- #####
net = vqvae.HumanVQVAE(opt, ## use opt to define different parameters in different quantizers
                       opt.nb_code,
                       opt.code_dim,
                       opt.output_emb_width,
                       opt.down_t,
                       opt.stride_t,
                       opt.width,
                       opt.depth,
                       opt.dilation_growth_rate,
                       opt.vq_act,
                       opt.vq_norm)

100%|██████████| 4384/4384 [00:00<00:00, 5026.18it/s]


Pointer Pointing at 0


In [8]:
if opt.resume_pth : 
    logger.info('loading checkpoint from {}'.format(opt.resume_pth))
    ckpt = torch.load(opt.resume_pth, map_location='cpu')
    net.load_state_dict(ckpt['net'], strict=True)
net.train()
net.cuda()


2023-06-19 07:34:11,655 INFO loading checkpoint from output_vqfinal/VQ-VAE/eval/net_last.pth


HumanVQVAE(
  (vqvae): VQVAE_251(
    (encoder): Encoder(
      (model): Sequential(
        (0): Conv1d(263, 512, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): ReLU()
        (2): Sequential(
          (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
          (1): Resnet1D(
            (model): Sequential(
              (0): ResConv1DBlock(
                (norm1): Identity()
                (norm2): Identity()
                (activation1): ReLU()
                (activation2): ReLU()
                (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
                (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
              )
              (1): ResConv1DBlock(
                (norm1): Identity()
                (norm2): Identity()
                (activation1): ReLU()
                (activation2): ReLU()
                (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
 

In [9]:
fid = []
div = []
top1 = []
top2 = []
top3 = []
matching = []
repeat_time = 20
for i in range(repeat_time):
    best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(opt.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper, draw=True, savegif=True, save=False, savenpy=False)
    fid.append(best_fid)
    div.append(best_div)
    top1.append(best_top1)
    top2.append(best_top2)
    top3.append(best_top3)
    matching.append(best_matching)
print('final result:')
print('fid: ', sum(fid)/repeat_time)
print('div: ', sum(div)/repeat_time)
print('top1: ', sum(top1)/repeat_time)
print('top2: ', sum(top2)/repeat_time)
print('top3: ', sum(top3)/repeat_time)
print('matching: ', sum(matching)/repeat_time)

fid = np.array(fid)
div = np.array(div)
top1 = np.array(top1)
top2 = np.array(top2)
top3 = np.array(top3)
matching = np.array(matching)
msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}"
logger.info(msg_final)

2023-06-19 07:34:41,908 INFO --> 	 Eva. Iter 0 :, FID. 0.1140, Diversity Real. 9.4657, Diversity. 9.6670, R_precision_real. [0.51228448 0.71142241 0.80732759], R_precision. [0.49633621 0.69568966 0.79331897], matching_score_real. 2.977429518206366, matching_score_pred. 3.0757760343880487
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could be missing [imageio, requests]
moviepy is installed, but can't import moviepy.editor. Some packages could 