In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import torch

CUDA_VISIBLE_DEVICES = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES

gender_info = {
    "P1": 0, "P3": 0, "P4": 0, "P7": 0, "P8": 0, "P10": 0, "P13": 0, "P15": 0,
    "P2": 1, "P5": 1, "P6": 1, "P9": 1, "P11": 1, "P12": 1, "P14": 1, "P16": 1,
    "P17": 1, "P18": 1, "P19": 1, "P20": 1
}

from plot_paper_func import render_smplx_from_paramdict, create_smpl_model, render_smplx_pred_gt_side_by_side, ppt_hsb_to_rgba
from plot_loss import select_indices_all_actions


# Loading A small Demo dataset with RGB, Depth, RPC, RT and GT Mesh modalities...

In [2]:
from dataset.dataset_mmMesh2_vis import RF3DPoseDataset, ToTensor
import torchvision.transforms as transforms

dataset_vis = RF3DPoseDataset([], transform=ToTensor(), load_save=True, use_image=True, is_demo=True, cache_dir="../cached_data_test_vis/")

Dataset loaded Successfully from ../cached_data_test_vis/rf3dpose_all ...
Kept 5000 / 5000 samples after filtering.


In [None]:
from dataset.extra_util_dataloader_vis import plot_frames_for_gif_depth
for i in range(1):
    plot_frames_for_gif_depth(dataset_vis, combined_gif_path=f'vis_depth/combined_output_depth_{i+1}.gif', act_id=i+1, is_plot="show")

