Environment: `py39_torch271`

In [309]:
import os
import random
import itertools
from typing import Optional
from dataclasses import dataclass, asdict
from pathlib import Path

import torch
import numpy as np
from dacite import from_dict
from torch.utils.data import TensorDataset, DataLoader

from pyrutils.torch.train_utils import numpy_to_torch, train, save_checkpoint
from pyrutils.torch.multi_task import MultiTaskLossLearner
from vhoi.data_loading import (
    segmentation_from_output_class, 
    compute_centroid, 
    ignore_last_step_end_flag_general, 
    smooth_segmentation, 
    maybe_scale_input_tensors, 
    input_size_from_data_loader, 
    select_model_data_feeder, 
    select_model_data_fetcher,
)
from vhoi.losses import (
    select_loss, 
    decide_num_main_losses, 
    select_loss_types, 
    select_loss_learning_mask,
)
from vhoi.models import load_model_weights
from vhoi.models_custom import TGGCN_Custom

seed = 42
random.seed(seed)   # Python的随机性
os.environ['PYTHONHASHSEED'] = str(seed)    # 设置Python哈希种子，为了禁止hash随机化，使得实验可复现
np.random.seed(seed)   # numpy的随机性
torch.manual_seed(seed)   # torch的CPU随机性，为CPU设置随机种子
torch.cuda.manual_seed(seed)   # torch的GPU随机性，为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.   torch的GPU随机性，为所有GPU设置随机种子
torch.backends.cudnn.benchmark = False   # if benchmark=True, deterministic will be False
torch.backends.cudnn.deterministic = True   # 选择确定性算法

In [310]:
class DictMixin:
    def get(self, key, default_value=None):
        return getattr(self, key, default_value)

    def as_dict(self):
        return asdict(self)
    
@dataclass
class Resources(DictMixin):
    use_gpu: bool
    num_threads: int

@dataclass
class ModelMetadata(DictMixin):
    model_name: str
    input_type: str

@dataclass
class ModelParameters(DictMixin):
    add_segment_length: int
    add_time_position: int
    time_position_strategy: str
    positional_encoding_style: str
    attention_style: str
    bias: bool
    cat_level_states: int
    discrete_networks_num_layers: int
    discrete_optimization_strategy: str
    filter_discrete_updates: bool
    gcn_node: int
    hidden_size: int
    message_humans_to_human: bool
    message_human_to_objects: bool
    message_objects_to_human: bool
    message_objects_to_object: bool
    message_geometry_to_objects: bool
    message_geometry_to_human: bool
    message_segment: bool
    message_type: str
    message_granularity: str
    message_aggregation: str
    object_segment_update_strategy: str
    share_level_mlps: int
    update_segment_threshold: float

@dataclass
class ModelOptimization(DictMixin):
    batch_size: int
    clip_gradient_at: float
    epochs: int
    learning_rate: float
    val_fraction: float

@dataclass
class BudgetLoss(DictMixin):
    add: bool
    human_weight: float
    object_weight: float

@dataclass
class SegmentationLoss(DictMixin):
    add: bool
    pretrain: bool
    sigma: float
    weight: float

@dataclass
class ModelMisc(DictMixin):
    anticipation_loss_weight: float
    budget_loss: BudgetLoss
    first_level_loss_weight: float
    impose_segmentation_pattern: int
    input_human_segmentation: bool
    input_object_segmentation: bool
    make_attention_distance_based: bool
    multi_task_loss_learner: bool
    pretrained: bool
    pretrained_path: Optional[str]
    segmentation_loss: SegmentationLoss

@dataclass
class ModelLogging(DictMixin):
    root_log_dir: str
    checkpoint_name: str
    log_dir: str

@dataclass
class Models(DictMixin):
    metadata: ModelMetadata
    parameters: ModelParameters
    optimization: ModelOptimization
    misc: ModelMisc
    logging: ModelLogging

@dataclass
class Data(DictMixin):
    name: str
    path: str
    path_zarr: str
    path_obb_zarr: str
    path_hbb_zarr: str
    path_hps_zarr: str
    cross_validation_test_subject: str
    scaling_strategy: Optional[str]
    downsampling: int

