Environment: `py39_torch271`

In [None]:
import os
import random
from copy import deepcopy

from typing import Optional
from dataclasses import dataclass, asdict
from pathlib import Path
from datetime import datetime

import torch
import numpy as np
import pandas as pd
from dacite import from_dict
from sklearn.model_selection import train_test_split

from pyrutils.torch.train_utils import train, save_checkpoint
from pyrutils.torch.multi_task import MultiTaskLossLearner
from vhoi.data_loading import (
    input_size_from_data_loader, 
    select_model_data_feeder, 
    select_model_data_fetcher,
)
from vhoi.data_loading_custom import (
    create_data,
    create_data_loader
)
from vhoi.losses import (
    select_loss, 
    decide_num_main_losses, 
    select_loss_types, 
    select_loss_learning_mask,
)
from vhoi.models import load_model_weights
from vhoi.models_custom import TGGCN

seed = 42
random.seed(seed)   # Python的随机性
os.environ['PYTHONHASHSEED'] = str(seed)    # 设置Python哈希种子，为了禁止hash随机化，使得实验可复现
np.random.seed(seed)   # numpy的随机性
torch.manual_seed(seed)   # torch的CPU随机性，为CPU设置随机种子
torch.cuda.manual_seed(seed)   # torch的GPU随机性，为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.   torch的GPU随机性，为所有GPU设置随机种子
torch.backends.cudnn.benchmark = False   # if benchmark=True, deterministic will be False
torch.backends.cudnn.deterministic = True   # 选择确定性算法

In [None]:
class DictMixin:
    def get(self, key, default_value=None):
        return getattr(self, key, default_value)

    def as_dict(self):
        return asdict(self)
    
@dataclass
class Resources(DictMixin):
    use_gpu: bool
    num_threads: int

@dataclass
class ModelMetadata(DictMixin):
    model_name: str
    input_type: str

@dataclass
class ModelParameters(DictMixin):
    add_segment_length: int
    add_time_position: int
    time_position_strategy: str
    positional_encoding_style: str
    attention_style: str
    bias: bool
    cat_level_states: int
    discrete_networks_num_layers: int
    discrete_optimization_strategy: str
    filter_discrete_updates: bool
    gcn_node: int
    hidden_size: int
    message_humans_to_human: bool
    message_human_to_objects: bool
    message_objects_to_human: bool
    message_objects_to_object: bool
    message_geometry_to_objects: bool
    message_geometry_to_human: bool
    message_segment: bool
    message_type: str
    message_granularity: str
    message_aggregation: str
    object_segment_update_strategy: str
    share_level_mlps: int
    update_segment_threshold: float

@dataclass
class ModelOptimization(DictMixin):
    batch_size: int
    clip_gradient_at: float
    epochs: int
    learning_rate: float
    val_fraction: float

@dataclass
class BudgetLoss(DictMixin):
    add: bool
    human_weight: float
    object_weight: float

@dataclass
class SegmentationLoss(DictMixin):
    add: bool
    pretrain: bool
    sigma: float
    weight: float

@dataclass
class ModelMisc(DictMixin):
    anticipation_loss_weight: float
    budget_loss: BudgetLoss
    first_level_loss_weight: float
    impose_segmentation_pattern: int
    input_human_segmentation: bool
    input_object_segmentation: bool
    make_attention_distance_based: bool
    multi_task_loss_learner: bool
    pretrained: bool
    pretrained_path: Optional[str]
    segmentation_loss: SegmentationLoss

@dataclass
class ModelLogging(DictMixin):
    root_log_dir: str
    checkpoint_name: str
    log_dir: str

@dataclass
class Models(DictMixin):
    metadata: ModelMetadata
    parameters: ModelParameters
    optimization: ModelOptimization
    misc: ModelMisc
    logging: ModelLogging

@dataclass
class Data(DictMixin):
    name: str
    path: str
    path_zarr: str
    path_obb_zarr: str
    path_hbb_zarr: str
    path_hps_zarr: str
    cross_validation_test_subject: str
    scaling_strategy: Optional[str]
    downsampling: int

@dataclass
class Config(DictMixin):
    resources: Resources
    models: Models
    data: Data
    
metadata_dict = {
    "model_name": "2G-GCN",
    "input_type": "multiple"
}

