In [1]:
import os
import signal
import time
import csv
import sys
import warnings
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
import time
import pprint
from loguru import logger
import wandb
import yaml

from utils import config, logger_tools, other_tools
from dataloaders import data_tools
from dataloaders.build_vocab import Vocab
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func

In [2]:
import time
from typing import List, Union


class Config:
    def __init__(self):
        self.config: str = "configs/cnn_vqvae_hand_30_skeleton.yaml"
        self.project: str = "audio2pose"
        self.csv: str = "git3.csv"
        self.trainer: str = "camn"
        self.notes: str = ""
        self.out_path = ""
        self.stat = ""
        self.ori_joints = ""
        self.tar_joints = ""
        self.vae_test_len = 64
        self.vae_test_dim = 144
        self.vae_test_stride = 20
        self.vae_codebook_size = 256
        self.vae_layer = 2
        self.vae_grow = [1,1,2,1]
        self.variational = False
        self.rec_ver_weight = 1

        # Path and save name
        self.is_train: bool = True
        self.root_path: str = ""
        self.out_root_path: str = "/outputs/audio2pose/"
        self.train_data_path: str = "/datasets/trinity/train/"
        self.val_data_path: str = "/datasets/trinity/val/"
        self.test_data_path: str = "/datasets/trinity/test/"
        self.mean_pose_path: str = "/datasets/trinity/train/"
        self.std_pose_path: str = "/datasets/trinity/train/"

        # Pretrained weights
        self.torch_hub_path: str = "../../datasets/checkpoints/"
        self.model_name_last: str = "last.pth"
        self.model_name_best: str = "best.pth"
        self.eval_model: str = "vae"
        self.e_name: Union[str, None] = None
        self.e_path: str = "/datasets/beat/generated_data/self_vae_128.bin"
        self.test_ckpt: str = "/datasets/beat_cache/beat_4english_15_141/last.bin"
        self.variational_encoding: bool = False
        self.vae_length: int = 256

        # Data
        self.dataset: str = "beat"
        self.pose_version: str = "spine_neck_141"
        self.new_cache: bool = True
        self.use_aug: bool = False
        self.disable_filtering: bool = False
        self.clean_first_seconds: int = 0
        self.clean_final_seconds: int = 0

        self.audio_rep: Union[str, None] = None
        self.word_rep: Union[str, None] = None
        self.emo_rep: Union[str, None] = None
        self.sem_rep: Union[str, None] = None
        self.facial_rep: Union[str, None] = None
        self.pose_rep: str = "skeleton_cache"
        self.speaker_id: Union[str, None] = None
        self.freeze_wordembed: bool = True
        self.audio_fps: int = 16000
        self.facial_fps: int = 15
        self.pose_fps: int = 15

        self.audio_dims: int = 1
        self.facial_dims: int = 39
        self.pose_dims: int = 123
        self.word_index_num: int = 5793
        self.word_dims: int = 300
        self.speaker_dims: int = 4
        self.emotion_dims: int = 8

        self.audio_norm: bool = False
        self.facial_norm: bool = False
        self.pose_norm: bool = True

        self.pose_length: int = 34
        self.pre_frames: int = 4
        self.stride: int = 10
        self.pre_type: str = "zero"

        self.audio_f: int = 128
        self.facial_f: int = 128
        self.speaker_f: int = 0
        self.word_f: int = 0
        self.emotion_f: int = 0
        self.aud_prob: float = 1.0
        self.pos_prob: float = 1.0
        self.txt_prob: float = 1.0
        self.fac_prob: float = 1.0
        self.multi_length_training: List[float] = [1.0]

        # Model
        self.pretrain: bool = False
        self.model: str = "camn"
        self.g_name: str = "CaMN"
        self.d_name: Union[str, None] = None
        self.dropout_prob: float = 0.3
        self.n_layer: int = 4
        self.hidden_size: int = 300
        self.finger_net: str = "original"

        # Training
        self.epochs: int = 120
        self.grad_norm: int = 0
        self.no_adv_epochs: int = 4
        self.batch_size: int = 128
        self.opt: str = "adam"
        self.lr_base: float = 0.00025
        self.opt_betas: List[float] = [0.5, 0.999]
        self.weight_decay: float = 0.0
        self.lr_min: float = 1e-7
        self.warmup_lr: float = 5e-4
        self.warmup_epochs: int = 0
        self.decay_epochs: int = 9999
        self.decay_rate: float = 0.1
        self.lr_policy: str = "step"
        self.momentum: float = 0.8
        self.nesterov: bool = True
        self.amsgrad: bool = False
        self.d_lr_weight: float = 0.2
        self.rec_weight: float = 500
        self.adv_weight: float = 20.0
        self.fid_weight: float = 0.0
        self.vel_weight: float = 0.0
        self.acc_weight: float = 0.0
        self.kld_weight: float = 0.0
        self.kld_aud_weight: float = 0.0
        self.kld_fac_weight: float = 0.0
        self.ali_weight: float = 0.0
        self.div_reg_weight: float = 0.0
        self.rec_aud_weight: float = 0.0
        self.rec_pos_weight: float = 0.0
        self.rec_fac_weight: float = 0.0
        self.rec_txt_weight: float = 0.0

        # Device
        self.random_seed: int = 2021
        self.deterministic: bool = True
        self.benchmark: bool = True
        self.cudnn_enabled: bool = True
        self.apex: bool = False
        self.gpus: List[int] = [0]
        self.loader_workers: int = 0
        self.ddp: bool = False

        # Logging
        self.log_period: int = 10
        self.test_period: int = 20

        self.name: str = ""

    def update_name(self):
        idc = self.config.rfind('/')
        self.name = self.config[idc+1:-5]
        
        if self.is_train:
            time_local = time.localtime()
            name_expend = f"{time_local[1]:02d}{time_local[2]:02d}_{time_local[3]:02d}{time_local[4]:02d}{time_local[5]:02d}_"
            self.name = name_expend + self.name

    def load_yaml(self):
        with open(self.config, 'r') as file:
            yaml_data = yaml.safe_load(file)
        
        for key, value in yaml_data.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                print(f"Warning: '{key}' is not a valid configuration parameter.")
    
    def print_attributes(self):
        for attr, value in vars(self).items():
            print(f"{attr}: {value}")
            