@dataclass
class Config(DictMixin):
    resources: Resources
    models: Models
    data: Data
    
metadata_dict = {
    "model_name": "2G-GCN",
    "input_type": "multiple"
}

parameters_dict = {
    "add_segment_length": 0,  # length of the segment to the segment-level rnn. 0 is off and 1 is on.
    "add_time_position": 0,  # absolute time position to the segment-level rnn. 0 is off and 1 is on.
    "time_position_strategy": "s",  # input time position to segment [s] or discrete update [u].
    "positional_encoding_style": "e",  # e [embedding] or p [periodic].
    "attention_style": "v3",  # v1 [concat], v2 [dot-product], v3 [scaled_dot-product], v4 [general]
    "bias": True,
    "cat_level_states": 0,  # concatenate first and second level hidden states for predictors MLPs.
    "discrete_networks_num_layers": 1,  # depth of the state change detector MLP.
    "discrete_optimization_strategy": "gs",  # straight-through [st] or gumbel-sigmoid [gs]
    "filter_discrete_updates": False,  # maxima filter for soft output of state change detector.
    "gcn_node": 25,  # custom, original: 19 for cad120, 30 for bimanual, 26 for mphoi
    "hidden_size": 512,  # 512 for cad120 & mphoi; 64 for bimanual
    "message_humans_to_human": False, # custom, original: True
    "message_human_to_objects": True,
    "message_objects_to_human": True,
    "message_objects_to_object": False, # custom, original: True
    "message_geometry_to_objects": True,
    "message_geometry_to_human": True,  # custom, original: False
    "message_segment": True,
    "message_type": "v2",  # v1 [relational] or v2 [non-relational]
    "message_granularity": "v1",  # v1 [generic] or v2 [specific]
    "message_aggregation": "att",  # mean_pooling [mp] or attention [att]
    "object_segment_update_strategy": "ind",  # same_as_human [sah], independent [ind], or conditional_on_human [coh]
    "share_level_mlps": 0,  # whether to share [1] or not [0] the prediction MLPs of the levels.
    "update_segment_threshold": 0.5  # [0.0, 1.0)
}

optimization_dict = {
    "batch_size": 8,  # mphoi:8; cad120:16; bimanual: 32
    "clip_gradient_at": 0.0,
    "epochs": 10, # custom, original: cad120 & mphoi:40; bimanual: 60
    "learning_rate": 1e-4,  # mphoi:1e-4; cad120 & bimanual:1e-3
    "val_fraction": 0.1
}

data_dict = {
    "name": "mphoi",
    "path": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_ground_truth_labels.json",
    "path_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/faster_rcnn.zarr",
    "path_obb_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/object_bounding_boxes.zarr",
    "path_hbb_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/human_bounding_boxes.zarr",
    "path_hps_zarr": f"{os.getcwd()}/data/MPHOI/MPHOI/mphoi_derived_features/human_pose.zarr",
    "cross_validation_test_subject": "Subject14",  # Subject45, Subject25, Subject14
    "scaling_strategy": None,  # null or "standard"
    "downsampling": 1 # custom, original: 3, 1 = full FPS, 2 = half FPS, ...
}

# root_log_dir = f"{os.getcwd()}/outputs_hiergat/{data_dict['name']}/{metadata_dict['model_name']}"
root_log_dir = f"{os.getcwd()}/outputs_hiergat/custom"
checkpoint_name = (
    f"hs{parameters_dict['hidden_size']}_e{optimization_dict['epochs']}_bs{optimization_dict['batch_size']}_"
    f"lr{optimization_dict['learning_rate']}_{parameters_dict['update_segment_threshold']}_{data_dict['cross_validation_test_subject']}"
)
log_dir = f"{root_log_dir}/{checkpoint_name}"
os.makedirs(log_dir, exist_ok=True)