parameters_dict = {
    "add_segment_length": 0,  # length of the segment to the segment-level rnn. 0 is off and 1 is on.
    "add_time_position": 0,  # absolute time position to the segment-level rnn. 0 is off and 1 is on.
    "time_position_strategy": "s",  # input time position to segment [s] or discrete update [u].
    "positional_encoding_style": "e",  # e [embedding] or p [periodic].
    "attention_style": "v3",  # v1 [concat], v2 [dot-product], v3 [scaled_dot-product], v4 [general]
    "bias": True,
    "cat_level_states": 0,  # concatenate first and second level hidden states for predictors MLPs.
    "discrete_networks_num_layers": 1,  # depth of the state change detector MLP.
    "discrete_optimization_strategy": "gs",  # straight-through [st] or gumbel-sigmoid [gs]
    "filter_discrete_updates": False,  # maxima filter for soft output of state change detector.
    "gcn_node": 25,  # custom, original: 19 for cad120, 30 for bimanual, 26 for mphoi
    "hidden_size": 512,  # 512 for cad120 & mphoi; 64 for bimanual
    "message_humans_to_human": False, # custom, original: True
    "message_human_to_objects": True,
    "message_objects_to_human": True,
    "message_objects_to_object": False, # custom, original: True
    "message_geometry_to_objects": True,
    "message_geometry_to_human": True,  # custom, original: False
    "message_segment": True,
    "message_type": "v2",  # v1 [relational] or v2 [non-relational]
    "message_granularity": "v1",  # v1 [generic] or v2 [specific]
    "message_aggregation": "att",  # mean_pooling [mp] or attention [att]
    "object_segment_update_strategy": "ind",  # same_as_human [sah], independent [ind], or conditional_on_human [coh]
    "share_level_mlps": 0,  # whether to share [1] or not [0] the prediction MLPs of the levels.
    "update_segment_threshold": 0.5  # [0.0, 1.0)
}

optimization_dict = {
    "batch_size": 64,  # mphoi:8; cad120:16; bimanual: 32
    "clip_gradient_at": 0.0,
    "epochs": 10, # custom, original: cad120 & mphoi:40; bimanual: 60
    "learning_rate": 1e-4,  # mphoi:1e-4; cad120 & bimanual:1e-3
    "val_fraction": 0.1
}

data_dict = {
    "name": "mphoi",
    "path": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_ground_truth_labels.json",
    "path_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/faster_rcnn.zarr",
    "path_obb_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/object_bounding_boxes.zarr",
    "path_hbb_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/human_bounding_boxes.zarr",
    "path_hps_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/human_pose.zarr",
    "cross_validation_test_subject": "Subject14",  # Subject45, Subject25, Subject14
    "scaling_strategy": None,  # null or "standard"
    "downsampling": 1 # custom, original: 3, 1 = full FPS, 2 = half FPS, ...
}