Generating 5000 frames starting from index 0 for combined GIF with 4 columns...
tensor([  2,   1, 100]) image: torch.Size([3, 480, 640]), depth: torch.Size([1, 480, 640]), radar_tensor: torch.Size([121, 111, 31]), radar_points: torch.Size([1000, 4])
Radar points range: -2.9427934 0.70828354 1.8797358 4.643846 -1.4849242 0.86974144
tensor([  2,   1, 101]) image: torch.Size([3, 480, 640]), depth: torch.Size([1, 480, 640]), radar_tensor: torch.Size([121, 111, 31]), radar_points: torch.Size([1000, 4])
Radar points range: -2.9427934 0.34054098 1.8939896 4.643846 -1.4849242 0.6788225
tensor([  2,   1, 102]) image: torch.Size([3, 480, 640]), depth: torch.Size([1, 480, 640]), radar_tensor: torch.Size([121, 111, 31]), radar_points: torch.Size([1000, 4])
Radar points range: -2.9427934 0.79327756 1.9431412 4.643846 -1.4778532 0.7636752
tensor([  2,   1, 103]) image: torch.Size([3, 480, 640]), depth: torch.Size([1, 480, 640]), radar_tensor: torch.Size([121, 111, 31]), radar_points: torch.Size([100

KeyboardInterrupt: 

# Loading Full dataset with only radar-based RPC, RT and GT Mesh modalities...

In [7]:
from dataset.dataset_mmMesh2 import RF3DPoseDataset, ToTensor
import torchvision.transforms as transforms

dataset_train = RF3DPoseDataset([], 
                                # Train-Test Specifics
                                split = "train",  
                                main_modality="rt", 
                                protocol_id="p1", 
                                split_id="s1", 
                                temporal_window=6,
                                # meta info
                                load_save=True, 
                                cache_dir="../mmDataset/MR-Mesh/",  
                                transform=transforms.Compose([ToTensor()]), 
                                )
dataset_test = RF3DPoseDataset([], 
                               # Train-Test Specifics
                               split = "test",  
                               main_modality="rt", 
                               protocol_id="p1", 
                               split_id="s1", 
                               temporal_window=6,
                               # meta info
                               load_save=True, 
                               cache_dir="../mmDataset/MR-Mesh/",  
                               transform=transforms.Compose([ToTensor()]), 
                               )

Dataset loaded Successfully from ../mmDataset/MR-Mesh/rf3dpose_all ...
Load indices from pre-saved ../mmDataset/MR-Mesh/rf3dpose_all/indeces.pkl.gz indices splits.
Loaded train Dataset with length 495512.
Unique sub in train: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
Unique act in train: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
Kept 495506 / 495512 samples after filtering.
Dataset loaded Successfully from ../mmDataset/MR-Mesh/rf3dpose_all ...
Load indices from pre-saved ../mmDataset/MR-Mesh/rf3dpose_all/indeces.pkl.gz indices splits.
Loaded test Dataset with length 131809.
Unique sub in test: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
Unique act in test: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,

In [6]:

def get_sample_by_indicator(indicator, dataset_test=dataset_test):
    arr = np.asarray(dataset_test.indicator_list)
    mask = np.all(arr == np.asarray(indicator), axis=1)
    idxs = np.where(mask)[0]
    if idxs.size == 0:
        raise ValueError(f"No match for indicator {indicator}")
    idx = int(idxs[0])
    sample = dataset_test[idx]  # get the sample
    return sample

get_sample_by_indicator((6, 6, 350), dataset_test=dataset_test).keys()


dict_keys(['rawImage_XYZ', 'vertices', 'bbbox', 'projected_vertices', 'parameter', 'calibration', 'indicator', 'radar_points', 'joints_root'])

# multimodality

In [9]:
import math

def _aa_to_R(aa):
    aa = np.asarray(aa, dtype=np.float32).reshape(3)
    theta = float(np.linalg.norm(aa))
    if theta < 1e-8:
        return np.eye(3, dtype=np.float32)
    k = aa / theta
    kx, ky, kz = k
    K = np.array([[0, -kz, ky],
                  [kz, 0, -kx],
                  [-ky, kx, 0]], dtype=np.float32)
    I = np.eye(3, dtype=np.float32)
    return I + math.sin(theta) * K + (1.0 - math.cos(theta)) * (K @ K)

def _geodesic_angle_aa(aa1, aa2):
    R1 = _aa_to_R(aa1)
    R2 = _aa_to_R(aa2)
    R = R1.T @ R2
    c = np.clip((np.trace(R) - 1.0) * 0.5, -1.0, 1.0)
    return float(np.arccos(c))

def _combined_loss(params_a, params_b, idx, w_trans=1.0, w_root=1.0, w_pose=0.1):
    # trans L2
    ta = params_a["trans"][idx]; tb = params_b["trans"][idx]
    loss_t = float(np.linalg.norm(ta - tb))
    # root_orient geodesic (axis-angle 3,)
    ra = params_a["root_orient"][idx] if "root_orient" in params_a else params_a["global_orient"][idx]
    rb = params_b["root_orient"][idx] if "root_orient" in params_b else params_b["global_orient"][idx]
    loss_r = _geodesic_angle_aa(ra, rb)
    # pose L2 (full pose vector)
    pa = params_a["pose"][idx]; pb = params_b["pose"][idx]
    loss_p = float(np.linalg.norm(pa - pb))
    return w_trans * loss_t + w_root * loss_r + w_pose * loss_p

def select_indices_all_actions_mm(data_rt, data_pc, data_depth, data_rt_rgb, data_rt_depth, data_rt_rpc, data_rgb, 
                               w_trans=1.0, w_root=0.0, w_pose=0.0, tol=1e-6, mix_rate=1.0):
    """
    For each action (indicators_test[:,1] in 1..50):
      - Calculate combined loss for all modalities.
      - Filter indices to ensure the loss ordering:
        loss_rt_depth < loss_rt_rgb < loss_rt_rpc ~ loss_depth < loss_rt < loss_rpc < loss_rgb
      - Sort indices based on total loss sum.
    """
    indicators = data_rt["indicators_test"]
    pred = data_rt["pred_params_test"]
    gt = data_rt["gt_params_test"]
    pred_pc = data_pc["pred_params_test"]

    # Additional modalities
    modalities = {
        "depth": data_depth["pred_params_test"],
        "rt_rgb": data_rt_rgb["pred_params_test"],
        "rt_depth": data_rt_depth["pred_params_test"],
        "rt_rpc": data_rt_rpc["pred_params_test"],
        "rgb": data_rgb["pred_params_test"],
        "rt": data_rt["pred_params_test"],
        "rpc": data_pc["pred_params_test"]
    }

    out = {}
    loss = {}
    for action_id in range(1, 51):
        idxs = np.where(indicators[:, 1] == action_id)[0]
        if idxs.size == 0:
            out[action_id] = None
            continue

        # Calculate losses for each modality
        losses = {name: np.array([_combined_loss(pred_mod, gt, i, w_trans, w_root, w_pose) for i in idxs], dtype=np.float64)
                  for name, pred_mod in modalities.items()}

        # Apply the filtering condition
        valid_idxs = []
        for i, idx in enumerate(idxs):
            loss_rt_depth = losses["rt_depth"][i]
            loss_rt_rgb = losses["rt_rgb"][i]
            loss_rt_rpc = losses["rt_rpc"][i]
            loss_depth = losses["depth"][i]
            loss_rt = losses["rt"][i]
            loss_rpc = losses["rpc"][i]
            loss_rgb = losses["rgb"][i]

            # Ensure the loss ordering condition
            if loss_rt_depth < loss_rt_rgb < loss_rt_rpc and \
               loss_rt_rpc <= loss_depth and \
               loss_depth < loss_rt < loss_rpc < loss_rgb:
                valid_idxs.append(idx)

        # If no valid indices, skip this action
        if not valid_idxs:
            out[action_id] = None
            continue

        valid_idxs = np.array(valid_idxs)

        # Calculate total loss sum for sorting
        total_loss = np.sum([losses[name][np.where(idxs == valid_idxs[:, None])[1]] for name in modalities.keys()], axis=0)

        # Sort indices based on total loss sum
        order = np.argsort(total_loss)  # Ascending order of total loss
        sorted_idxs = valid_idxs[order]

        out[action_id] = [int(i) for i in sorted_idxs]
        loss[action_id] = total_loss[order]
        print(f"Action {action_id}: {out[action_id][:10]} -> Loss: {loss[action_id][:10]}")

    return out, loss

# direct indicator

In [6]:
# import time
# for sorted_id in range(0,1):
#     # action_id = 33
#     # sorted_id = 47
#     camera_position = [3, -3.5, 1.]
    
#     indicator = [7, 34, 99]


#     # print(f"Action {action_id}: Selected indices: {selected_by_action[action_id][:10]}")
#     # print(f"Action {action_id}: Corresponding losses: {loss_by_actions[action_id][:10]}")
#     saved_folder = "."
#     os.makedirs(saved_folder, exist_ok=True)
#     gender = "male" if gender_info["P" + str(indicator[0])] == 1 else "female"
#     data_sample_gt = {}
#     try:
#         data_sample_all = get_sample_by_indicator(indicator, dataset_test=dataset_test)
#     except Exception as e:
#         print(f"Error processing test indicator {indicator}: {e}")
#         data_sample_all = get_sample_by_indicator(indicator, dataset_test=dataset_train)
    
#     print(data['pred_params_test'].keys(), data_sample_all["parameter"].keys())
#     data_sample_gt['betas'] = data_sample_all["parameter"]['betas']
#     data_sample_gt['global_orient'] = data_sample_all["parameter"]['root_orient']
#     data_sample_gt['pose'] = data_sample_all["parameter"]['pose_body']
#     data_sample_gt['trans'] = data_sample_all["parameter"]['trans']
#     data_sample_gt['gender'] = data_sample_all["parameter"]['gender']


    
#     data_sample_points = data_sample_all["radar_points"].reshape(-1, 4)[:,:3]
#     data_sample_joint = data_sample_all["joints_root"][0]
#     # Filter out points within the specified range
#     filter_range = .8
#     filtered_points = data_sample_points[
#         (data_sample_points[:, 0] >= data_sample_joint[0] - filter_range) &
#         (data_sample_points[:, 0] <= data_sample_joint[0] + filter_range) &
#         (data_sample_points[:, 1] >= data_sample_joint[1] - filter_range) &
#         (data_sample_points[:, 1] <= data_sample_joint[1] + filter_range)
#     ]

#     # img = render_smplx_from_paramdict(smpl_models[gender], data_sample, device="cuda")
#     ground_z = None
#     img, img_gt, img_gt = render_smplx_pred_gt_side_by_side(smpl_models[gender], data_sample_gt, data_sample_gt, device="cuda", pc=filtered_points, pc_color=ppt_hsb_to_rgba(0, 100, 100, 100), camera_position=camera_position, ground_z=ground_z, pc_size=5.0)

#     from PIL import Image
#     # Image.fromarray(img).save("smplx_render.png")
#     Image.fromarray(img_gt).save(os.path.join(saved_folder, f"smplx_render_gt.png"))
#     # time.sleep(3)

In [None]:
import copy
from plot_paper_func import render_smplx_frames_overlay



import time
for sorted_id in range(1):
    camera_position = [4., -4, 0.3]
    
    start = 95
    sub = 2
    indicator = [[sub, 50, start+3], [sub, 50, start- 5]]
    pc_range = [.5,.5,.5]
    
    


    # print(f"Action {action_id}: Selected indices: {selected_by_action[action_id][:10]}")
    # print(f"Action {action_id}: Corresponding losses: {loss_by_actions[action_id][:10]}")
    saved_folder = "."
    os.makedirs(saved_folder, exist_ok=True)
    i = 0
    
    
    data_sample_gt_arr = []
    filtered_points_arr = []
    for indi in indicator:
        try:
            data_sample_all = get_sample_by_indicator(indi, dataset_test=dataset_test)
        except Exception as e:
            print(f"Error processing test indicator {indi}: {e}")
            data_sample_all = get_sample_by_indicator(indi, dataset_test=dataset_train)
       

        data_sample_gt = {}
        gender = "male" if gender_info["P" + str(indi[0])] == 1 else "female"
        data_sample_gt['betas'] = data_sample_all["parameter"]['betas']
        data_sample_gt['global_orient'] = data_sample_all["parameter"]['root_orient']
        data_sample_gt['pose'] = data_sample_all["parameter"]['pose_body']
        data_sample_gt['trans'] = data_sample_all["parameter"]['trans']
        data_sample_gt['gender'] = data_sample_all["parameter"]['gender']
        
        
        data_sample_gt_arr.append(data_sample_gt)
        
        data_sample_points = data_sample_all["radar_points"].reshape(-1, 4)[:,:3]
        data_sample_joint = data_sample_all["joints_root"][0]
        # Filter out points within the specified range
        filter_range = pc_range[i]
        filtered_points = data_sample_points[
            (data_sample_points[:, 0] >= data_sample_joint[0] - filter_range) &
            (data_sample_points[:, 0] <= data_sample_joint[0] + filter_range) &
            (data_sample_points[:, 1] >= data_sample_joint[1] - filter_range) &
            (data_sample_points[:, 1] <= data_sample_joint[1] + filter_range)
        ]
        filtered_points_arr.append(copy.deepcopy(filtered_points))
        i+=1


    # img = render_smplx_from_paramdict(smpl_models[gender], data_sample, device="cuda")
    ground_z = 0.3
    img_gt = render_smplx_frames_overlay(
    smpl_models[gender],
    data_sample_gt_arr,
    [0, 1],
    device="cuda",
    camera_position=camera_position,
    base_color_hsb=[(209, 64, 85), (209, 64, 85), (209, 64, 85)],  # darker orange, darker green, darker blue
    min_alpha_pct=50,
    max_alpha_pct=100,
    pcs=filtered_points_arr,  # PCs have no shadow
    pc_colors=(0, 100, 100, 100),
    pc_size=4.0,
    ground_z=ground_z,
)


    from PIL import Image
    # Image.fromarray(img).save("smplx_render.png")
    # Image.fromarray(img_gt).save(os.path.join(saved_folder, f"smplx_render_gt{indicator}.png"))
    Image.fromarray(img_gt).save(os.path.join(f"smplx_render_gt.png"))

    # time.sleep(3)

## Search based

In [8]:
selected_by_action, loss_by_actions = select_indices_all_actions(data, data_pc, w_trans=10.0, w_root=1.0, w_pose=1.0, tol=1e-6, mix_rate=0.2)

[12372, 72132, 72116, 12371, 91975, 72117, 105555, 72115, 91969, 72118] [-0.01352544  0.13781323  0.16917648  0.20105727  0.22281634  0.22874631
  0.27319805  0.2993162   0.30658599  0.32930904]
[112475, 125657, 85368, 79, 125617, 125618, 125661, 45431, 125566, 72237] [0.25885371 0.38502101 0.3942021  0.40022091 0.41355357 0.42276499
 0.42576445 0.43009522 0.43488593 0.43856888]
[72427, 243, 112539, 112548, 112594, 242, 112547, 112508, 85451, 5954] [0.22550039 0.25757587 0.28833205 0.29348908 0.30258961 0.31455734
 0.32640557 0.33105186 0.33335783 0.3360806 ]
[45703, 85596, 59130, 45713, 38938, 6118, 45700, 45702, 112653, 38942] [0.25360274 0.26749248 0.30973844 0.33145897 0.33787191 0.35417436
 0.36139603 0.36646259 0.36740847 0.36937463]
[125965, 39071, 105997, 45865, 59267, 105998, 106011, 126054, 92477, 105996] [0.11247791 0.24644564 0.2583206  0.26808441 0.27777002 0.29474723
 0.30054765 0.30591386 0.31540983 0.32038205]
[39230, 25922, 19603, 85802, 59347, 92597, 106259, 32497, 92

In [None]:
import time
for sorted_id in[5,39]:
    action_id = 22
    # sorted_id = 47
    index = selected_by_action[action_id][sorted_id] #8900
    camera_position = [-4., 0.0, 0.5]


    # print(f"Action {action_id}: Selected indices: {selected_by_action[action_id][:10]}")
    # print(f"Action {action_id}: Corresponding losses: {loss_by_actions[action_id][:10]}")
    saved_folder = "./static_fig"
    os.makedirs(saved_folder, exist_ok=True)
    gender = "male" if gender_info["P" + str(data['indicators_test'][index][0])] == 1 else "female"
    indicator = data['indicators_test'][index]
    data_sample_pred = {}
    data_sample_gt = {}
    data_sample_pred_pc = {}
    for key in data['pred_params_test'].keys():
        data_sample_pred[key] = data['pred_params_test'][key][index]
        data_sample_gt[key] = data['gt_params_test'][key][index]
        data_sample_pred_pc[key] = data_pc['pred_params_test'][key][index]

    data_sample_all = get_sample_by_indicator(indicator, dataset_test=dataset_test)
    data_sample_points = data_sample_all["radar_points"].reshape(-1, 4)[:,:3]
    data_sample_joint = data_sample_all["joints_root"][0]
    # Filter out points within the specified range
    filter_range = 0.5
    filtered_points = data_sample_points[
        (data_sample_points[:, 0] >= data_sample_joint[0] - filter_range) &
        (data_sample_points[:, 0] <= data_sample_joint[0] + filter_range) &
        (data_sample_points[:, 1] >= data_sample_joint[1] - filter_range) &
        (data_sample_points[:, 1] <= data_sample_joint[1] + filter_range)
    ]
    print(action_id, sorted_id, indicator)

    # img = render_smplx_from_paramdict(smpl_models[gender], data_sample, device="cuda")
    ground_z = None
    img, img_pred, img_gt = render_smplx_pred_gt_side_by_side(smpl_models[gender], data_sample_pred, data_sample_gt, device="cuda", camera_position=camera_position, ground_z=ground_z)
    img, img_pred_pc, img_gt_pc = render_smplx_pred_gt_side_by_side(smpl_models[gender], data_sample_pred_pc, data_sample_gt, device="cuda", pc=filtered_points, pc_color=ppt_hsb_to_rgba(0, 100, 100, 100), camera_position=camera_position, ground_z=ground_z, pc_size=5.0)

    from PIL import Image
    # Image.fromarray(img).save("smplx_render.png")
    Image.fromarray(img_gt).save(os.path.join(saved_folder, f"smplx_render_gt.png"))
    Image.fromarray(img_pred).save(os.path.join(saved_folder, f"smplx_render_pred.png"))
    Image.fromarray(img_pred_pc).save(os.path.join(saved_folder, f"smplx_render_pred_pc.png"))
    Image.fromarray(img_gt).save(os.path.join(saved_folder, f"smplx_render_gt{action_id}_{sorted_id}.png"))
    Image.fromarray(img_pred).save(os.path.join(saved_folder, f"smplx_render_pred{action_id}_{sorted_id}.png"))
    Image.fromarray(img_pred_pc).save(os.path.join(saved_folder, f"smplx_render_pred_pc{action_id}_{sorted_id}.png"))
    # time.sleep(3)

2 0 [ 18   2 266]
2 1 [ 20   2 379]
2 2 [ 14   2 573]
2 3 [  1   2 304]
2 4 [ 20   2 339]
2 5 [ 20   2 340]
2 6 [ 20   2 383]
2 7 [  8   2 311]
2 8 [ 20   2 288]
2 9 [ 12   2 591]


In [None]:
# index_save = {}
# index_save[10] = [41,77]
# index_save[12] = [1,15, 20]
# index_save[14] = [3, 7, 16, 19]
# index_save[16] = [13,25, 27, 34]
# index_save[17] = [1,3,7, 9, 10, 13] # [-1., -3.0, 0.3]
# index_save[19] = [39, 4] # [-4., 0.0, 0.5]
# index_save[20] = [1, 2, 3, 19, 21]
# index_save[22] = [5,39]
# index_save[24] = [5,7,27,]#[-0, -4.0, 0.5]
# index_save[27] = [0, 2,39,50,98,102,109,114, 116, 117]
# index_save[28] = [0,2,11,17,35, 36, 43, 45, 50]
# [1,2, 9,11,20, 36, 37] [29]
# [8, 10, 21, 22, 24, 30, 40] [31]
# [4,5,10,16, 35, 47] [33]




In [None]:
from plot_paper_func import render_smplx_frames_overlay


import time
for sorted_id in range(1):
    action_id = 50
    sorted_id = 3
    index = 19999 #selected_by_action[action_id][sorted_id] #8900
    camera_position =  [-4.0, -1.0, 0.5]


    # print(f"Action {action_id}: Selected indices: {selected_by_action[action_id][:10]}")
    # print(f"Action {action_id}: Corresponding losses: {loss_by_actions[action_id][:10]}")
    saved_folder = "./SMPL_plot"
    os.makedirs(saved_folder, exist_ok=True)
    gender = "male" if gender_info["P" + str(data['indicators_test'][index][0])] == 1 else "female"
    indicator = data['indicators_test'][index]
    data_sample_gt_list = []
    data_sample_pred_list = []
    data_sample_pred_pc_list = []
    filtered_points_list = []
    for i in range(15):
        index_now = index - i
        indicator_now = data['indicators_test'][index_now]
        data_sample_pred = {}
        data_sample_gt = {}
        data_sample_pred_pc = {}
        for key in data['pred_params_test'].keys():
            data_sample_pred[key] = data['pred_params_test'][key][index_now]
            data_sample_gt[key] = data['gt_params_test'][key][index_now]
            data_sample_pred_pc[key] = data_pc['pred_params_test'][key][index_now]

        data_sample_all = get_sample_by_indicator(indicator_now, dataset_test=dataset_test)
        data_sample_points = data_sample_all["radar_points"].reshape(-1, 4)[:,:3]
        data_sample_joint = data_sample_all["joints_root"][0]
        # Filter out points within the specified range
        filter_range = 0.5
        filtered_points = data_sample_points[
            (data_sample_points[:, 0] >= data_sample_joint[0] - filter_range) &
            (data_sample_points[:, 0] <= data_sample_joint[0] + filter_range) &
            (data_sample_points[:, 1] >= data_sample_joint[1] - filter_range) &
            (data_sample_points[:, 1] <= data_sample_joint[1] + filter_range)
        ]
        # print(sorted_id, indicator_now)
        
        filtered_points_list.append(filtered_points)
        data_sample_gt_list.append(data_sample_gt)
        data_sample_pred_list.append(data_sample_pred)
        data_sample_pred_pc_list.append(data_sample_pred_pc)
        
    

    # img = render_smplx_from_paramdict(smpl_models[gender], data_sample, device="cuda")
    ground_z = None
    # img, img_pred, img_gt = render_smplx_pred_gt_side_by_side(smpl_models[gender], data_sample_pred, data_sample_gt, device="cuda", camera_position=camera_position, ground_z=ground_z)
    # img, img_pred_pc, img_gt_pc = render_smplx_pred_gt_side_by_side(smpl_models[gender], data_sample_pred_pc, data_sample_gt, device="cuda", pc=filtered_points, pc_color=ppt_hsb_to_rgba(0, 100, 100, 100), camera_position=camera_position, ground_z=ground_z)
    img_gt = render_smplx_frames_overlay(
    smpl_models[gender],
    data_sample_gt_list,
    [0, 7, 14],
    device="cuda",
    camera_position=camera_position,
    base_color_hsb=(209, 64, 85),  # cyan
    min_alpha_pct=50,
    max_alpha_pct=100,
    pcs=None,  # PCs have no shadow
    pc_colors=(0, 100, 100, 100),
    pc_size=4.0,
    ground_z=ground_z,
)
    img_pred = render_smplx_frames_overlay(
    smpl_models[gender],
    data_sample_pred_list,
    [0, 7, 14],
    device="cuda",
    camera_position=camera_position,
    base_color_hsb=(21, 79, 91),  # cyan
    min_alpha_pct=50,
    max_alpha_pct=100,
    pcs=None,  # PCs have no shadow
    pc_colors=(0, 100, 100, 100),
    pc_size=4.0,
    ground_z=ground_z,
)
    img_pred_pc = render_smplx_frames_overlay(
    smpl_models[gender],
    data_sample_pred_pc_list,
    [0, 7, 14],
    device="cuda",
    camera_position=camera_position,
    base_color_hsb=(21, 79, 91),  # cyan
    min_alpha_pct=50,
    max_alpha_pct=100,
    pcs=None,  # PCs have no shadow
    pc_colors=(0, 100, 100, 100),
    pc_size=4.0,
    ground_z=ground_z,
)
    from PIL import Image
    # Image.fromarray(img).save("smplx_render.png")
    Image.fromarray(img_gt).save(os.path.join(saved_folder, f"smplx_render_gt{indicator}.png"))
    Image.fromarray(img_pred).save(os.path.join(saved_folder, f"smplx_render_pred{indicator}.png"))
    Image.fromarray(img_pred_pc).save(os.path.join(saved_folder, f"smplx_render_pred_pc{indicator}.png"))
    Image.fromarray(img_gt).save(os.path.join(f"smplx_render_gt.png"))
    Image.fromarray(img_pred).save(os.path.join(f"smplx_render_pred.png"))
    Image.fromarray(img_pred_pc).save(os.path.join(f"smplx_render_pred_pc.png"))
    print(action_id, sorted_id, indicator)
    # time.sleep(3)

TypeError: cannot unpack non-iterable int object

In [None]:
[50] [0,3, 183]  [-1.0, -5.0, 0.5]
[49] [569]  [-1.0, -5.0, 0.5]
[47] [39,69] [3.0, -2.0, 0.5]
[46][36] [3.0, -2.0, 0.5]
# [45] [52] [4.0, -2.0, 0.5]  [-4.0, -2.0, 0.5]
[43] [72] [-4.0, -2.0, 0.5]
[38] [26] [-1.0, -5.0, 0.5]
[38] [10, 32] [-1.0, -5.0, 0.5] [30]
[38] [43] [0, 17, 29] [-1.0, -5.0, 0.5] [30]



  [50] [0,3, 183]  [-1.0, -5.0, 0.5]
  [50] [0,3, 183]  [-1.0, -5.0, 0.5]
  [50] [0,3, 183]  [-1.0, -5.0, 0.5]
  [50] [0,3, 183]  [-1.0, -5.0, 0.5]


TypeError: list indices must be integers or slices, not tuple

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

: 

# Compare with mmBody

In [None]:
# def plot_pc_and_mesh(pc, mesh, img=None, centers=None,
#                      xlim=(-3, 3), ylim=(0, 6), zlim=(1, 1000),
#                      pc_color='k', mesh_color='r',
#                      pc_size=2, mesh_size=1,
#                      elev=20, azim=-60,
#                      save_path=None, dpi=150,
#                      bev_bins=(60, 60)):
#     """
#     Plot point cloud (N,6) using first 3 cols as XYZ, mesh vertices (V,3),
#     and a list of human joint centers in a 3-column layout:
#     - Left: 3D point cloud and mesh
#     - Middle: Input image
#     - Right: 3D BEV histogram of human center distribution with logarithmic Z-axis.

#     Args:
#       pc: array-like (N,6) or (N,3) -- first 3 cols used as x,y,z
#       mesh: array-like (V,3)
#       img: optional image as numpy or torch tensor (3,H,W) or (H,W,C)
#       centers: list of human joint centers (list of [3] tensors)
#       xlim, ylim, zlim: axis limits
#       pc_color, mesh_color: colors for pc and mesh
#       pc_size, mesh_size: marker sizes
#       elev, azim: view elevation and azimuth for 3D plot
#       save_path: optional path to save PNG. If provided and has no .png extension, .png is appended.
#       dpi: saving DPI
#     """
#     import numpy as np
#     from matplotlib import cm
#     from matplotlib.colors import Normalize
#     import matplotlib.pyplot as plt

#     pc_arr = np.asarray(pc)
#     if pc_arr.ndim == 2 and pc_arr.shape[1] >= 3:
#         xyz = pc_arr[:, :3]
#     elif pc_arr.ndim == 1 and pc_arr.size >= 3:
#         xyz = pc_arr.reshape(-1, 3)[:, :3]
#     else:
#         raise ValueError("pc must be shape (N,6) or (N,3)")

#     mesh_arr = np.asarray(mesh)
#     if mesh_arr.ndim != 2 or mesh_arr.shape[1] != 3:
#         raise ValueError("mesh must be shape (V,3)")

#     # prepare image if provided
#     img_to_show = None
#     if img is not None:
#         img_arr = np.asarray(img)
#         if img_arr.ndim == 3 and img_arr.shape[0] in (1, 3, 4) and img_arr.shape[0] != img_arr.shape[2]:
#             img_to_show = img_arr.transpose(1, 2, 0)
#         else:
#             img_to_show = img_arr
#         if img_to_show.ndim == 2:
#             img_to_show = np.stack([img_to_show]*3, axis=-1)

#     # prepare centers if provided
#     if centers is not None:
#         centers = np.stack([c.numpy() if isinstance(c, torch.Tensor) else np.array(c) for c in centers])
#         if centers.ndim != 2 or centers.shape[1] not in (2, 3):
#             raise ValueError("centers must be a list of [3] tensors or shape (M,2)/(M,3)")

#     # figure with three panels
#     fig = plt.figure(figsize=(15, 5))
#     ax3d = fig.add_subplot(1, 3, 1, projection='3d')
#     ax_img = fig.add_subplot(1, 3, 2)
#     ax_bev = fig.add_subplot(1, 3, 3, projection='3d')

#     # --- Left panel: 3D point cloud and mesh ---
#     ax3d.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2],
#                  c=pc_color, s=pc_size, depthshade=True, linewidths=0, alpha=1.0)
#     ax3d.scatter(mesh_arr[:, 0], mesh_arr[:, 1], mesh_arr[:, 2],
#                  c=mesh_color, s=mesh_size, depthshade=True, linewidths=0, alpha=1.0)

#     ax3d.set_xlim(*xlim)
#     ax3d.set_ylim(*ylim)
#     ax3d.set_zlim(*zlim)
#     ax3d.set_xticks([])
#     ax3d.set_yticks([])
#     ax3d.set_zticks([])
#     ax3d.view_init(elev=elev, azim=azim)

#     # --- Middle panel: Input image ---
#     if img_to_show is not None:
#         img_disp = img_to_show.astype(np.float32)
#         if img_disp.max() > 1.0:
#             img_disp = img_disp / 255.0
#         ax_img.imshow(np.clip(img_disp, 0.0, 1.0))
#     ax_img.axis('off')

#     # --- Right panel: BEV histogram surface of centers ---
#     if centers is not None:
#         H, xedges, yedges = np.histogram2d(
#             centers[:, 0], centers[:, 1],
#             bins=[bev_bins[0], bev_bins[1]],
#             range=[[xlim[0], xlim[1]], [ylim[0], ylim[1]]]
#         )
#         xcenters = (xedges[:-1] + xedges[1:]) / 2.0
#         ycenters = (yedges[:-1] + yedges[1:]) / 2.0
#         X, Y = np.meshgrid(xcenters, ycenters, indexing='xy')

#         # Clamp counts to [1, 1e6] and plot in log10 space
#         Z_counts = np.clip(H.T, 1, 1_000_000)
#         Z_log = np.log10(Z_counts)

#         # Fixed log range 10^0..10^6 -> 0..6 in log10
#         vmin_log, vmax_log = 0.0, 6.0
#         norm = Normalize(vmin=vmin_log, vmax=vmax_log)

#         # Plot surface with fixed normalization
#         surf = ax_bev.plot_surface(
#             X, Y, Z_log, cmap='viridis', norm=norm,
#             edgecolor='none', alpha=0.9
#         )
#         # ax_bev.plot_wireframe(X, Y, Z_log, color='k', linewidth=0.2, alpha=0.2)

#         # Colorbar labeled in powers of 10
#         cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap='viridis'),
#                             ax=ax_bev, shrink=0.5, aspect=10)
#         tick_locs = np.arange(int(vmin_log), int(vmax_log) + 1)
#         cbar.set_ticks(tick_locs)
#         # cbar.set_ticklabels([f"1e{k}" for k in tick_locs])
#         # cbar.set_label('Count', rotation=270, labelpad=15)

#         # Axes limits and tick labels to match the fixed log range
#         ax_bev.set_xlim(xlim[0], xlim[1])
#         ax_bev.set_ylim(ylim[0], ylim[1])
#         ax_bev.set_zlim(vmin_log, vmax_log)
#         # ax_bev.set_zticks(tick_locs)
#         # ax_bev.set_zticklabels([f"1e{k}" for k in tick_locs])
#         ax_bev.set_xticks([])
#         ax_bev.set_yticks([])
#         ax_bev.set_zticks([])
#         # ax_bev.set_zticks(tick_locs)
#         ax_bev.view_init(elev=elev, azim=azim)  # Match view direction with PC plot

#     plt.tight_layout()

#     if save_path is not None:
#         sp = str(save_path)
#         if not sp.lower().endswith('.png'):
#             sp = sp + '.png'
#         fig.savefig(sp, dpi=dpi, bbox_inches='tight', pad_inches=0)
#         plt.close(fig)
#         return sp

#     return fig

: 

: 

In [None]:

# dataset_test_select = torch.utils.data.Subset(dataset_test, range(5800, len(dataset_test)))
# data_loader = torch.utils.data.DataLoader(
#     dataset_test_select, batch_size=1, shuffle=True
# )
# joint_center_list = data_pc["pred_params_test"]["trans"]
# joint_center_list = [torch.tensor(joint_center_list[i]) for i in range(joint_center_list.shape[0])]
# ratio_count = 0
# ratio_PC = 0

# for batch in data_loader:
#     indicator = batch['indicator'][0].numpy()
#     radar = batch['radar_points'][0][-1].numpy().reshape(-1, 4)[:,:3]
#     radar = radar[~np.all(radar == 0, axis=1)]
#     center_joint = batch['joints_root'][0][0].numpy()
#     print(radar.shape, center_joint.shape)
    
#     # ensure radar points as a torch tensor on CPU
#     if isinstance(radar, np.ndarray):
#         pts = torch.from_numpy(radar).float()
#     else:
#         pts = radar.detach().cpu().float()

#     # ensure pts is (N, >=3)
#     if pts.dim() == 1:
#         pts = pts.view(-1, 3)
#     xyz = pts[:, :3]

#     # ensure center_joint is tensor on CPU
#     if isinstance(center_joint, np.ndarray):
#         center = torch.from_numpy(center_joint).float()
#     else:
#         center = center_joint.detach().cpu().float()

#     # offsets relative to center joint
#     dx = xyz[:, 0] - center[0]
#     dy = xyz[:, 1] - center[1]
#     dz = xyz[:, 2] - center[2]

#     # filter: x within ±1, y within ±1, z between 0 and 3 (relative to joint)
#     mask = (dx.abs() <= 1.0) & (dy.abs() <= 1.0) & (dz >= -1.5) & (dz <= 1.5)

#     count_close = int(mask.sum().item())
#     total_points = int(xyz.shape[0]) if xyz.shape[0] > 0 else 0
#     ratio = (count_close / total_points) if total_points > 0 else 0.0

#     # accumulate and compute running average across processed samples
#     ratio_PC += ratio
#     ratio_count += 1
#     avg_ratio = ratio_PC / ratio_count

#     print(f"ratio_count {ratio_count} sample {i}: {count_close}/{total_points} -> ratio {ratio:.4f}, running avg {avg_ratio:.4f}")
        
#     # if ratio_count == 20:
#     #     break
    
#     # print(f"Processing indicator: {indicator}")
#     # if indicator[1] == 1:
#     #     pc = batch['radar_points'][0][-1].numpy().reshape(-1, 4)[:,:3]
#     #     pc = pc[~np.all(pc == 0, axis=1)]
#     #     joints = batch['joints_root'][0].numpy()
#     #     vertices = batch['vertices'][0].numpy()/1000 
#     #     joint_center_list.append(joints[0])
        
#     #     save_path = plot_pc_and_mesh(
#     #         pc=pc,
#     #         mesh=vertices,
#     #         centers=joint_center_list,
#     #         xlim=(-3, 3), ylim=(0, 9), zlim=(0, 4),
#     #         pc_color='black',
#     #         mesh_color='red',
#     #         pc_size=2,
#     #         mesh_size=1,
#     #         elev=20,
#     #         azim=-60,
#     #         save_path=f"mmBody_compare/pc_mesh_plot{indicator}.png",
#     #         dpi=800,
#     #         bev_bins=(60, 60)
#     #     )
#     #     print(f"Saved PC and mesh plot to {save_path}")
        


(323, 3) (3,)
ratio_count 1 sample 14: 302/323 -> ratio 0.9350, running avg 0.9350
(773, 3) (3,)
ratio_count 2 sample 14: 609/773 -> ratio 0.7878, running avg 0.8614
(522, 3) (3,)
ratio_count 3 sample 14: 392/522 -> ratio 0.7510, running avg 0.8246
(286, 3) (3,)
ratio_count 4 sample 14: 251/286 -> ratio 0.8776, running avg 0.8379
(644, 3) (3,)
ratio_count 5 sample 14: 623/644 -> ratio 0.9674, running avg 0.8638
(555, 3) (3,)
ratio_count 6 sample 14: 508/555 -> ratio 0.9153, running avg 0.8724
(526, 3) (3,)
ratio_count 7 sample 14: 491/526 -> ratio 0.9335, running avg 0.8811
(413, 3) (3,)
ratio_count 8 sample 14: 360/413 -> ratio 0.8717, running avg 0.8799
(433, 3) (3,)
ratio_count 9 sample 14: 378/433 -> ratio 0.8730, running avg 0.8791
(432, 3) (3,)
ratio_count 10 sample 14: 393/432 -> ratio 0.9097, running avg 0.8822
(511, 3) (3,)
ratio_count 11 sample 14: 406/511 -> ratio 0.7945, running avg 0.8742
(540, 3) (3,)
ratio_count 12 sample 14: 526/540 -> ratio 0.9741, running avg 0.8825
(

KeyboardInterrupt: 

: 

: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, ListedColormap
from matplotlib import cm

def plot_rpc_and_tensor_bev(pc_n4, tensor_121_111_31, save_path,
                            xy_window=6, elev=25, azim=-60, dpi=300):
    """
    pc_n4: (N,4) [x,y,z,power]
    tensor_121_111_31: (121,111,31)
    - pick strongest PC point by power
    - map to (ix0,iy0,iz0) using requested projections
    - take ±xy_window around (ix0,iy0)
    - use the single tensor plane at iz0 for heights (BEV)
    """

    pc = np.asarray(pc_n4)
    T = np.asarray(tensor_121_111_31)
    assert pc.ndim == 2 and pc.shape[1] == 4, "pc must be (N,4)"
    assert T.shape == (121, 111, 31), "tensor must be (121,111,31)"

    # strongest point
    pc = pc[pc[:, -1] > 0.5]
    strongest_idx = int(np.argmin(pc[:, 3]))
    x0, y0, z0, _ = pc[strongest_idx]

    # projections -> tensor indices
    def to_idx(x, y, z):
        ix = int(np.rint(x / 3.0 * 60.0 + 60.0))          # x/3*60 + 60
        iy = int(np.rint((y - 0.5) / 5.5 * 111.0))        # (y-0.5)/5.5*111
        iz = int(np.rint(z / 3.0 * 31.0))                 # z/3*31
        ix = np.clip(ix, 0, 120)
        iy = np.clip(iy, 0, 110)
        iz = np.clip(iz, 0, 30)
        return ix, iy, iz

    ix0, iy0, iz0 = to_idx(x0, y0, z0)

    # XY window (single z plane)
    x_start = max(ix0 - xy_window, 0)
    x_stop  = min(ix0 + xy_window + 1, 121)
    y_start = max(iy0 - xy_window, 0)
    y_stop  = min(iy0 + xy_window + 1, 111)
    xs = np.arange(x_start, x_stop)
    ys = np.arange(y_start, y_stop)

    # heights from single z plane
    H = np.zeros((len(ys), len(xs)), dtype=float)  # [y,x]
    for i, xi in enumerate(xs):
        for j, yj in enumerate(ys):
            H[j, i] = T[xi, yj, iz0]

    # world coords (meters)
    scale_factor = min(6.0 / 120.0, 5.5 / 110.0)  # Use the smaller scaling factor for square grid
    dx_world = scale_factor
    dy_world = scale_factor
    x_world = (xs - 60) * dx_world
    y_world = (ys - 55) * dy_world

    # occupancy ignoring z (any point in XY cell)
    ix_all = np.rint(pc[:, 0] / 3.0 * 60.0 + 60.0).astype(int)
    iy_all = np.rint((pc[:, 1] - 0.5) / 5.5 * 111.0).astype(int)
    ix_all = np.clip(ix_all, 0, 120)
    iy_all = np.clip(iy_all, 0, 110)

    occ = np.zeros_like(H, dtype=int)
    mask_xy = (ix_all >= x_start) & (ix_all < x_stop) & \
              (iy_all >= y_start) & (iy_all < y_stop)
    if np.any(mask_xy):
        ix_w = ix_all[mask_xy] - x_start
        iy_w = iy_all[mask_xy] - y_start
        occ[iy_w, ix_w] = 1

    # plots
    fig = plt.figure(figsize=(7, 12))
    ax3d = fig.add_subplot(2, 1, 1, projection='3d')
    ax2d = fig.add_subplot(2, 1, 2)

    Xc, Yc = np.meshgrid(x_world, y_world, indexing='xy')
    xpos = Xc.ravel()
    ypos = Yc.ravel()
    zpos = np.zeros_like(xpos)
    dx = np.full_like(xpos, dx_world)
    dy = np.full_like(ypos, dy_world)
    dz = H.ravel()

    vmin, vmax = (float(np.min(H)), float(np.max(H))) if H.size else (0.0, 1.0)
    if vmax <= vmin:
        vmax = vmin + 1e-6
    norm = Normalize(vmin=vmin, vmax=vmax)
    colors = cm.viridis(norm(dz))

    ax3d.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors, shade=True, linewidth=0)
    ax3d.set_zlim(0, vmax)  # Dynamically adjust z-axis limit based on max heatmap value
    ax3d.grid(False)
    ax3d.axis('off')  # Remove axis, labels, and ticks

    cmap = ListedColormap(['lightgray', 'limegreen'])
    extent = [x_world[0] - dx_world/2, x_world[-1] + dx_world/2,
              y_world[0] - dy_world/2, y_world[-1] + dy_world/2]
    ax2d.imshow(occ, origin='lower', cmap=cmap, vmin=0, vmax=1,
                interpolation='nearest', extent=extent, aspect='auto')
    ax2d.grid(False)
    ax2d.axis('off')  # Remove axis, labels, and ticks

    plt.tight_layout()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
    plt.close(fig)
    return save_path

: 

: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, ListedColormap
from matplotlib import cm

def plot_rpc_and_tensor_xz(pc_n4, tensor_121_111_31, save_path,
                           xz_window=6, elev=25, azim=-60, dpi=300):
    """
    pc_n4: (N,4) [x,y,z,power]
    tensor_121_111_31: (121,111,31)
    - pick strongest PC point by power
    - map to (ix0,iy0,iz0) using requested projections
    - take ±xz_window around (ix0,iz0)
    - use the single tensor plane at iy0 for heights (X-Z view with energy as Z-axis)
    """

    pc = np.asarray(pc_n4)
    T = np.asarray(tensor_121_111_31)
    assert pc.ndim == 2 and pc.shape[1] == 4, "pc must be (N,4)"
    assert T.shape == (121, 111, 31), "tensor must be (121,111,31)"

    # strongest point
    strongest_idx = int(np.argmax(pc[:, 3]))
    x0, y0, z0, _ = pc[strongest_idx]

    # projections -> tensor indices
    def to_idx(x, y, z):
        ix = int(np.rint(x / 3.0 * 60.0 + 60.0))          # x/3*60 + 60
        iy = int(np.rint((y - 0.5) / 5.5 * 111.0))        # (y-0.5)/5.5*111
        iz = int(np.rint(z / 3.0 * 31.0))                 # z/3*31
        ix = np.clip(ix, 0, 120)
        iy = np.clip(iy, 0, 110)
        iz = np.clip(iz, 0, 30)
        return ix, iy, iz

    ix0, iy0, iz0 = to_idx(x0, y0, z0)

    # XZ window (single y plane)
    x_start = max(ix0 - xz_window, 0)
    x_stop  = min(ix0 + xz_window + 1, 121)
    z_start = max(iz0 - xz_window, 0)
    z_stop  = min(iz0 + xz_window + 1, 31)
    xs = np.arange(x_start, x_stop)
    zs = np.arange(z_start, z_stop)

    # heights from single y plane
    H = np.zeros((len(xs), len(zs)), dtype=float)  # [x,z]
    for i, xi in enumerate(xs):
        for j, zj in enumerate(zs):
            H[i, j] = T[xi, iy0, zj]

    # world coords (meters)
    dx_world = 6.0 / 120.0     # [-3,3] over 121 cells
    dz_world = 3.0 / 30.0      # [0,3] over 31 cells
    x_world = (xs - 60) * dx_world
    z_world = zs * dz_world

    # occupancy ignoring y (any point in XZ cell)
    ix_all = np.rint(pc[:, 0] / 3.0 * 60.0 + 60.0).astype(int)
    iz_all = np.rint(pc[:, 2] / 3.0 * 31.0).astype(int)
    ix_all = np.clip(ix_all, 0, 120)
    iz_all = np.clip(iz_all, 0, 30)

    occ = np.zeros_like(H, dtype=int)
    mask_xz = (ix_all >= x_start) & (ix_all < x_stop) & \
              (iz_all >= z_start) & (iz_all < z_stop)
    if np.any(mask_xz):
        ix_w = ix_all[mask_xz] - x_start
        iz_w = iz_all[mask_xz] - z_start
        occ[ix_w, iz_w] = 1

    # plots
    fig = plt.figure(figsize=(7, 9))
    ax3d = fig.add_subplot(2, 1, 1, projection='3d')
    ax2d = fig.add_subplot(2, 1, 2)

    Xc, Zc = np.meshgrid(x_world, z_world, indexing='xy')
    xpos = Xc.ravel()
    zpos = Zc.ravel()
    ypos = np.zeros_like(xpos)
    dx = np.full_like(xpos, dx_world)
    dz = np.full_like(zpos, dz_world)
    dy = H.ravel()

    vmin, vmax = (float(np.min(H)), float(np.max(H))) if H.size else (0.0, 1.0)
    if vmax <= vmin:
        vmax = vmin + 1e-6
    norm = Normalize(vmin=vmin, vmax=vmax)
    colors = cm.viridis(norm(dy))

    ax3d.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors, shade=True, linewidth=0)
    ax3d.set_xlabel('X (m)')
    ax3d.set_ylabel('Z (m)')
    ax3d.set_zlabel('Energy')
    ax3d.view_init(elev=elev, azim=azim)
    ax3d.grid(False)

    cmap = ListedColormap(['lightgray', 'limegreen'])
    extent = [x_world[0] - dx_world/2, x_world[-1] + dx_world/2,
              z_world[0] - dz_world/2, z_world[-1] + dz_world/2]
    ax2d.imshow(occ.T, origin='lower', cmap=cmap, vmin=0, vmax=1,
                interpolation='nearest', extent=extent, aspect='auto')
    ax2d.set_xlabel('X (m)')
    ax2d.set_ylabel('Z (m)')
    ax2d.grid(False)

    plt.tight_layout()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
    plt.close(fig)
    return save_path

: 

: 

In [None]:
dataset_test_select = torch.utils.data.Subset(dataset_test, range(5800, len(dataset_test)))
data_loader = torch.utils.data.DataLoader(
    dataset_test_select, batch_size=1, shuffle=True
)
joint_center_list = data_pc["pred_params_test"]["trans"]
joint_center_list = [torch.tensor(joint_center_list[i]) for i in range(joint_center_list.shape[0])]
ratio_count = 0
ratio_PC = 0

for batch in data_loader:
    indicator = batch['indicator'][0].numpy()
    radar = batch['radar_points'][0][-1].numpy().reshape(-1, 4)
    radar = radar[~np.all(radar == 0, axis=1)]
    # center_joint = batch['joints_root'][0][0].numpy()
    radar_tensor = batch['rawImage_XYZ'][0][-1]
    print(radar.shape, radar_tensor.shape)
    out_png = plot_rpc_and_tensor_bev(radar, radar_tensor, 'radar_bev.png', xy_window=6)
    print('saved to', out_png)
    time.sleep(2)

(448, 4) torch.Size([121, 111, 31])
saved to radar_bev.png
(451, 4) torch.Size([121, 111, 31])
saved to radar_bev.png
(532, 4) torch.Size([121, 111, 31])
saved to radar_bev.png


KeyboardInterrupt: 

: 

: 