cfg_dict = {
    "resources": {
        "use_gpu": True,
        "num_threads": 32
    },
    "models": {
        "metadata": metadata_dict,
        "parameters": parameters_dict,
        "optimization": optimization_dict,
        "misc": {
            "anticipation_loss_weight": 1.0,
            "budget_loss": {
                "add": False,
                "human_weight": 1.0,
                "object_weight": 1.0
            },
            "first_level_loss_weight": 0.0,  # if positive, first level does frame-level prediction
            "impose_segmentation_pattern": 1,  # 0 [no pattern], 1 [all ones]
            "input_human_segmentation": False,  # (was "flase" in YAML, corrected here)
            "input_object_segmentation": False,
            "make_attention_distance_based": True,  # only meaningful if message_aggregation is attention
            "multi_task_loss_learner": False,
            "pretrained": False,  # unfortunately need two entries for checkpoint name
            "pretrained_path": None,  # specified parameters must match pre-trained model
            "segmentation_loss": {
                "add": False,
                "pretrain": False,
                "sigma": 0.0,  # Gaussian smoothing
                "weight": 1.0
            }
        },
        "logging": {
            "root_log_dir": root_log_dir,
            "checkpoint_name": checkpoint_name,
            "log_dir": log_dir
        },
    },
    "data": data_dict,
}

cfg = from_dict(data_class=Config, data=cfg_dict)

torch.set_num_threads(cfg.resources.num_threads)
model_name, model_input_type = cfg.models.metadata.model_name, cfg.models.metadata.input_type
batch_size, val_fraction = cfg.models.optimization.batch_size, cfg.models.optimization.val_fraction
misc_dict = cfg.get('misc', default_value={})
sigma = misc_dict.get('segmentation_loss', {}).get('sigma', 0.0)
scaling_strategy = cfg.data.scaling_strategy
downsampling = cfg.data.downsampling

num_classes = 18
features_dir = Path('/root/vs-gats-plaster/deepsort/outputs/anno_test_8/features')

In [311]:
def create_data(feature_dirs, downsampling: int = 1):
    human_features_list = []
    human_boxes_list = []
    human_poses_list = []
    object_features_list = []
    object_boxes_list = []
    gt_list = []
    xs_steps = []

    for feature_dir in feature_dirs:
        # Load and store human (subject) features
        subject_visual_features = np.load(feature_dir / 'subject_visual_features.npy')
        subject_boxes = np.load(feature_dir / 'subject_boxes.npy')
        subject_poses = np.zeros((subject_visual_features.shape[0], 17, 2))
        
        human_features_list.append([subject_visual_features])
        human_boxes_list.append([subject_boxes])
        human_poses_list.append([subject_poses])
        
        # Load and store object features
        object_visual_features = np.load(feature_dir / 'object_visual_features.npy')
        object_boxes = np.load(feature_dir / 'object_boxes.npy')
        
        object_features_list.append(object_visual_features[:, np.newaxis, :])
        object_boxes_list.append(object_boxes[:, np.newaxis, :])
        
        # Extract and store ground-truth action label
        action_label = int(str(feature_dir).split('_action_')[-1])
        seq_len = subject_visual_features.shape[0]
        gt_list.append({
            'Human1': [action_label] * seq_len,
        })
        
        # Store number of steps
        num_steps = len(subject_visual_features[downsampling - 1::downsampling])
        xs_steps.append(num_steps)

    xs_steps = np.array(xs_steps, dtype=np.float32)

    return (
        human_features_list,
        human_boxes_list,
        human_poses_list,
        object_features_list,
        object_boxes_list,
        gt_list,
        xs_steps,
    )

def split_list(lst, ratio=0.2):
    n = len(lst)
    split_idx = int(n * ratio) # number of items in first list
    list1 = lst[split_idx:]
    list2 = lst[:split_idx]
    return list1, list2