# root_log_dir = f"{os.getcwd()}/outputs_hiergat/{data_dict['name']}/{metadata_dict['model_name']}"
root_log_dir = f"{os.getcwd()}/outputs_hiergat/custom"
checkpoint_name = (
    f"hs{parameters_dict['hidden_size']}_e{optimization_dict['epochs']}_bs{optimization_dict['batch_size']}_"
    f"lr{optimization_dict['learning_rate']}_{parameters_dict['update_segment_threshold']}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
log_dir = f"{root_log_dir}/{checkpoint_name}"
print("Log directory:", log_dir)
os.makedirs(log_dir, exist_ok=True)

cfg_dict = {
    "resources": {
        "use_gpu": True,
        "num_threads": 32
    },
    "models": {
        "metadata": metadata_dict,
        "parameters": parameters_dict,
        "optimization": optimization_dict,
        "misc": {
            "anticipation_loss_weight": 1.0,
            "budget_loss": {
                "add": False,
                "human_weight": 1.0,
                "object_weight": 1.0
            },
            "first_level_loss_weight": 0.0,  # if positive, first level does frame-level prediction
            "impose_segmentation_pattern": 1,  # 0 [no pattern], 1 [all ones]
            "input_human_segmentation": False,  # (was "flase" in YAML, corrected here)
            "input_object_segmentation": False,
            "make_attention_distance_based": True,  # only meaningful if message_aggregation is attention
            "multi_task_loss_learner": False,
            "pretrained": False,  # unfortunately need two entries for checkpoint name
            "pretrained_path": None,  # specified parameters must match pre-trained model
            "segmentation_loss": {
                "add": False,
                "pretrain": False,
                "sigma": 0.0,  # Gaussian smoothing
                "weight": 1.0
            }
        },
        "logging": {
            "root_log_dir": root_log_dir,
            "checkpoint_name": checkpoint_name,
            "log_dir": log_dir
        },
    },
    "data": data_dict,
}

cfg = from_dict(data_class=Config, data=cfg_dict)

torch.set_num_threads(cfg.resources.num_threads)
model_name, model_input_type = cfg.models.metadata.model_name, cfg.models.metadata.input_type
batch_size, val_fraction = cfg.models.optimization.batch_size, cfg.models.optimization.val_fraction
misc_dict = cfg.get('misc', default_value={})
sigma = misc_dict.get('segmentation_loss', {}).get('sigma', 0.0)
scaling_strategy = cfg.data.scaling_strategy
downsampling = cfg.data.downsampling

action_classes = [
    # human
    'supervise',        # 0
    'collaborate with', # 1
    'assist',           # 2
    'lead',             # 3
    'coordinate with',  # 4
    'listen to',        # 5

    # rebar
    'tie',              # 6
    'erect',            # 7
    'prepare_rebar',    # 8
    'transport',        # 9

    # formwork
    'install',          # 10
    'prepare_formwork', # 11

    # concrete
    'pour',             # 12
    'finish',           # 13

    # equipment         
    'use',              # 14
    'carry',            # 15

    # all
    'inspect',          # 16
    'no interaction',   # 17
]
# new_action_classes = [
#     # rebar
#     'tie',              # 0
#     'erect',            # 1
#     'prepare_rebar',    # 2
#     'transport',        # 3

#     # equipment         
#     'use',              # 4
#     'carry',            # 5

#     # all
#     'inspect',          # 6
#     'no interaction',   # 7
# ]
new_action_classes = [
    # all
    'inspect',          # 0
    
    # rebar
    'prepare_rebar',    # 1
    'erect',            # 2
    'tie',              # 3
]
num_classes = len(new_action_classes)
train_features_dirs = [
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0085_full_MP4_anno_for_labelling_done_yoga_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0083_full_MP4_anno_for_labelling_done_putu_full_temporal_3s/features'), 
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0087_full_MP4_anno_for_labelling_done_arga_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0078_full_MP4_anno_for_labelling_done_anne_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0090_full_MP4_anno_for_labelling_rizky_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0099_full_MP4_anno_for_labelling_done_arga_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0101_full_MP4_anno_for_labelling_done_faridz_full_temporal_3s/features'),
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0098_full_MP4_anno_for_labelling_done_akbar_full_temporal_3s/features'),
    # Path('/root/vs-gats-plaster/deepsort/hiergat_data_3/C0098_full_MP4_anno_for_labelling_done_akbar_full_temporal_3s/features'),
]

val_features_dirs = [
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0100_full_MP4_anno_for_labelling_done_putu_full_temporal_3s/features'),
]

test_features_dirs = [
    Path('/root/vs-gats-plaster/deepsort/hiergat_data/C0074_full_MP4_anno_for_labelling_done_ray_full_temporal_3s/features'),   
]

In [None]:
# Data
def get_features_dirs_df(features_dirs):
    feature_dirs_dict = []
    for dirs in features_dirs:
        for dir in dirs.iterdir():
            action_label_str = str(dir).rsplit('_action_')[-1]
            is_interpolated = 0
            if '_interp' in action_label_str:
                action_label_str = action_label_str.replace('_interp', '')
                is_interpolated = 1
            action_label = int(action_label_str)
            action_class = action_classes[action_label]
            try:
                new_action_label = new_action_classes.index(action_class)
            except ValueError:
                new_action_label = -1
            feature_dirs_dict.append({
                'base_dir': str(dirs),
                'dir': str(dir),
                'action_label': action_label,
                'new_action_label': new_action_label,
                'is_interpolated': is_interpolated,
            })
        
    feature_dirs_df = pd.DataFrame(feature_dirs_dict)
    feature_dirs_df = feature_dirs_df[feature_dirs_df['new_action_label'] != -1] # remove rows where action_label == -1
    return feature_dirs_df

train_feature_dirs_df = get_features_dirs_df(train_features_dirs)
val_feature_dirs_df = get_features_dirs_df(val_features_dirs)
test_feature_dirs_df = get_features_dirs_df(test_features_dirs)

print("Training set action label counts:")
print(train_feature_dirs_df['new_action_label'].value_counts())
print()
print("Validation set action label counts:")
print(val_feature_dirs_df['new_action_label'].value_counts())
print()
print("Testing set action label counts:")
print(test_feature_dirs_df['new_action_label'].value_counts())

In [None]:
train_df_2 = train_feature_dirs_df[(train_feature_dirs_df['new_action_label'] == 2)]
train_df_0 = train_feature_dirs_df[(train_feature_dirs_df['new_action_label'] == 0) & (train_feature_dirs_df['is_interpolated'] == 0)]

# Sample a specified number of rows per group according to a size mapping
def sample_by_group_size(df, groupby_col, size_map, seed=42):
    return (
        df
        .groupby(groupby_col, group_keys=False)[df.columns]
        .apply(lambda x: x.sample(
            n=size_map[x.name],
            random_state=seed,
        ))
    )

# Downsample training data for label 3
train_df_3 = train_feature_dirs_df[(train_feature_dirs_df['new_action_label'] == 3) & (train_feature_dirs_df['is_interpolated'] == 0)]
# print(train_df_3.groupby('base_dir').size()) # Sanity Check
train_size_map_3 = {
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0078_full_MP4_anno_for_labelling_done_anne_full_temporal_3s/features'     : 65,   # 130 -> 65
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0083_full_MP4_anno_for_labelling_done_putu_full_temporal_3s/features'     : 65,   # 240 -> 65
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0085_full_MP4_anno_for_labelling_done_yoga_full_temporal_3s/features'     : 65,   # 200 -> 65
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0087_full_MP4_anno_for_labelling_done_arga_full_temporal_3s/features'     : 65,   # 205 -> 65
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0098_full_MP4_anno_for_labelling_done_akbar_full_temporal_3s/features'    : 2,    # 2   -> 2
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0099_full_MP4_anno_for_labelling_done_arga_full_temporal_3s/features'     : 4,    # 4   -> 4
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0101_full_MP4_anno_for_labelling_done_faridz_full_temporal_3s/features'   : 65,   # 147 -> 65 
}                                                                                                                                   # Total: 331
train_df_3_downsampled = sample_by_group_size(train_df_3, 'base_dir', train_size_map_3)

# Downsample training data for label 1
train_df_1 = train_feature_dirs_df[(train_feature_dirs_df['new_action_label'] == 1) & (train_feature_dirs_df['is_interpolated'] == 0)]
# print(train_df_1.groupby('base_dir').size()) # Sanity Check
train_size_map_1 = {
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0083_full_MP4_anno_for_labelling_done_putu_full_temporal_3s/features'     : 9,    # 56  -> 9
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0085_full_MP4_anno_for_labelling_done_yoga_full_temporal_3s/features'     : 10,   # 181 -> 10
    '/root/vs-gats-plaster/deepsort/hiergat_data/C0099_full_MP4_anno_for_labelling_done_arga_full_temporal_3s/features'     : 10,   # 274 -> 10
                                                                                                                                    # Total: 29
}
train_df_1_downsampled = sample_by_group_size(train_df_1, 'base_dir', train_size_map_1)

# Concatenate downsampled and original label subsets to form the downsampled training set
train_feature_dirs_df_downsampled = pd.concat((train_df_0, train_df_1_downsampled, train_df_2, train_df_3_downsampled))

# Sanity Check
print("Downsampled training set action label counts:")
print(train_feature_dirs_df_downsampled['new_action_label'].value_counts())

In [None]:
train_feature_dirs = train_feature_dirs_df_downsampled['dir'].tolist()
val_feature_dirs   = val_feature_dirs_df['dir'].tolist()
test_feature_dirs  = test_feature_dirs_df['dir'].tolist()

train_data = create_data(train_feature_dirs, action_classes, new_action_classes)
val_data = create_data(val_feature_dirs, action_classes, new_action_classes)
test_data = create_data(test_feature_dirs, action_classes, new_action_classes)

In [None]:
(
    train_human_features_list,
    train_human_boxes_list,
    train_human_poses_list,
    train_object_features_list,
    train_object_boxes_list,
    train_gt_list,
    train_xs_steps,
) = train_data

train_human_features_list = np.stack(train_human_features_list)[:, 0, :, :]     # (N, seq_len, 1024)
train_human_boxes_list = np.stack(train_human_boxes_list)[:, 0, :, :]           # (N, seq_len, 4)
train_human_poses_list = np.stack(train_human_poses_list)[:, 0, :, :]           # (N, seq_len, 17, 2)
train_object_features_list = np.stack(train_object_features_list)[:, :, 0, :]   # (N, seq_len, 1024)
train_object_boxes_list = np.stack(train_object_boxes_list)[:, :, 0, :]         # (N, seq_len, 4)
train_gt_list = np.array([gt['Human1'][0] for gt in train_gt_list])             # (N,)
# train_xs_steps                                                                # (N,)

# Sanity Check
print("train_human_features_list.shape:", train_human_features_list.shape)
print("train_human_boxes_list.shape:", train_human_boxes_list.shape)
print("train_human_poses_list.shape:", train_human_poses_list.shape)
print("train_object_features_list.shape:", train_object_features_list.shape)
print("train_object_boxes_list.shape:", train_object_boxes_list.shape)
print("train_gt_list.shape:", train_gt_list.shape)

In [None]:
# def pad_or_crop(seq, target_len):
#     seq_len = seq.shape[0]
#     if seq_len == target_len:
#         return seq
#     elif seq_len > target_len:
#         return seq[:target_len]
#     else:
#         pad_shape = (target_len - seq_len,) + seq.shape[1:]
#         pad = np.zeros(pad_shape, dtype=seq.dtype)
#         return np.concatenate([seq, pad], axis=0)

# def to_float(x): return x.astype(np.float32, copy=False)
# def add_noise(x, scale=0.01): return to_float(x) + np.random.normal(0, scale, x.shape).astype(np.float32)

# def translate_jitter(boxes, poses, max_shift=0.05):
#     boxes, poses = to_float(boxes), to_float(poses)
#     shift = np.random.uniform(-max_shift, max_shift, size=(1, 2)).astype(np.float32)
#     boxes[..., :2] += shift
#     poses += shift
#     return boxes, poses

# def scale_zoom(boxes, poses, scale_range=(0.9, 1.1)):
#     boxes, poses = to_float(boxes), to_float(poses)
#     scale = np.random.uniform(*scale_range)
#     boxes[..., :2] *= scale
#     boxes[..., 2:] *= scale
#     poses *= scale
#     return boxes, poses

# def horizontal_flip(boxes, poses, img_width=1.0):
#     boxes, poses = to_float(boxes), to_float(poses)
#     boxes[..., 0] = img_width - boxes[..., 0] - boxes[..., 2]
#     poses[..., 0] = img_width - poses[..., 0]
#     return boxes, poses

# def time_warp_stretch(seq, stretch_factor_range=(1.1, 1.3)):
#     seq = to_float(seq)
#     factor = np.random.uniform(*stretch_factor_range)
#     new_len = int(seq.shape[0] * factor)
#     new_idx = np.linspace(0, seq.shape[0]-1, new_len).astype(np.int32)
#     return seq[new_idx, ...]

# def time_warp_compress(seq, compress_factor_range=(0.7, 0.9)):
#     seq = to_float(seq)
#     factor = np.random.uniform(*compress_factor_range)
#     new_len = int(seq.shape[0] * factor)
#     new_idx = np.linspace(0, seq.shape[0]-1, new_len).astype(np.int32)
#     return seq[new_idx, ...]

# from scipy.interpolate import interp1d
# def interpolate_seq(seq, factor_range=(0.7, 1.3)):
#     """
#     Interpolates the sequence to a new length using linear interpolation.
#     Unlike stretch/compress which samples by index, this does real interpolation.
#     """
#     seq = to_float(seq)
#     old_len = seq.shape[0]
#     factor = np.random.uniform(*factor_range)
#     new_len = int(old_len * factor)
#     old_time = np.linspace(0, 1, old_len)
#     new_time = np.linspace(0, 1, new_len)
#     # interp1d supports multi-d interpolation along axis 0
#     f = interp1d(old_time, seq, axis=0, kind='linear')
#     new_seq = f(new_time)
#     return new_seq

# def mixup(x1, x2, alpha=0.2):
#     x1, x2 = to_float(x1), to_float(x2)
#     lam = np.random.beta(alpha, alpha)
#     return lam * x1 + (1 - lam) * x2

# def augment_by_label(
#     label,
#     train_human_features_list,
#     train_human_boxes_list,
#     train_human_poses_list,
#     train_object_features_list,
#     train_object_boxes_list,
#     train_gt_list,
#     train_xs_steps,
#     seed=42,
#     aug_samples_per_func=None
# ):
#     np.random.seed(seed)
#     default_aug = 5
#     if aug_samples_per_func is None:
#         aug_samples_per_func = {}

#     aug_samples_per_func = {k: aug_samples_per_func.get(k, default_aug) for k in [
#         'noise_spatial', 'translate', 'scale', 'flip',
#         'noise_temporal', 'stretch', 'compress', 'interpolate',
#         'noise_feature', 'mixup'
#     ]}

#     print(f"Augmenting for label {label}...")

#     target_indices = np.where(train_gt_list == label)[0]
#     if len(target_indices) == 0:
#         print(f"No samples found for label {label}. Skipping.")
#         return (
#             train_human_features_list,
#             train_human_boxes_list,
#             train_human_poses_list,
#             train_object_features_list,
#             train_object_boxes_list,
#             train_gt_list,
#             train_xs_steps
#         )

#     human_features = train_human_features_list[target_indices]
#     human_boxes = train_human_boxes_list[target_indices]
#     human_poses = train_human_poses_list[target_indices]
#     object_features = train_object_features_list[target_indices]
#     object_boxes = train_object_boxes_list[target_indices]
#     xs_steps = train_xs_steps[target_indices]
#     gts = train_gt_list[target_indices]

#     original_seq_len = human_features.shape[1]

#     augmented = dict(
#         human_features=[], human_boxes=[], human_poses=[],
#         object_features=[], object_boxes=[], gts=[], xs_steps=[]
#     )

#     def append_augmented(hf, hb, hp, of, ob, gt, step):
#         augmented['human_features'].append(hf)
#         augmented['human_boxes'].append(hb)
#         augmented['human_poses'].append(hp)
#         augmented['object_features'].append(of)
#         augmented['object_boxes'].append(ob)
#         augmented['gts'].append(gt)
#         augmented['xs_steps'].append(step)

#     def sample_and_apply(func, n_times):
#         for _ in range(n_times):
#             i = np.random.randint(0, len(gts))
#             hf, hb, hp, of, ob, gt, step = map(deepcopy,
#                 [human_features[i], human_boxes[i], human_poses[i],
#                  object_features[i], object_boxes[i], gts[i], xs_steps[i]])

#             if func == 'noise_spatial':
#                 hb, hp = add_noise(hb), add_noise(hp)
#             elif func == 'translate':
#                 hb, hp = translate_jitter(hb, hp)
#             elif func == 'scale':
#                 hb, hp = scale_zoom(hb, hp)
#             elif func == 'flip':
#                 hb, hp = horizontal_flip(hb, hp)
#             elif func == 'noise_temporal':
#                 hf, of = add_noise(hf), add_noise(of)
#             elif func == 'stretch':
#                 hf = time_warp_stretch(hf)
#                 of = time_warp_stretch(of)
#             elif func == 'compress':
#                 hf = time_warp_compress(hf)
#                 of = time_warp_compress(of)
#             elif func == 'interpolate':
#                 hf = interpolate_seq(hf)
#                 of = interpolate_seq(of)
#                 hb = interpolate_seq(hb)
#                 hp = interpolate_seq(hp)
#                 ob = interpolate_seq(ob)
#             elif func == 'noise_feature':
#                 hf, of = add_noise(hf), add_noise(of)
#             elif func == 'mixup':
#                 j = np.random.randint(0, len(gts))
#                 hf = mixup(hf, human_features[j])
#                 of = mixup(of, object_features[j])

#             hf = pad_or_crop(hf, original_seq_len)
#             of = pad_or_crop(of, original_seq_len)
#             hb = pad_or_crop(hb, original_seq_len)
#             hp = pad_or_crop(hp, original_seq_len)
#             ob = pad_or_crop(ob, original_seq_len)

#             append_augmented(hf, hb, hp, of, ob, gt, step)

#     for func_name, n_samples in aug_samples_per_func.items():
#         sample_and_apply(func_name, n_samples)

#     aug_hf = np.stack(augmented['human_features'])
#     aug_hb = np.stack(augmented['human_boxes'])
#     aug_hp = np.stack(augmented['human_poses'])
#     aug_of = np.stack(augmented['object_features'])
#     aug_ob = np.stack(augmented['object_boxes'])
#     aug_gt = np.array(augmented['gts'])
#     aug_xs = np.array(augmented['xs_steps'])

#     train_human_features_list = np.concatenate([train_human_features_list, aug_hf], axis=0)
#     train_human_boxes_list = np.concatenate([train_human_boxes_list, aug_hb], axis=0)
#     train_human_poses_list = np.concatenate([train_human_poses_list, aug_hp], axis=0)
#     train_object_features_list = np.concatenate([train_object_features_list, aug_of], axis=0)
#     train_object_boxes_list = np.concatenate([train_object_boxes_list, aug_ob], axis=0)
#     train_gt_list = np.concatenate([train_gt_list, aug_gt], axis=0)
#     train_xs_steps = np.concatenate([train_xs_steps, aug_xs], axis=0)

#     print(f"Done! Total: {np.sum(train_gt_list == label)}")

#     return (
#         train_human_features_list,
#         train_human_boxes_list,
#         train_human_poses_list,
#         train_object_features_list,
#         train_object_boxes_list,
#         train_gt_list,
#         train_xs_steps
#     )


# # Augment training data for label 2 (Total: 52)
# train_human_features_list, train_human_boxes_list, train_human_poses_list, \
# train_object_features_list, train_object_boxes_list, train_gt_list, train_xs_steps = augment_by_label(
#     label=2,
#     train_human_features_list=train_human_features_list,
#     train_human_boxes_list=train_human_boxes_list,
#     train_human_poses_list=train_human_poses_list,
#     train_object_features_list=train_object_features_list,
#     train_object_boxes_list=train_object_boxes_list,
#     train_gt_list=train_gt_list,
#     train_xs_steps=train_xs_steps,
#     aug_samples_per_func={
#         # Spatial Augmentation
#         'noise_spatial': 0,
#         'translate': 0,
#         'scale': 0,
#         'flip': 0,
        
#         # Temporal Augmentation
#         'noise_temporal': 0,
#         'stretch': 0,
#         'compress': 0,
        
#         # Feature-Space Augmentation
#         'noise_feature': 0,
#         'mixup': 0,
        
#         # Misc.
#         'interpolate': 52,
#     },
# )

# # Sanity Check
# print(pd.DataFrame(train_gt_list).value_counts())

In [None]:
train_human_features_list = np.stack(train_human_features_list)[:, np.newaxis, :, :]          # (N, 1, seq_len, 1024)
train_human_boxes_list = np.stack(train_human_boxes_list)[:, np.newaxis, :, :]                # (N, 1, seq_len, 4)
train_human_poses_list = np.stack(train_human_poses_list)[:, np.newaxis, :, :]                # (N, 1, seq_len, 17, 2)
train_object_features_list = np.stack(train_object_features_list)[:, :, np.newaxis, :]        # (N, seq_len, 1, 1024)
train_object_boxes_list = np.stack(train_object_boxes_list)[:, :, np.newaxis, :]              # (N, seq_len, 1, 4)
train_gt_list = [{'Human1': [gt]*train_human_features_list.shape[2]} for gt in train_gt_list] # N

# Sanity Check
print("train_human_features_list:", train_human_features_list.shape)
print("train_human_boxes_list:", train_human_boxes_list.shape)
print("train_human_poses_list:", train_human_poses_list.shape)
print("train_object_features_list:", train_object_features_list.shape)
print("train_object_boxes_list:", train_object_boxes_list.shape)
print("train_xs_steps:", len(train_xs_steps))
print("train_gt_list:", len(train_gt_list))

train_data = (
    train_human_features_list, 
    train_human_boxes_list, 
    train_human_poses_list,
    train_object_features_list, 
    train_object_boxes_list, 
    train_gt_list, 
    train_xs_steps,
)

In [None]:
train_loader, scalers, _ = create_data_loader(
    *train_data, 
    model_name, 
    batch_size=batch_size, 
    shuffle=True,
    scaling_strategy=scaling_strategy, 
    sigma=sigma,
    downsampling=downsampling,
)
val_loader, _, _ = create_data_loader(
    *val_data, 
    model_name, 
    batch_size=len(val_data[0]),
    shuffle=False, 
    scalers=scalers, 
    sigma=sigma, 
    downsampling=downsampling,
)
test_loader, _, _ = create_data_loader(
    *test_data, 
    model_name, 
    batch_size=len(test_data[0]),
    shuffle=False, 
    scalers=scalers, 
    sigma=sigma, 
    downsampling=downsampling,
)
input_size = input_size_from_data_loader(train_loader, model_name, model_input_type)
data_info = {'input_size': input_size}

In [None]:
# Model
model_creation_args = cfg.models.parameters
model_creation_args = {**data_info, **model_creation_args.__dict__}
dataset_name = cfg.data.name
model_creation_args['num_classes'] = (num_classes, None)
device = 'cuda' if torch.cuda.is_available() and cfg.resources.use_gpu else 'cpu'
model = TGGCN(feat_dim=1024, **model_creation_args).to(device)
if misc_dict.get('pretrained', False) and misc_dict.get('pretrained_path') is not None:
    state_dict = load_model_weights(misc_dict['pretrained_path'])
    model.load_state_dict(state_dict, strict=False)

In [None]:
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=cfg.models.optimization.learning_rate)
criterion, loss_names = select_loss(model_name, model_input_type, dataset_name, cfg=cfg)
mtll_model = None
if misc_dict.get('multi_task_loss_learner', False):
    loss_types = select_loss_types(model_name, dataset_name, cfg=cfg)
    mask = select_loss_learning_mask(model_name, dataset_name, cfg=cfg)
    mtll_model = MultiTaskLossLearner(loss_types=loss_types, mask=mask).to(device)
    optimizer.add_param_group({'params': mtll_model.parameters()})
