# Setup all settings here! (Except the plotting codes)

In [None]:
from pathlib import Path
import os
import logging
import sys

logging.basicConfig(level=logging.DEBUG)

root = os.environ.get("PIXI_PROJECT_ROOT", "")
if root == "":
    logging.warning("PIXI_PROJECT_ROOT environment variable not set. Using default root path '/' which may not be correct.")

sys.path.append(str(Path(root) / "PyriteML"))

In [None]:
task_name = "train_multimodal_conv_workspace"
data_path = Path(root) / "training_outputs"
train_name = "2026.02.28_19.50.59_cable_mounting_multimodal_conv_230"
ckpt_name = "epoch=0000-train_loss=0.602"

# By default the kinect data is disabled!
workspace_config_overrides = ['kinect=enabled']

output_dir = Path("ckpt_test_output")
output_dir.mkdir(exist_ok=True)

In [None]:
ckpt_path = data_path / train_name / "checkpoints" / (ckpt_name + ".ckpt")
if not ckpt_path.exists():
    raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
else:
    logging.info(f"Checkpoint found at {ckpt_path}")

# Load the checkpoint

In [None]:
import sys
import os
from typing import Dict, Callable, Tuple, List

import numpy as np
import torch
import time
import dill
import hydra
from torch.utils.data import DataLoader

from PyriteML.diffusion_policy.workspace.base_workspace import BaseWorkspace
from PyriteML.diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset
from PyriteML.diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace


device = torch.device('cpu')


payload = torch.load(open(ckpt_path, 'rb'), map_location='cpu', pickle_module=dill)
cfg = payload['cfg']
print("model_name:", cfg.policy.obs_encoder._target_)
print("dataset_path:", cfg.task.dataset.dataset_path)

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.num_inference_steps = cfg.policy.num_inference_steps # DDIM inference iterations

policy.eval().to(device)
policy.reset()

# use normalizer saved in the policy
sparse_normalizer = policy.get_normalizer()

shape_meta = cfg.task.shape_meta

In [None]:
# for item in cfg:
#     logging.info(f"{item}: {cfg[item]}")
# for item in cfg.policy:
#     logging.info(f"policy.{item}: {cfg.policy[item]}")

# for item in cfg.policy.obs_encoder:
#     logging.info(f"policy.obs_encoder.{item}: {cfg.policy.obs_encoder[item]}")

# for item in cfg.task:
#     logging.info(f"task.{item}: {cfg.task[item]}")

# Load a dataset

## Setup the task name here!

In [None]:
task_config_path = Path("PyriteML")/ "diffusion_policy" / "config" / (task_name + ".yaml")

In [None]:
# # load the dataset used in training
# dataset: BaseImageDataset
# dataset = hydra.utils.instantiate(cfg.task.dataset)
# assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
# print("Test Script: Creating dataloader.")
# train_dataloader = DataLoader(dataset, **cfg.dataloader)
# print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))

# load the dataset specified in config
from hydra import compose, initialize
from omegaconf import OmegaConf

if not task_config_path.exists():
    raise FileNotFoundError(f"Task config not found at {task_config_path}")

with initialize(
    version_base=None,
    config_path=str(task_config_path.parent),
    job_name="test_multi_modal_checkpoint",
):
    cfg = compose(config_name=task_name, overrides = workspace_config_overrides)
    OmegaConf.resolve(cfg)

    logging.info("Test Script: configuring dataset.")
    dataset: BaseImageDataset
    dataset = hydra.utils.instantiate(cfg.task.dataset)
    # assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
    logging.info("Test Script: Creating dataloader.")
    train_dataloader = DataLoader(dataset, **cfg.dataloader)
    logging.info('train dataset: %d train dataloader: %d', len(dataset), len(train_dataloader))

# Run some tests (deprecated)

In [None]:
# import torch.nn.functional as F
# from einops import rearrange, reduce
# import json
# from einops import rearrange
# def log_action_mse(step_log, category, pred_action, gt_action):
#     pred_naction = {
#         'sparse': sparse_normalizer['action'].normalize(pred_action['sparse']),
#         # 'dense': dense_normalizer['action'].normalize(pred_action['dense'])
#     }
#     gt_naction = {
#         'sparse': sparse_normalizer['action'].normalize(gt_action['sparse']),
#         # 'dense': dense_normalizer['action'].normalize(gt_action['dense'])
#     }