In [312]:
def assemble_mphoi_frame_level_recurrent_human(
    human_features_list, human_poses_list, object_boxes_list, gt_list,
    downsampling: int = 1, 
    # test_data: bool = False, 
    max_no_objects: int = 4
):
    xs_h, xs_hp, x_obb = [], [], []
    max_len, max_len_downsampled = 0, 0

    if max_no_objects is None:
        max_no_objects = max(
            max(len(frame) for frame in video) for video in object_boxes_list
        )

    for humans, poses, objects_bounding_box in zip(human_features_list, human_poses_list, object_boxes_list):
        num_humans = len(humans)
        max_len = max(max_len, humans[0].shape[0])

        humans_ds = [h[downsampling - 1::downsampling] for h in humans]
        poses_ds  = [p[downsampling - 1::downsampling] / 1000 for p in poses]
        max_len_downsampled = max(
            max_len_downsampled,
            max(h.shape[0] for h in humans_ds),
            max(p.shape[0] for p in poses_ds),
            objects_bounding_box[downsampling - 1::downsampling].shape[0],
        )
        xs_h.append(humans_ds)
        xs_hp.append(poses_ds)

        obb_ds = objects_bounding_box[downsampling - 1::downsampling] / 1000
        x_obb.append(obb_ds)

    xs_obb = []
    for video in x_obb:
        bb = []
        for frame in video:
            b = np.zeros((max_no_objects, 4))
            n = min(len(frame), max_no_objects)
            b[:n] = frame[:n]
            b = b.reshape(max_no_objects * 2, 2)
            bb.append(b)
        xs_obb.append(bb)

    # keypoints = [1, 2, 4, 6, 7, 11, 13, 14, 27]
    keypoints = list(range(human_poses_list[0][0].shape[1]))
    xs_h_with_context = []
    for i, (humans_ds, poses_ds, obb_video) in enumerate(zip(xs_h, xs_hp, xs_obb)):
        num_humans = len(humans_ds)
        humans_context = [[] for _ in range(num_humans)]

        for j in range(len(humans_ds[0])):
            obb = obb_video[j]

            if j + 1 < len(humans_ds[0]):
                next_poses = [p[j+1][keypoints] for p in poses_ds]
                pose_velos = [(next_pose - poses_ds[h][j][keypoints]) * 100 for h, next_pose in enumerate(next_poses)]
                obb_velo = (obb_video[j+1] - obb) * 100
            else:
                pose_velos = [np.zeros((len(keypoints), 2)) for _ in poses_ds]
                obb_velo = np.zeros((max_no_objects * 2, 2))

            obbvelo = np.hstack((obb, obb_velo)).reshape(1, -1)

            context = []
            for h in range(num_humans):
                pose = poses_ds[h][j][keypoints]
                velo = pose_velos[h]
                posevelo = np.hstack((pose, velo)).reshape(1, -1)
                context.append(posevelo[0])

            context = np.concatenate(context + [obbvelo[0]])

            for h in range(num_humans):
                h_con = np.concatenate((humans_ds[h][j], context))
                humans_context[h].append(h_con)

        xs_h_with_context.append([np.array(hc) for hc in humans_context])

    feature_size = xs_h_with_context[0][0].shape[-1]
    num_humans = len(xs_h_with_context[0])
    x_hs = np.full([len(xs_h_with_context), max_len_downsampled, num_humans, feature_size],
                   fill_value=np.nan, dtype=np.float32)

    for m, humans in enumerate(xs_h_with_context):
        for h, feats in enumerate(humans):
            seq_len = min(len(feats), max_len_downsampled)
            x_hs[m, :seq_len, h] = feats[:seq_len]

    xs = [x_hs]

    # ----------------------
    # Outputs
    # ----------------------
    y_rec_hs = np.full([len(x_hs), max_len_downsampled, num_humans], fill_value=-1, dtype=np.int64)
    y_pred_hs = np.full_like(y_rec_hs, fill_value=-1)
    
    for m, video_hands_ground_truth in enumerate(gt_list):
        for h in range(num_humans):
            human_key = f"Human{h+1}"
            if human_key not in video_hands_ground_truth:
                continue

            y_h = video_hands_ground_truth[human_key]

            # Ground truth (downsampled)
            y_h_ds = y_h[downsampling - 1::downsampling]
            seq_len = min(len(y_h_ds), y_rec_hs.shape[1])
            y_rec_hs[m, :seq_len, h] = y_h_ds[:seq_len]

            # Prediction: shift labels forward
            y_h_p = np.roll(y_h, -1)
            y_h_p[-1] = -1  # last frame has no "next"
            y_h_p_ds = y_h_p[downsampling - 1::downsampling]
            seq_len_p = min(len(y_h_p_ds), y_pred_hs.shape[1])
            y_pred_hs[m, :seq_len_p, h] = y_h_p_ds[:seq_len_p]
            # y_pred_hs[m, :seq_len, h] = y_h_ds[:seq_len]
            
    x_hs_segmentation = segmentation_from_output_class(y_rec_hs, segmentation_type="input")
    xs.append(x_hs_segmentation)

    y_hs_segmentation = segmentation_from_output_class(y_rec_hs, segmentation_type="output")
    ys = [y_rec_hs, y_pred_hs, y_hs_segmentation]

    return xs, ys