def create_config() -> Config:
    config = Config()
    config.load_yaml()
    return config

In [3]:
os.environ["MASTER_ADDR"]='localhost'
os.environ["MASTER_PORT"]='2224'
args = create_config()
args.update_name()

In [4]:

@logger.catch
def main_worker(rank, world_size, args):
    if not sys.warnoptions:
        warnings.simplefilter("ignore")
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
        
    logger_tools.set_args_and_logger(args, rank)
    other_tools.set_random_seed(args)
    other_tools.print_exp_info(args)

main_worker(0, 1, args)



[34m 09-12 12:37:29[0m | [1m{'acc_weight': 0.0,
 'adv_weight': 20.0,
 'ali_weight': 0.0,
 'amsgrad': False,
 'apex': False,
 'aud_prob': 1.0,
 'audio_dims': 1,
 'audio_f': 128,
 'audio_fps': 16000,
 'audio_norm': False,
 'audio_rep': None,
 'batch_size': 512,
 'benchmark': True,
 'clean_final_seconds': 0,
 'clean_first_seconds': 0,
 'config': 'configs/cnn_vqvae_hand_30_skeleton.yaml',
 'csv': 'git3.csv',
 'cudnn_enabled': True,
 'd_lr_weight': 0.2,
 'd_name': None,
 'dataset': 'beat_skeleton',
 'ddp': False,
 'decay_epochs': 9999,
 'decay_rate': 0.1,
 'deterministic': True,
 'disable_filtering': False,
 'div_reg_weight': 0.0,
 'dropout_prob': 0.3,
 'e_name': None,
 'e_path': '/datasets/beat/generated_data/self_vae_128.bin',
 'emo_rep': None,
 'emotion_dims': 8,
 'emotion_f': 0,
 'epochs': 500,
 'eval_model': 'vae',
 'fac_prob': 1.0,
 'facial_dims': 51,
 'facial_f': 128,
 'facial_fps': 15,
 'facial_norm': False,
 'facial_rep': None,
 'fid_weight': 0.0,
 'finger_net': 'original',
 'fr

In [14]:
test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test")


In [15]:
test_data[0]["pose"].shape

torch.Size([450, 144])

In [13]:
len(test_data)

56548