# Some config + model training
tensorboard_log_dir = cfg.models.logging.root_log_dir
checkpoint_name = cfg.models.logging.checkpoint_name
fetch_model_data = select_model_data_fetcher(model_name, model_input_type,
                                             dataset_name=dataset_name, **{**misc_dict, **cfg.models.parameters.__dict__})
feed_model_data = select_model_data_feeder(model_name, model_input_type, dataset_name=dataset_name, **misc_dict)
num_main_losses = decide_num_main_losses(model_name, dataset_name, {**misc_dict, **cfg.models.parameters.__dict__})
checkpoint = train(
    model, 
    train_loader, 
    optimizer, 
    criterion, 
    cfg.models.optimization.epochs, 
    device, 
    loss_names,
    clip_gradient_at=cfg.models.optimization.clip_gradient_at,
    fetch_model_data=fetch_model_data, feed_model_data=feed_model_data,
    val_loader=val_loader, 
    mtll_model=mtll_model, 
    num_main_losses=num_main_losses,
    tensorboard_log_dir=tensorboard_log_dir, 
    checkpoint_name=checkpoint_name,
)
# Logging
if cfg.models.logging.log_dir is not None:
    log_dir = cfg.models.logging.log_dir
    checkpoint['scalers'] = scalers
    save_checkpoint(log_dir, checkpoint, checkpoint_name=checkpoint_name, include_timestamp=False)