# xs, ys = assemble_mphoi_frame_level_recurrent_human(human_features_list, human_poses_list, object_boxes_list, gt_list)

# for ys_i, xs_i in zip(ys, xs):
#     print("ys_i.shape:", ys_i.shape)
#     print("ys_i[0]:", ys_i[0].flatten())
#     print("ys_i[1]:", ys_i[1].flatten())
#     print("ys_i[2]:", ys_i[2].flatten())
#     print()
#     print("xs_i:", xs_i.shape)
#     print()

In [313]:
def assemble_mphoi_frame_level_recurrent_objects(object_features_list, downsampling: int = 1):
    xs_objects = []
    max_len, max_len_downsampled, max_num_objects = 0, 0, 0
    for objects in object_features_list:
        max_len = max(max_len, objects.shape[0])
        max_num_objects = max(max_num_objects, objects.shape[1])
        objects = objects[downsampling - 1::downsampling]
        max_len_downsampled = max(max_len_downsampled, objects.shape[0])
        xs_objects.append(objects)
    feature_size = xs_objects[-1].shape[-1]
    x_objects = np.full([len(xs_objects), max_len_downsampled, max_num_objects, feature_size],
                        fill_value=np.nan, dtype=np.float32)
    x_objects_mask = np.zeros([len(xs_objects), max_num_objects], dtype=np.float32)
    for m, x_o in enumerate(xs_objects):
        x_objects[m, :x_o.shape[0], :x_o.shape[1], :] = x_o
        x_objects_mask[m, :x_o.shape[1]] = 1.0
    xs = [x_objects, x_objects_mask]
    return xs

# xs_objects = assemble_mphoi_frame_level_recurrent_objects(object_features_list)
# for xs_i in xs_objects:
#     print(xs_i.shape)

In [314]:
def assemble_mphoi_human_human_distances(human_boxes_list, downsampling: int = 1):
    """
    Compute pairwise human-human distances for multiple humans across videos.

    Args:
        human_boxes_list: list of list of human bounding boxes per video
                          (outer list: videos, inner list: humans, array: frames x 4)
        downsampling:     frame downsampling factor

    Returns:
        x_hh_dists: tensor [num_videos, max_len, N, N] with pairwise distances
    """
    mphoi_dims = np.array([3840, 2160], dtype=np.float32)
    max_len, max_num_humans = 0, 0
    all_dists = []

    for video_bbs in human_boxes_list:
        num_humans = len(video_bbs)
        max_num_humans = max(max_num_humans, num_humans)

        # Downsample and compute centroids
        centroids = []
        for bb in video_bbs:
            bb = bb[downsampling - 1::downsampling]
            c = compute_centroid(bb) / mphoi_dims
            centroids.append(c)

        # Length of this video (frames)
        max_len = max(max_len, centroids[0].shape[0])

        # Compute pairwise distances (frames x N x N)
        T = centroids[0].shape[0]
        dists_matrix = np.zeros((T, num_humans, num_humans), dtype=np.float32)

        for i, j in itertools.combinations(range(num_humans), 2):
            d = np.linalg.norm(centroids[i] - centroids[j], ord=2, axis=-1)
            dists_matrix[:, i, j] = d
            dists_matrix[:, j, i] = d

        all_dists.append(dists_matrix)

    # Pad into a tensor [num_videos, max_len, max_num_humans, max_num_humans]
    tensor_shape = [len(all_dists), max_len, max_num_humans, max_num_humans]
    x_hh_dists = np.full(tensor_shape, fill_value=np.nan, dtype=np.float32)

    for m, dists_matrix in enumerate(all_dists):
        T, N, _ = dists_matrix.shape
        x_hh_dists[m, :T, :N, :N] = dists_matrix

    return x_hh_dists