#     B, T, _ = pred_naction['sparse'].shape
#     pred_naction_sparse = rearrange(pred_naction['sparse'], 'batch time action_dim -> batch time action_dim')
#     gt_naction_sparse = rearrange(gt_naction['sparse'], 'batch time action_dim -> batch time action_dim')
#     sparse_loss = F.mse_loss(pred_naction_sparse, gt_naction_sparse, reduction='none')
#     sparse_loss = sparse_loss.type(sparse_loss.dtype)
#     sparse_loss = reduce(sparse_loss, 'b ... -> b (...)', 'mean')
#     sparse_loss = sparse_loss.mean()            

#     step_log[f'{category}_sparse_naction_mse_error'] = float(sparse_loss.detach())
#     # step_log[f'{category}_sparse_naction_mse_error_pos'] = F.mse_loss(pred_naction_sparse[..., :3], gt_naction_sparse[..., :3])
#     # step_log[f'{category}_sparse_naction_mse_error_rot'] = F.mse_loss(pred_naction_sparse[..., 3:9], gt_naction_sparse[..., 3:9])
#     # B, T, _, _= pred_naction['dense'].shape
#     # pred_naction_dense = pred_naction['dense'].view(B, T, -1, 9)
#     # gt_naction_dense = gt_naction['dense'].view(B, T, -1, 9)
#     # dense_loss = F.mse_loss(pred_naction_dense, gt_naction_dense, reduction='none')
#     # dense_loss = dense_loss.type(dense_loss.dtype)
#     # dense_loss = reduce(dense_loss, 'b ... -> b (...)', 'mean')
#     # dense_loss = dense_loss.mean()            
#     # step_log[f'{category}_dense_naction_mse_error'] = float(dense_loss.detach())
#     # step_log[f'{category}_dense_naction_mse_error_pos'] = F.mse_loss(pred_naction_dense[..., :3], gt_naction_dense[..., :3])
#     # step_log[f'{category}_dense_naction_mse_error_rot'] = F.mse_loss(pred_naction_dense[..., 3:9], gt_naction_dense[..., 3:9])
    
# # get a batch of data'
# print('get a batch of data')
# batch = next(iter(train_dataloader))

# # print(batch.keys())
# # for key, attr in batch['obs']['sparse'].items():
# #     print("   obs.sparse.key: ", key, attr.shape)
# # for key, attr in batch['obs']['dense'].items():
# #     print("   obs.dense.key: ", key, attr.shape)
# # for key, attr in batch['action'].items():
# #     print("   action.key: ", key, attr.shape)


# Let's plot something here!

## Define the action mapping

In [None]:

from torch.nn import functional as F
from scipy.spatial.transform import Rotation as R
def rot6d_to_R(rot6d : torch.Tensor) -> R:
    assert rot6d.shape[-1] == 6
    a = rot6d[..., :3]
    b = rot6d[..., 3:6]
    a = F.normalize(a, dim=-1)
    b = F.normalize(b - (a * b).sum(dim=-1, keepdim=True) * a, dim=-1)
    c = torch.cross(a, b, dim=-1)
    R_mat = torch.stack([a, b, c], dim=-2)
    return R.from_matrix(R_mat.detach().cpu().numpy())

def identity(x : torch.Tensor | np.ndarray) -> np.ndarray:
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return x

def R_to_euler(R_mat : R) -> np.ndarray:
    return R_mat.as_euler('xyz', degrees=True)

def cal_R_err(pred : R, gt : R) -> np.ndarray:
    R_err = pred.inv() * gt
    angle_err = np.rad2deg(R_err.magnitude())
    return angle_err

def cal_err(pred : np.ndarray, gt : np.ndarray) -> np.ndarray:
    return np.linalg.norm(pred - gt, axis=-1)

