Replacement Metric Flow (to work around the lack of JAAD ground truth)
======================================================================

Our aim is to get a model that can be used to "fix" the 2D poses of the JAAD dataset. General idea:
1. JAAD does not have the ground truth, only 2D poses that are extracted from the images using OpenPose which are not very accurate.
2. We have datasets containing the 3D poses in various formats, namely SMPL (we use CMU and HumanEva from AMASS) and CARLA skeleton (we use data recorded from the simulator, denoted as CarlaRec dataset). We can project the 3D poses (joints positions) to 2D poses using the camera parameters and get a "good" ground truth.
3. We need to train a model using "good" datasets and then run the inference on JAAD and save the results.
4. Then, we need to train another model (but with same architecture), this time using only the results from the inference on JAAD as the input and then run the inference on "good" datasets.
5. Finally, we can compare the results from the second model with ground truth from "good" datasets and get a metric. If the model trained on "fixed" JAAD gets satisfactory results, we can assume that JAAD was correctly fixed and our model is good at generalizing/fixing the 2D poses.

Pick a model to use
-------------------

In [None]:
model_name = 'LinearAE2D'
model_args = []

...and setup some common imports/data/training args as well:

In [None]:
from pedestrians_video_2_carla.modeling import main
import os
import glob

In [None]:
common_args = [
    "--flow=autoencoder",
    "--batch_size=256",
    "--input_nodes=CARLA_SKELETON",
    "--output_nodes=CARLA_SKELETON",
    "--loss_modes",
    "loc_2d",
    "--check_val_every_n_epoch=1",
    "--renderers",
    "none",
    "--gpus=1",
    "--clip_length=15",
    "--clip_offset=15",
    "--prefer_tensorboard",
    "--mask_missing_joints=true",
    "--disable_lr_scheduler",
    "--num_workers=4",
]

Train a model using CarlaRec + AMASS
---------------------------------------

In [None]:
model_one_train = [
    *common_args,
    "--data_module_name=JAADCarlaRecAMASS",
    "--missing_point_probability=0.1",
    "--noise=gaussian",
    "--noise_param=1.0",
    "--train_proportions",
    "0",
    "0.5",
    "0.5",
    "--val_proportions",
    "0",
    "-1",
    "-1",
    "--limit_val_batches=10",
    f"--seed=1",
    #
    "--max_epochs=100",
    #
    f"--movements_model_name={model_name}",
    *model_args,
]
model_one_log_dir = main(model_one_train)

Get the trained model an run the inference on JAAD
--------------------------------------------------

In [None]:
model_checkpoint = glob.glob(os.path.join(model_one_log_dir, 'checkpoints', '*.*')).pop()
model_one_predict = [
    *common_args,
    "--mode=predict",
    "--data_module_name=JAADOpenPose",
    "--predict_sets",
    "train",
    "val",
    f"--ckpt_path={model_checkpoint}",
    f"--seed=2",
    f"--movements_model_name={model_name}",
    *model_args,
]
main(model_one_predict)

Train model on the captured data
--------------------------------

The captured data is from JAAD, but it is in CARLA_SKELETON format. Therefore, we need to remember to force the Dataset to use CARLA_SKELETON format.

In [None]:
run_one_name = os.path.basename(model_one_log_dir)
jaad_subsets_dir = f'/outputs/JAADOpenPoseDataModulePredictions/subsets/598268da7cf7978df3eed284e07970c5/{run_one_name}'

In [None]:
model_two_train = [
    *common_args,
    "--data_module_name=JAADOpenPose",
    "--data_nodes=CARLA_SKELETON",
    "--missing_point_probability=0.1",
    "--noise=gaussian",
    "--noise_param=1.0",
    f"--seed=3",
    #
    "--max_epochs=100",
    #
    f"--movements_model_name={model_name}",
    f"--subsets_dir={jaad_subsets_dir}",
    *model_args,
]
model_two_log_dir = main(model_two_train)

Get the trained model and run the inference on CarlaRec + AMASS
---------------------------------------------------------------

In [None]:
model_two_checkpoint = glob.glob(os.path.join(model_two_log_dir, 'checkpoints', '*.*')).pop()
model_two_predict_a = [
    *common_args,
    "--mode=predict",
    "--data_module_name=CarlaRecorded",
    "--predict_sets",
    "train",
    "val",
    f"--ckpt_path={model_two_checkpoint}",
    f"--seed=4",
    f"--movements_model_name={model_name}",
    *model_args,
]
model_two_predict_b = [
    *common_args,
    "--mode=predict",
    "--data_module_name=AMASS",
    "--predict_sets",
    "train",
    "val",
    f"--ckpt_path={model_two_checkpoint}",
    f"--seed=5",
    f"--movements_model_name={model_name}",
    *model_args,
]