# xs_hh_dists = assemble_mphoi_human_human_distances(human_boxes_list)
# print(xs_hh_dists.shape)

In [315]:
def assemble_mphoi_human_object_distances(human_boxes_list, object_boxes_list, downsampling: int = 1):
    """
    Compute human-object distances for multiple humans and objects across videos.

    Args:
        human_boxes_list:  list of list of human bounding boxes per video
                           (outer list: videos, inner list: humans, array: frames x 4)
        object_boxes_list: list of object bounding box arrays per video (frames x num_objects x 4)
        downsampling:      frame downsampling factor

    Returns:
        x_ho_dists: tensor [num_videos, max_len, max_num_humans, max_num_objects]
    """
    mphoi_dims = np.array([3840, 2160], dtype=np.float32)
    max_len, max_num_humans, max_num_objects = 0, 0, 0
    all_dists = []

    for video_bbs, obj_bbs in zip(human_boxes_list, object_boxes_list):
        num_humans = len(video_bbs)

        # Downsample humans → centroids
        human_centroids = []
        for bb in video_bbs:
            bb = bb[downsampling - 1::downsampling]
            c = compute_centroid(bb) / mphoi_dims
            human_centroids.append(c)

        # Downsample objects → centroids
        obj_bbs = obj_bbs[downsampling - 1::downsampling]
        obj_centroids = compute_centroid(obj_bbs) / mphoi_dims

        T = obj_centroids.shape[0]
        max_len = max(max_len, T)
        max_num_humans = max(max_num_humans, num_humans)
        max_num_objects = max(max_num_objects, obj_centroids.shape[1])

        # Compute distances [frames, num_humans, num_objects]
        dists_matrix = np.zeros((T, num_humans, obj_centroids.shape[1]), dtype=np.float32)
        for h, h_c in enumerate(human_centroids):
            d = np.linalg.norm(obj_centroids - np.expand_dims(h_c, axis=1), ord=2, axis=-1)
            dists_matrix[:, h, :] = d

        all_dists.append(dists_matrix)

    # Pad into a tensor [num_videos, max_len, max_num_humans, max_num_objects]
    tensor_shape = [len(all_dists), max_len, max_num_humans, max_num_objects]
    x_ho_dists = np.full(tensor_shape, fill_value=np.nan, dtype=np.float32)

    for m, dists_matrix in enumerate(all_dists):
        T, H, O = dists_matrix.shape
        x_ho_dists[m, :T, :H, :O] = dists_matrix

    return x_ho_dists

# xs_ho_dists = assemble_mphoi_human_object_distances(human_boxes_list, object_boxes_list)
# print(xs_ho_dists.shape)