def _ensure_2d(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr)
    if arr.ndim == 1:
        arr = arr[:, None]
    return arr

def _dim_labels(name: str, d: int):
    if d == 3:
        # Euler angles or translation
        return ['x', 'y', 'z']
    elif d == 1:
        return ['value']
    else:
        return [f'dim{i}' for i in range(d)]


from typing import Any
# Right:  pose9 + vt_pose9 + stiffness1 + grip1 = 20
# Left:   pose9 + grip1 = 10
# Total = 30

action_map : Dict[str, Tuple[Tuple[slice, slice], Callable[[torch.Tensor], Any], Callable[[Any], np.ndarray]]] = {
    'trans_left': ((slice(None), slice(0, 3)), identity, identity, cal_err),
    'rotation_left': ((slice(None), slice(3, 9)), rot6d_to_R,R_to_euler,  cal_R_err),
    'vt_trans_left': ((slice(None), slice(9, 12)), identity, identity, cal_err),
    'vt_rotation_left': ((slice(None), slice(12, 18)), rot6d_to_R, R_to_euler, cal_R_err),
    'stiffness_left': ((slice(None), slice(18, 19)), identity, identity, cal_err),
    'gripper_left': ((slice(None), slice(19, 20)), identity, identity, cal_err),
    'trans_right': ((slice(None), slice(20, 23)), identity, identity, cal_err),
    'rotation_right': ((slice(None), slice(23, 29)), rot6d_to_R, R_to_euler, cal_R_err),
    'gripper_right': ((slice(None), slice(29, 30)), identity, identity, cal_err),
}


In [None]:
batch = next(iter(train_dataloader))
gt_action = batch['action']
pred_action = policy.predict_action(batch['obs'])

## Plot whatever you want

