In [1]:
import os
import numpy as np
from humor.datasets.amass_discrete_dataset import AmassDiscreteDataset
from humor.datasets.amass_fit_dataset import AMASSFitDataset
from torch.utils.data import Dataset, DataLoader

%load_ext autoreload
%autoreload 2

In [3]:
# configs
data_path  = r"../datasets/AMASS/amass_processed"
data_roots = data_path 
split_by = 'sequence'
sample_num_frames = 10
data_steps_in = 1
data_steps_out = 1
data_rot_rep = 'mat'
data_return_config  = "smpl+joints+contacts"

print("Data roots: ", data_roots)

Data roots:  ../datasets/AMASS/amass_processed


## Training dataset

In [4]:
dataset = AmassDiscreteDataset(
    split='train',
    data_paths=data_roots,
    split_by=split_by,
    sample_num_frames=sample_num_frames,
    step_frames_in=data_steps_in,
    step_frames_out=data_steps_out,
    data_rot_rep=data_rot_rep,
    data_return_config=data_return_config,
)
loader = DataLoader(dataset, 
                        batch_size=1,
                        shuffle=True,
                        num_workers=0,
                        pin_memory=True,
                        drop_last=False,
                        worker_init_fn=lambda _: np.random.seed()) # get around numpy RNG seed bug

Loading data from../datasets/AMASS/amass_processed
Logger must be initialized before logging!
This split contains 286 sequences (that meet the duration criteria).
Logger must be initialized before logging!
The dataset contains 18411 sub-sequences in total.
Logger must be initialized before logging!


In [5]:
for i, (data_in, data_out, meta) in enumerate(loader):
    for k, v in data_in.items():
        print(f"data_in[{k}]: {v.shape}")
    break

data_in[pose_body]: torch.Size([1, 10, 1, 189])
data_in[root_orient]: torch.Size([1, 10, 1, 9])
data_in[root_orient_vel]: torch.Size([1, 10, 1, 3])
data_in[trans]: torch.Size([1, 10, 1, 3])
data_in[trans_vel]: torch.Size([1, 10, 1, 3])
data_in[joints]: torch.Size([1, 10, 1, 22, 3])
data_in[joints_vel]: torch.Size([1, 10, 1, 22, 3])
data_in[contacts]: torch.Size([1, 10, 1, 9])


In [None]:
fit_dataset = AMASSFitDataset(
    data_path=data_roots,
    seq_len=60,
    return_joints=True,
    return_verts=False,
    return_points=False,
    noise_std=0.04,
    make_partial=False,
    partial_height=0.9,
    drop_middle=False,
    root_only=False,
    split_by='sequence',
    custom_split=None,
)

Loading data from../datasets/AMASS/amass_processed
Logger must be initialized before logging!
This split contains 66 sequences (that meet the duration criteria).
Logger must be initialized before logging!
The dataset contains 1281 sub-sequences in total.
Logger must be initialized before logging!


## FitDataset with joints config

In [10]:
data_loader = DataLoader(fit_dataset, 
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=False,
                            worker_init_fn=lambda _: np.random.seed())

In [23]:
for i, data in enumerate(data_loader):
    observed_data, gt_data = data
    print(f"observed_data keys: {observed_data.keys()}")
    print(f"gt_data keys: {gt_data.keys()}")
    print(f"observed_data['joints3d']: {observed_data['joints3d'].shape}")
    for k, v in gt_data.items():
        if k != "gender" and k != "name":
            print(f"gt_data[{k}]: {v.shape}")
        else:
            print(f"gt_data[{k}]: {v}")
    break

observed_data keys: dict_keys(['joints3d'])
gt_data keys: dict_keys(['root_orient', 'trans', 'joints', 'verts', 'trans_vel', 'root_orient_vel', 'joints_vel', 'pose_body', 'contacts', 'betas', 'gender', 'name'])
observed_data['joints3d']: torch.Size([1, 60, 22, 3])
gt_data[root_orient]: torch.Size([1, 60, 3])
gt_data[trans]: torch.Size([1, 60, 3])
gt_data[joints]: torch.Size([1, 60, 22, 3])
gt_data[verts]: torch.Size([1, 60, 43, 3])
gt_data[trans_vel]: torch.Size([1, 60, 3])
gt_data[root_orient_vel]: torch.Size([1, 60, 3])
gt_data[joints_vel]: torch.Size([1, 60, 22, 3])
gt_data[pose_body]: torch.Size([1, 60, 63])
gt_data[contacts]: torch.Size([1, 60, 22])
gt_data[betas]: torch.Size([1, 60, 16])
gt_data[gender]: ['female']
gt_data[name]: ['datasets_AMASS_amass_processed\\ACCAD\\Female1General_c3d\\A7 - crouch_poses_121_frames_30_fps']