In [316]:
def assemble_mphoi_object_object_distances(object_boxes_list, downsampling: int = 1):
    """
    Compute pairwise object-object distances across videos.

    Args:
        object_boxes_list: list of object bounding box arrays per video (frames x num_objects x 4)
        downsampling:      frame downsampling factor

    Returns:
        x_oo_dists: tensor [num_videos, max_len, max_num_objects, max_num_objects]
    """
    mphoi_dims = np.array([3840, 2160], dtype=np.float32)
    max_len, max_num_objects = 0, 0
    all_dists = []

    for obj_bbs in object_boxes_list:
        # Downsample and compute centroids
        obj_bbs = obj_bbs[downsampling - 1::downsampling]
        objs_centroid = compute_centroid(obj_bbs) / mphoi_dims   # (frames, num_objects, 2)
        num_objects = objs_centroid.shape[1]

        # Compute pairwise distances per frame
        dists = []
        for k in range(num_objects):
            kth_object_centroid = objs_centroid[:, k:k+1]  # (frames, 1, 2)
            kth_dist = np.linalg.norm(objs_centroid - kth_object_centroid, ord=2, axis=-1)  # (frames, num_objects)
            dists.append(kth_dist)

        dists = np.stack(dists, axis=1)  # (frames, num_objects, num_objects)
        all_dists.append(dists)

        max_len = max(max_len, obj_bbs.shape[0])
        max_num_objects = max(max_num_objects, num_objects)

    # Pad into tensor [num_videos, max_len, max_num_objects, max_num_objects]
    tensor_shape = [len(all_dists), max_len, max_num_objects, max_num_objects]
    x_oo_dists = np.full(tensor_shape, fill_value=np.nan, dtype=np.float32)

    for m, dists in enumerate(all_dists):
        T, O1, O2 = dists.shape
        x_oo_dists[m, :T, :O1, :O2] = dists

    return x_oo_dists

# xs_oo_dists = assemble_mphoi_object_object_distances(object_boxes_list)
# print(xs_oo_dists.shape)

In [317]:
def assemble_mphoi_tensors(
    human_features_list,
    human_boxes_list,
    human_poses_list,
    object_features_list,
    object_boxes_list,
    gt_list,
    xs_steps,
    model_name: str, 
    sigma: float = 0.0, 
    downsampling: int = 1,
):
    xs, ys = assemble_mphoi_frame_level_recurrent_human(human_features_list, human_poses_list, object_boxes_list, gt_list)
    xs_objects = assemble_mphoi_frame_level_recurrent_objects(object_features_list, downsampling=downsampling)
    if model_name == '2G-GCN':
        if sigma:
            ys[2] = ignore_last_step_end_flag_general(ys[2])
        ys[2] = smooth_segmentation(ys[2], sigma)
        ys_budget = ys[2]
        xs_hh_dists = assemble_mphoi_human_human_distances(human_boxes_list, downsampling=downsampling)
        xs_ho_dists = assemble_mphoi_human_object_distances(human_boxes_list, object_boxes_list, downsampling=downsampling)
        xs_oo_dists = assemble_mphoi_object_object_distances(object_boxes_list, downsampling=downsampling)
        xs = xs[:1] + xs_objects + xs[1:] + [xs_hh_dists, xs_ho_dists, xs_oo_dists, xs_steps]
        ys = [ys_budget] + ys[2:] + ys[:2]
        ys += ys[-2:]
    else:
        raise ValueError(f'MPHOI code not implemented for {model_name} yet.')
    return xs, ys

In [318]:
def create_data_loader(
    human_features_list,
    human_boxes_list,
    human_poses_list,
    object_features_list,
    object_boxes_list,
    gt_list,
    xs_steps,
    model_name: str, 
    batch_size: int, 
    shuffle: bool,
    scaling_strategy: Optional[str] = None, 
    scalers: Optional[dict] = None, 
    sigma: float = 0.0,
    downsampling: int = 1, 
):
    x, y = assemble_mphoi_tensors(
        human_features_list,
        human_boxes_list,
        human_poses_list,
        object_features_list,
        object_boxes_list,
        gt_list,
        xs_steps,
        model_name=model_name, 
        sigma=sigma, 
        downsampling=downsampling, 
    )
    
    x, scalers = maybe_scale_input_tensors(x, model_name, scaling_strategy=scaling_strategy, scalers=scalers)
    x = [np.nan_to_num(ix, copy=False, nan=0.0) for ix in x]
    x, y = numpy_to_torch(*x), numpy_to_torch(*y)
    dataset = TensorDataset(*(x + y))
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0,
                             pin_memory=False, drop_last=False)
    segmentations = None
    return data_loader, scalers, segmentations

In [319]:
# Data
feature_dirs = [i for i in features_dir.iterdir()]
train_feature_dirs, val_feature_dirs = split_list(feature_dirs, ratio=0.2)

train_data = create_data(train_feature_dirs)
val_data = create_data(val_feature_dirs)