# Predict

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix

from predict import match_shape, match_att_shape

inspect_model = False

model.eval();

In [None]:
outputs, targets, attentions = [], [], []
for i, dataset in enumerate(test_loader):
    data, target = fetch_model_data(dataset, device=device)
    with torch.no_grad():
        output = feed_model_data(model, data)
    if inspect_model:
        output, attention_scores = output
        attention_scores = [att_score[:, 0] for att_score in attention_scores]
    if num_main_losses is not None:
        output = output[-num_main_losses:]
        target = target[-num_main_losses:]
    if downsampling > 1:
        for i, (out, tgt) in enumerate(zip(output, target)):
            if out.ndim != 4:
                raise RuntimeError(f'Number of dimensions for output is {out.ndim}')
            out = torch.repeat_interleave(out, repeats=downsampling, dim=-2)
            out = match_shape(out, tgt)
            output[i] = out
        if inspect_model:
            a_target = target[0]
            attention_scores = [torch.repeat_interleave(att_score, repeats=downsampling, dim=-2)
                                for att_score in attention_scores]
            attention_scores = [match_att_shape(att_score, a_target) for att_score in attention_scores]
            attentions.append(attention_scores)
    outputs += output
    targets += target

## Predict with Mode

In [None]:
y_pred = torch.argmax(outputs[0], dim=1)
y_pred = y_pred.squeeze(-1).mode(dim=1).values.cpu().numpy()
y_true = targets[0].squeeze(-1).mode(dim=1).values.cpu().numpy()

In [None]:
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)

print("Accuracy:", acc)
print("F1 Score:", f1)
print("Precision:", precision)
print("Recall:", recall)

In [None]:
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(new_action_classes))))

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=new_action_classes, yticklabels=new_action_classes)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

## Predict with Last Timestep

In [None]:
y_pred = outputs[0][:, :, -1, 0].argmax(dim=1).cpu().numpy()
y_true = targets[0][:, -1, 0].cpu().numpy()

In [None]:
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)

print("Accuracy:", acc)
print("F1 Score:", f1)
print("Precision:", precision)
print("Recall:", recall)

In [None]:
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(new_action_classes))))

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=new_action_classes, yticklabels=new_action_classes)
plt.xlabel("Predicted")
plt.ylabel("Ground Truth")
plt.title("Confusion Matrix")
plt.show()