In [None]:
from matplotlib import pyplot as plt
def visualize_sparse_action_comparison(
    gt_action_sparse: torch.Tensor,
    pred_action_sparse: torch.Tensor,
    action_map: Dict[str, Tuple[Tuple[slice, slice], Callable[[torch.Tensor], Any], Callable[[Any], np.ndarray], Callable[[Any, Any], np.ndarray]]],
    save_path: Path | str = None,
    dpi: int = 180,
):
    """
    gt_action_sparse: [T, 30]
    pred_action_sparse: [T, 30]
    """
    if isinstance(save_path, Path):
        save_path = str(save_path)
    assert gt_action_sparse.shape == pred_action_sparse.shape, (
        f"Shape mismatch: gt={gt_action_sparse.shape}, pred={pred_action_sparse.shape}"
    )
    assert gt_action_sparse.ndim == 2, (
        f"Expect [T, A], got {gt_action_sparse.shape}"
    )

    T, A = gt_action_sparse.shape
    time_idx = np.arange(T)

    n_items = len(action_map)

    groups_per_row = 3
    nrows = np.ceil(n_items / groups_per_row).astype(np.int32)
    ncols = groups_per_row * 2

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(40, 4.2 * nrows),
        sharex=True,
        constrained_layout=True
    )

    axes = np.atleast_2d(axes)

    summary = {}

    for idx_item, (name, (idx, convert_fn, vis_fn, err_fn)) in enumerate(action_map.items()):
        row = idx_item // groups_per_row
        group_col = idx_item % groups_per_row

        ax_val = axes[row, group_col * 2]
        ax_err = axes[row, group_col * 2 + 1]

        # raw slices: [T, d_raw]
        gt_raw = gt_action_sparse[idx]
        pred_raw = pred_action_sparse[idx]

        # converted objects
        gt_obj = convert_fn(gt_raw)
        pred_obj = convert_fn(pred_raw)

        # values to visualize
        gt_vis = vis_fn(gt_obj)
        pred_vis = vis_fn(pred_obj)

        is_rotation = (vis_fn is R_to_euler)

        if is_rotation:
            # gt_vis, pred_vis: [T, 3], in degrees
            gt_vis = _ensure_2d(gt_vis)
            pred_vis = _ensure_2d(pred_vis)

            # rotation error from your function: radians
            rot_err = err_fn(pred_obj, gt_obj)   # shape [T]
            rot_err = np.asarray(rot_err)

            rot_err_plot = rot_err
            err_unit = "deg"

            labels = ['x', 'y', 'z']
            for j in range(3):
                ax_val.plot(time_idx, gt_vis[:, j], linestyle='-',  label=f'gt_{labels[j]}')
                ax_val.plot(time_idx, pred_vis[:, j], linestyle='--', label=f'pred_{labels[j]}')

            ax_err.plot(time_idx, rot_err_plot, linewidth=1.8, label='rotation error')

            mean_err = float(np.mean(rot_err_plot))
            max_err = float(np.max(rot_err_plot))
            rmse_err = float(np.sqrt(np.mean(rot_err_plot ** 2)))

            summary[name] = {
                'type': 'rotation',
                'mean_err_deg': mean_err,
                'max_err_deg': max_err,
                'rmse_err_deg': rmse_err,
            }

            ax_val.set_title(f'{name} | Euler xyz (deg): pred vs gt')
            ax_err.set_title(
                f'{name} | geodesic error ({err_unit}) | mean={mean_err:.3f}, max={max_err:.3f}'
            )
            ax_val.set_ylabel('deg')
            ax_err.set_ylabel(err_unit)

        else:
            # numerical values
            gt_vis = _ensure_2d(gt_vis)
            pred_vis = _ensure_2d(pred_vis)

            d = gt_vis.shape[1]
            labels = _dim_labels(name, d)

            # 画 pred vs gt
            for j in range(d):
                ax_val.plot(time_idx, gt_vis[:, j], linestyle='-',  label=f'gt_{labels[j]}')
                ax_val.plot(time_idx, pred_vis[:, j], linestyle='--', label=f'pred_{labels[j]}')

            # 现在 err_fn 返回的是每个时刻的误差，而不是整体 MSE
            # - translation: shape [T], each step = ||pred - gt||_2
            # - stiffness/gripper (1D): shape [T], each step = |pred - gt|
            err_curve = np.asarray(err_fn(pred_vis, gt_vis)).reshape(-1)

            ax_err.plot(time_idx, err_curve, linewidth=1.8, label='error')

            mean_err = float(np.mean(err_curve))
            max_err = float(np.max(err_curve))
            rmse_err = float(np.sqrt(np.mean(err_curve ** 2)))
            std_err = float(np.std(err_curve))

            summary[name] = {
                'type': 'numeric',
                'mean_err': mean_err,
                'max_err': max_err,
                'rmse_err': rmse_err,
                'std_err': std_err,
            }

            ax_val.set_title(f'{name} | pred vs gt')
            ax_err.set_title(
                f'{name} | error norm | mean={mean_err:.4f}, max={max_err:.4f}, rmse={rmse_err:.4f}'
            )
            ax_val.set_ylabel('value')
            ax_err.set_ylabel('error')

        ax_val.grid(True, alpha=0.3)
        ax_err.grid(True, alpha=0.3)

        ax_val.legend(fontsize=8, ncol=3, loc='best')
        ax_err.legend(fontsize=8, ncol=2, loc='best')


    total_axes_used = n_items * 2
    total_axes = nrows * ncols

    for k in range(total_axes_used, total_axes):
        r = k // ncols
        c = k % ncols
        axes[r, c].axis('off')

    for c in range(ncols):
        axes[-1, c].set_xlabel('timestep')

    fig.suptitle('Sparse Action Comparison: Prediction vs Ground Truth', fontsize=16)

    if save_path is not None:
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight')

    plt.show()
    return fig, summary

In [None]:
batch_idx = 0
gt_action_sparse = gt_action['sparse'][batch_idx]      # expected shape [T, dim_action]
pred_action_sparse = pred_action['sparse'][batch_idx]  # expected shape [T, dim_action]

fig, summary = visualize_sparse_action_comparison(
    gt_action_sparse=gt_action_sparse,
    pred_action_sparse=pred_action_sparse,
    action_map=action_map,
    save_path= output_dir / 'action_comparison.png',
)

for name, metrics in summary.items():
    logging.info(f"{name}: {metrics}")