train_loader, scalers, _ = create_data_loader(
    *train_data, 
    model_name, 
    batch_size=batch_size, 
    shuffle=True,
    scaling_strategy=scaling_strategy, 
    sigma=sigma,
    downsampling=downsampling,
)
val_loader, _, _ = create_data_loader(
    *val_data, 
    model_name, 
    batch_size=len(val_data[0]),
    shuffle=False, 
    scalers=scalers, 
    sigma=sigma, 
    downsampling=downsampling,
)
input_size = input_size_from_data_loader(train_loader, model_name, model_input_type)
data_info = {'input_size': input_size}

In [320]:
# Model
model_creation_args = cfg.models.parameters
model_creation_args = {**data_info, **model_creation_args.__dict__}
dataset_name = cfg.data.name
model_creation_args['num_classes'] = (num_classes, None)
device = 'cuda' if torch.cuda.is_available() and cfg.resources.use_gpu else 'cpu'
model = TGGCN_Custom(feat_dim=1024, **model_creation_args).to(device)
if misc_dict.get('pretrained', False) and misc_dict.get('pretrained_path') is not None:
    state_dict = load_model_weights(misc_dict['pretrained_path'])
    model.load_state_dict(state_dict, strict=False)
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=cfg.models.optimization.learning_rate)
criterion, loss_names = select_loss(model_name, model_input_type, dataset_name, cfg=cfg)
mtll_model = None
if misc_dict.get('multi_task_loss_learner', False):
    loss_types = select_loss_types(model_name, dataset_name, cfg=cfg)
    mask = select_loss_learning_mask(model_name, dataset_name, cfg=cfg)
    mtll_model = MultiTaskLossLearner(loss_types=loss_types, mask=mask).to(device)
    optimizer.add_param_group({'params': mtll_model.parameters()})
# Some config + model training
tensorboard_log_dir = cfg.models.logging.root_log_dir
checkpoint_name = cfg.models.logging.checkpoint_name
fetch_model_data = select_model_data_fetcher(model_name, model_input_type,
                                             dataset_name=dataset_name, **{**misc_dict, **cfg.models.parameters.__dict__})
feed_model_data = select_model_data_feeder(model_name, model_input_type, dataset_name=dataset_name, **misc_dict)
num_main_losses = decide_num_main_losses(model_name, dataset_name, {**misc_dict, **cfg.models.parameters.__dict__})
checkpoint = train(
    model, 
    train_loader, 
    optimizer, 
    criterion, 
    cfg.models.optimization.epochs, 
    device, 
    loss_names,
    clip_gradient_at=cfg.models.optimization.clip_gradient_at,
    fetch_model_data=fetch_model_data, feed_model_data=feed_model_data,
    val_loader=val_loader, 
    mtll_model=mtll_model, 
    num_main_losses=num_main_losses,
    tensorboard_log_dir=tensorboard_log_dir, 
    checkpoint_name=checkpoint_name,
)
# Logging
if cfg.models.logging.log_dir is not None:
    log_dir = cfg.models.logging.log_dir
    checkpoint['scalers'] = scalers
    save_checkpoint(log_dir, checkpoint, checkpoint_name=checkpoint_name, include_timestamp=False)


Epoch: [   1/  10]


     (Train) Loss:  3.3070   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  1.6399   NLL_SAP:  1.6671
(Validation) Loss:  4.4901   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  2.2077   NLL_SAP:  2.2824

Epoch: [   2/  10]
     (Train) Loss:  3.1222   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  1.5560   NLL_SAP:  1.5662
(Validation) Loss:  4.3519   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  2.1906   NLL_SAP:  2.1614

Epoch: [   3/  10]
     (Train) Loss:  3.0030   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  1.4934   NLL_SAP:  1.5096
(Validation) Loss:  3.8022   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F:  0.0000   NLL_SAR:  1.9248   NLL_SAP:  1.8774

Epoch: [   4/  10]
     (Train) Loss:  2.8529   B_HS:  0.0000   BCE_HS:  0.0000   NLL_SAR_F:  0.0000   NLL_SAP_F: