In [1]:
import os
import shutil
import numpy as np
import argparse
import errno
import tensorboardX
from time import time
import random
import prettytable
import traceback
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from itertools import cycle
from collections import Counter
import importlib

from lib.utils.tools import *
from lib.utils.learning import *
from lib.data.datareader_h36m import DataReaderH36M
from lib.data.dataset_2DAR import ActionRecognitionDataset2D, get_AR_labels, collate_fn_2DAR
from lib.model.loss import *
from lib.utils.viz_skel_seq import viz_skel_seq_anim
from lib.utils.viz_img_seq import viz_img_seq

def import_class(class_name):
    mod_str, _sep, class_str = class_name.rpartition('.')
    __import__(mod_str)
    try:
        return getattr(sys.modules[mod_str], class_str)
    except AttributeError:
        raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info())))


def import_function(func_name=None):
    """
    动态导入指定的函数。
    
    参数:
    - func_name: 一个字符串，表示函数的全限定名，如 "mymodule.my_function"
    
    返回:
    - 导入的函数对象
    """    
    # 分割模块名和函数名
    module_name, func_name = func_name.rsplit('.', 1)
    
    # 动态导入模块
    module = importlib.import_module(module_name)
    
    # 获取函数对象
    func = getattr(module, func_name)
    
    return func


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to the config file.")
    parser.add_argument('-c', '--checkpoint', default='ckpt/default', type=str, metavar='PATH', help='checkpoint directory')
    parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)')
    parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
    parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed')
    parser.add_argument('-v', '--visualize', action='store_true', help='whether to activate visualization')
    # opts = parser.parse_args()
    opts, _ = parser.parse_known_args()       # 在ipynb中要用这行
    return opts


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def save_checkpoint(chk_path, epoch, lr, optimizer, model_pos, eval_dict):
    print('\tSaving checkpoint to', chk_path)
    torch.save({
        'epoch': epoch + 1,
        'lr': lr,
        'optimizer': optimizer.state_dict(),
        'model_pos': model_pos.state_dict(),
        'eval_dict' : eval_dict
    }, chk_path)

In [2]:
opts = parse_args()
set_random_seed(opts.seed)
args = get_config(opts.config)

In [3]:
assert 'bin' not in opts.checkpoint
if args.use_partial_data:
    args.data = args.partial_data
else:
    args.data = args.full_data

# Import specified classes and functions
## dataset AR
dataset_action_recognition_VER = args.func_ver.get('dataset_action_recognition', 1)
dataset_action_recognition = import_class(class_name=f'funcs_and_classes.AR.dataset_AR.ver{dataset_action_recognition_VER}.Dataset_ActionRecognition')
## evaluate AR
evaluate_action_recognition_VER = args.func_ver.get('evaluate_action_recognition', 2)
evaluate_action_recognition = import_function(func_name=f'funcs_and_classes.AR.eval_AR.ver{evaluate_action_recognition_VER}.evaluate_action_recognition')
## train epoch AR
train_epoch_action_recognition_VER = args.func_ver.get('train_epoch_action_recognition', 2)
train_epoch_action_recognition = import_function(func_name=f'funcs_and_classes.AR.train_epoch.ver{train_epoch_action_recognition_VER}.train_epoch')
## dataset non-AR
dataset_VER = args.func_ver.get('dataset_non_AR', 1)
dataset = import_class(class_name=f'funcs_and_classes.Non_AR.dataset.ver{dataset_VER}.MotionDataset3D')
## evaluate non-AR
evaluate_VER = args.func_ver.get('evaluate_non_AR', 1)
evaluate_future_pose_estimation = import_function(func_name=f'funcs_and_classes.Non_AR.eval_funcs.ver{evaluate_VER}.evaluate_future_pose_estimation')
evaluate_motion_completion = import_function(func_name=f'funcs_and_classes.Non_AR.eval_funcs.ver{evaluate_VER}.evaluate_motion_completion')
evaluate_motion_prediction = import_function(func_name=f'funcs_and_classes.Non_AR.eval_funcs.ver{evaluate_VER}.evaluate_motion_prediction')
evaluate_pose_estimation = import_function(func_name=f'funcs_and_classes.Non_AR.eval_funcs.ver{evaluate_VER}.evaluate_pose_estimation')
## train epoch non-AR
train_epoch_VER = args.func_ver.get('train_epoch_non_AR', 1)
train_epoch = import_function(func_name=f'funcs_and_classes.Non_AR.train_epoch.ver{train_epoch_VER}.train_epoch')

In [4]:
train_dataset = dataset(args, data_split='train')

train sample count: {'PE': 8964}


In [5]:
# i = 7200
# data = train_dataset[i][1]
# data = data[:256][::5]
# viz_skel_seq_anim(data, if_print=False, file_name=f"{i:08d}", file_folder="tmp", lim3d=0.3, lw=4, if_rot=True, fs=1, azim=-107, elev=8, interval=75)

In [6]:
# i=8113; frames=np.arange(90, 150, 5); azim=-66; elev=10; lim3d=0.25; lw=15; print(len(frames))

In [7]:
i=5000; frames=np.arange(0, 128, 1); azim=-107; elev=8; lim3d=0.25; lw=5; fs=1; interval=40; print(len(frames))

128


In [8]:
# i=4950; frames=np.arange(0, 128, 10); azim=-107; elev=8; lim3d=0.5; lw=15; print(len(frames))

In [9]:
# i=7200; frames=np.arange(0, 128, 10); azim=-40; elev=8; lim3d=0.4; lw=15; print(len(frames))

In [10]:
data = train_dataset[i][1]
data = data[frames]
# data = data - data[:, [0], :]
data[..., 1] = data[..., 1] + 0.15
viz_skel_seq_anim(data, if_print=1, file_name=f"{i:08d}", file_folder="viz_results/outstanding_thesis_ppt", lim3d=lim3d, lw=lw, if_rot=True, fs=fs, azim=azim, elev=elev, interval=interval)