main(model_two_predict_a)
main(model_two_predict_b)

Calculate the target metrics
----------------------------

In [None]:
run_two_name = 'relative-granite' # os.path.basename(model_two_log_dir)
gt_carla_rec_subsets_dir = '/outputs/CarlaRecordedDataModule/subsets/7db72382d7a13dc69f8d4919228ea591'
gt_amass_subsets_dir = '/outputs/AMASSDataModule/subsets/136b27b5869bd98ec133f22e327f6ec4'
pred_carla_rec_subsets_dir = f'/outputs/CarlaRecordedDataModulePredictions/subsets/7db72382d7a13dc69f8d4919228ea591/{run_two_name}'
pred_amass_subsets_dir = f'/outputs/AMASSDataModulePredictions/subsets/136b27b5869bd98ec133f22e327f6ec4/{run_two_name}'

In [None]:
from pedestrians_video_2_carla.transforms.normalization import Normalizer
from pedestrians_video_2_carla.transforms.hips_neck_bbox_fallback import HipsNeckBBoxFallbackExtractor
from pedestrians_video_2_carla.data.carla.skeleton import CARLA_SKELETON
from pedestrians_video_2_carla.data.carla.carla_recorded_dataset import CarlaRecordedDataset
from pedestrians_video_2_carla.data.smpl.skeleton import SMPL_SKELETON
from pedestrians_video_2_carla.data.smpl.smpl_dataset import SMPLDataset

common_kwargs = {
    'input_nodes': CARLA_SKELETON,
    'skip_metadata': True,
    'transform': Normalizer(HipsNeckBBoxFallbackExtractor(CARLA_SKELETON))
}

gt_carla_rec = CarlaRecordedDataset(
    set_filepath=os.path.join(gt_carla_rec_subsets_dir, 'val.hdf5'),
    data_nodes=CARLA_SKELETON,
    **common_kwargs
)
pred_carla_rec = CarlaRecordedDataset(
    set_filepath=os.path.join(pred_carla_rec_subsets_dir, 'val.hdf5'),
    data_nodes=CARLA_SKELETON,
    **common_kwargs
)

gt_amass = SMPLDataset(
    set_filepath=os.path.join(gt_amass_subsets_dir, 'val.hdf5'),
    data_nodes=SMPL_SKELETON,
    **{
        **common_kwargs,
        'transform': Normalizer(HipsNeckBBoxFallbackExtractor(SMPL_SKELETON))
    }
)
pred_amass = SMPLDataset(
    set_filepath=os.path.join(pred_amass_subsets_dir, 'val.hdf5'),
    data_nodes=CARLA_SKELETON,
    **common_kwargs
)

In [None]:
from pedestrians_video_2_carla.transforms.hips_neck import HipsNeckExtractor
from torchmetrics import MetricCollection
from torchmetrics import MeanSquaredError
from pedestrians_video_2_carla.metrics.multiinput_wrapper import MultiinputWrapper
from pedestrians_video_2_carla.metrics.pck import PCK

outputs_key = 'projection_2d_transformed'

def get_normalization_tensor(x):
    return HipsNeckExtractor(
        input_nodes=CARLA_SKELETON
    ).get_shift_scale(x)[1]

metrics_collection = MetricCollection({
    'MSE': MultiinputWrapper(
        MeanSquaredError(dist_sync_on_step=True),
        outputs_key, outputs_key,
        input_nodes=CARLA_SKELETON,
        output_nodes=CARLA_SKELETON,
        mask_missing_joints=True,
    ),
    'PCKhn@01': PCK(
        dist_sync_on_step=True,
        input_nodes=CARLA_SKELETON,
        output_nodes=CARLA_SKELETON,
        mask_missing_joints=True,
        key=outputs_key,
        threshold=0.1,
        get_normalization_tensor=get_normalization_tensor,
    ),
    'PCK@005': PCK(
        dist_sync_on_step=True,
        input_nodes=CARLA_SKELETON,
        output_nodes=CARLA_SKELETON,
        mask_missing_joints=True,
        key=outputs_key,
        threshold=0.05,
        get_normalization_tensor=None,  # standard bbox normalization
    ),
})

In [None]:
import itertools
from tqdm.auto import tqdm

for gt_item, pred_item in tqdm(
    zip(itertools.chain(gt_carla_rec, gt_amass), itertools.chain(pred_carla_rec, pred_amass)),
    total=len(gt_carla_rec) + len(gt_amass)
):
    metrics_collection.update(gt_item[1], pred_item[1])

results = metrics_collection.compute()
results