# HOPNet Autoregressive Rollout Visualization

This notebook can be used to visualize in 3D the autoregressive rollout output of a 
pre-trained model on a single dataset sample.

Modify the [1. Sample & Model Configuration](#1-sample--model-configuration) code cell
to adjust the model and the dataset sample you want to process.

## 0. Notebook Configuration

This section should not be modified.

In [None]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

# Packages import
from copy import deepcopy
import json
from os import path as opath
from sys import path as spath

spath.append(opath.join(opath.abspath(''), ".."))

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import torch
from tqdm.auto import tqdm
import trimesh

from data.movi_dataset import MoviNormalization
from scripts.main import get_meshes, get_model, quat_multiply
from utils.complexes import compute_nodes_and_objects_positions
from utils.plots import plot_errors
from utils.rollout import (
    build_base_complex,
    build_featured_complex,
    compute_nodes_positions,
    model_input_from_ccc,
    positions_from_model_output,
    shape_matching,
)

# General Configuration (do not change)
H: int = 2  # Horizon
MESHES_LOCATION: str = "../data/objects/"
MESH_FILENAME: str = "collision_geometry.obj"
METADATA_FILENAME: str = "metadata.json"
COLOR_PALETTE = px.colors.qualitative.Plotly

FLOOR_FRICTION: float = 0.30  # Constant from Kubric
FLOOR_RESTITUTION: float = 0.50  # Constant from Kubric
FLOOR_SIZE: float = 20.0  # Minimal approximation, Kubric uses 40
FLOOR = trimesh.Trimesh(
    vertices=[
        [-FLOOR_SIZE, -FLOOR_SIZE, 0],
        [-FLOOR_SIZE, FLOOR_SIZE, 0],
        [FLOOR_SIZE, -FLOOR_SIZE, 0],
        [FLOOR_SIZE, FLOOR_SIZE, 0],
    ],
    faces=[[0, 1, 2], [1, 2, 3]],
)


## 1. Sample & Model Configuration

First, choose the model checkpoint that you want to test. Make sure that you select the
right normalization file (*i.e.* the one with which the model was trained).

Then, choose the dataset sample you want to process.

In [None]:
# MODEL CONFIGURATION
MODEL_NAME: str = "HOPNet"
MODEL_CHECKPOINT_FILE: str = "../checkpoints/models_seed0_e39.pt"
NORMALIZATION_FILE: str = "../samples/normalization-movis.npy"
COLLISION_RADIUS: float = 0.1 # Default = 0.1
NUM_CHANNELS: int = 128 # Default = 128
NUM_LAYERS: int = 1 # Default = 1
MLP_LAYERS: int = 2 # Default = 2

# DATASET SAMPLE CONFIGURATION
SAMPLE_PATH: str = "../samples/MoVi-B/1"


In [None]:
# Validate the configuration (do not change)
assert MODEL_NAME in ["HOPNet", "NoObjectCells", "NoSequential"]
assert NUM_CHANNELS > 0
assert NUM_LAYERS > 0
assert MLP_LAYERS > 0


## 2. Load Sample Data

This section loads the sample data using the specified location in the previous cell.

In [None]:
# Load the metadata
metadata_path: str = opath.join(SAMPLE_PATH, METADATA_FILENAME)
with open(metadata_path) as f:
    metadata: dict = json.load(f)

# Extract key information about the sample
num_frames: int = metadata["metadata"]["num_frames"]
num_objects: int = metadata["metadata"]["num_instances"]
step_rate_hz: int = metadata["metadata"]["step_rate"]
objects: list[dict] = metadata["instances"]

# Create the virtual scene representing the sample
meshes = get_meshes(objects)

# Add floor data (not present by default in the sample metadata)
floor_metadata = {
    "asset_id": "floor",
    "angular_velocities": np.zeros((num_frames, 3)).tolist(),
    "friction": FLOOR_FRICTION,
    "mass": 0.0,
    "positions": np.zeros((num_frames, 3)).tolist(),
    "quaternions": np.repeat([[1.0, 0.0, 0.0, 0.0]], num_frames, axis=0).tolist(),
    "restitution": FLOOR_RESTITUTION,
    "size": 1.0,
    "velocities": np.zeros((num_frames, 3)).tolist(),
}
objects.append(floor_metadata)
meshes.append(FLOOR)

assert len(objects) == len(meshes)


## 3. Load the Model

This section loads the model using the specified location provided in 
[1. Sample & Model Configuration](#1-sample--model-configuration).

In [None]:
# Load the normalization used for model training
normalization = MoviNormalization(NORMALIZATION_FILE)

# Load the model checkpoint
device = torch.device("cpu")
model = get_model(MODEL_NAME, NUM_CHANNELS, NUM_LAYERS, MLP_LAYERS).to(device)
weights = torch.load(MODEL_CHECKPOINT_FILE, map_location=device)
model.load_state_dict(weights, strict=True)


## 4. Compute and Visualize the Ground Truth

Compute the actual ground truth using the ground truth positions and the ground
truth accelerations.

In [None]:
# Compute the target features learned by the network
(
    nodes_positions,
    nodes_target_a,
    nodes_feature_v,
    objects_positions,
    objects_target_a,
    objects_feature_v,
    complex,
    triangles_ids,
) = compute_nodes_and_objects_positions(objects, meshes, MODEL_NAME == "NoObjectCells")

objects_quaternions = np.moveaxis(
    np.array([obj["quaternions"] for obj in objects]), 0, 1
)

# Re-compute the positions if the model was perfect
gt_nodes_positions = np.zeros_like(nodes_positions)
gt_nodes_positions[1:, :, :] = (
    nodes_target_a[1:, :, :]
    + 2 * nodes_positions[1:, :, :]
    - nodes_positions[0:-1, :, :]
)

# Re-align with the model output based on the horizon
t_start: int = H
t_end: int = gt_nodes_positions.shape[0] - 1

# At time t, the GT node positions are equal to the node positions at t+1
assert np.isclose(
    gt_nodes_positions[t_start:t_end, :, :],
    nodes_positions[t_start + 1 : t_end + 1, :, :],
).all()

# Re-compute the positions if the model was perfect
gt_objs_positions = np.zeros_like(objects_positions)
gt_objs_positions[1:, :, :] = (
    objects_target_a[1:, :, :]
    + 2 * objects_positions[1:, :, :]
    - objects_positions[0:-1, :, :]
)

# Re-align with the model output based on the horizon
t_start: int = H
t_end: int = gt_objs_positions.shape[0] - 1

# At time t, the GT node positions are equal to the node positions at t+1
assert np.isclose(
    gt_objs_positions[t_start:t_end, :, :],
    objects_positions[t_start + 1 : t_end + 1, :, :],
).all()


The following cell plots the distribution of the target acceleration for nodes.

In [None]:
# Visualize the distribution of the acceleration
fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=["X-Acceleration", "Y-acceleration", "Z-acceleration"],
    shared_yaxes=True,
)
fig.add_histogram(
    x=nodes_target_a[t_start:t_end, :, 0].flatten(),
    row=1,
    col=1,
    histnorm="percent",
    name="X-Acceleration",
)
fig.add_histogram(
    x=nodes_target_a[t_start:t_end, :, 1].flatten(),
    row=1,
    col=2,
    histnorm="percent",
    name="Y-acceleration",
)
fig.add_histogram(
    x=nodes_target_a[t_start:t_end, :, 2].flatten(),
    row=1,
    col=3,
    histnorm="percent",
    name="Z-acceleration",
)
for i in range(1, 4):
    fig.update_xaxes(title=r"Acceleration ms^{-2}", row=1, col=i)
fig.update_yaxes(title_text=r"Percentage [%]", row=1, col=1)
fig.update_layout(yaxis1_type="log")
fig.update_layout(yaxis2_type="log")
fig.update_layout(yaxis3_type="log")
fig.show()

print("X-accel mean:", np.mean(nodes_target_a[t_start:t_end, :, 0]))
print("X-accel std :", np.std(nodes_target_a[t_start:t_end, :, 0]))
print("Y-accel mean:", np.mean(nodes_target_a[t_start:t_end, :, 1]))
print("Y-accel std :", np.std(nodes_target_a[t_start:t_end, :, 1]))
print("Z-accel mean:", np.mean(nodes_target_a[t_start:t_end, :, 2]))
print("Z-accel std :", np.std(nodes_target_a[t_start:t_end, :, 2]))


The following cells shows the ground truth of the selected dataset sample. Use the
sliding bar at the bottom of the figure to navigate the timeline.

In [None]:
def plot_time(time: int, pos: np.ndarray, nodes_and_edges: bool ) -> list[go.Mesh3d]:
    data = []
    nodes_count: int = 0
    for obj_idx, mesh in enumerate(meshes):
        tmp_mesh = mesh.copy()
        tmp_pos: np.ndarray = np.array(objects[obj_idx]["positions"][time])
        quat: np.ndarray = np.array(objects[obj_idx]["quaternions"][time])
        tmp_mesh.apply_transform(trimesh.transformations.quaternion_matrix(quat))
        tmp_mesh.apply_translation(tmp_pos)
        mesh_nodes = len(tmp_mesh.vertices)
        data.append(
            go.Mesh3d(
                x=pos[time, nodes_count : nodes_count + mesh_nodes, 0],
                y=pos[time, nodes_count : nodes_count + mesh_nodes, 1],
                z=pos[time, nodes_count : nodes_count + mesh_nodes, 2],
                i=mesh.faces[:, 0],
                j=mesh.faces[:, 1],
                k=mesh.faces[:, 2],
                showscale=True,
                color=px.colors.qualitative.Plotly[obj_idx] if obj_idx != len(meshes) -1 else "#1F77B4"
            )
        )
        nodes_count += mesh_nodes

        if nodes_and_edges:
            data.append(
                go.Scatter3d(
                    x=tmp_mesh.vertices[:, 0],
                    y=tmp_mesh.vertices[:, 1],
                    z=tmp_mesh.vertices[:, 2],
                    mode="markers",
                    marker=dict(size=2, color="black"),
                )
            )

            # Uncomment to mesh triangles edges as lines
            for face_idx in range(len(tmp_mesh.triangles)):
                data.append(
                    go.Scatter3d(
                        x=tmp_mesh.triangles[face_idx][[0, 1, 2, 0], 0],
                        y=tmp_mesh.triangles[face_idx][[0, 1, 2, 0], 1],
                        z=tmp_mesh.triangles[face_idx][[0, 1, 2, 0], 2],
                        mode="lines",
                        line=dict(color="black", width=2),
                    )
                )
    return data


In [None]:
t_end = 478
slider_steps = [
    {
        "args": [
            [time],
            {
                "frame": {"duration": 1 / 240, "redraw": True},
                "mode": "immediate",
                "transition": {"duration": 300},
            },
        ],
        "label": time,
        "method": "animate",
    }
    for time in range(t_start + 1, t_end)
]

sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "transition": {"duration": t_end - t_start, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": slider_steps,
}

frames = [
    go.Frame(
        data=plot_time(t, gt_nodes_positions, "MoVi-B" not in SAMPLE_PATH), name=str(t)
    )
    for t in range(t_start + 1, t_end)
]
fig = go.Figure(
    data=plot_time(t_start, gt_nodes_positions, "MoVi-B" not in SAMPLE_PATH),
    frames=frames,
)
fig.update_layout(
    scene=dict(
        xaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
        yaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
        zaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
    ),
    height=1200,
    margin=dict(r=20, l=10, b=10, t=10),
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                {
                    "args": [
                        None,
                        {
                            "frame": {"duration": 0, "redraw": True},
                            "fromcurrent": True,
                        },
                    ],
                    "label": "Play",
                    "method": "animate",
                },
                {
                    "args": [
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                    "label": "Pause",
                    "method": "animate",
                },
            ],
        )
    ],
    sliders=[sliders_dict],
)

fig.show()


## 5. Compute Autoregressive Rollout

This section computes the auto-regressive rollout on the selected sample using the
selected model.

In [None]:
# Autoregressive rollout configuration
START_TIME: int = 100
ROLLOUT_DURATION: int = 100

assert START_TIME > 0
assert ROLLOUT_DURATION > 3


In [None]:
# Step 0: Compute the first nodes and objects positions for horizon H
objects_init = deepcopy(objects)
base_ccc, triangle_ids, obj_idx_from_node_idx = build_base_complex(
    meshes, objects, MODEL_NAME == "NoObjectCells"
)

# Keep only the first H+1 timesteps
for obj in objects_init:
    obj["positions"] = np.array(obj["positions"]).tolist()
    obj["quaternions"] = np.array(obj["quaternions"]).tolist()

# Initial object positions
obj_pos = np.array(
    [
        np.array(obj["positions"])[START_TIME : START_TIME + H + 2]
        for obj in objects_init
    ]
)
obj_pos = np.moveaxis(obj_pos, 0, 1)  # Shape [timesteps, obj_count, 3]
obj_quat = np.array(
    [
        np.array(obj["quaternions"])[START_TIME : START_TIME + H + 2]
        for obj in objects_init
    ]
)
obj_quat = np.moveaxis(obj_quat, 0, 1)  # Shape [timesteps, obj_count, 4]

# Ground truth from metadata.json
objects_quaternions = np.array([o["quaternions"] for o in objects])
objects_quaternions = np.swapaxes(objects_quaternions, 0, 1)

# Run the rollout for a certain number of timesteps
for t in tqdm(range(H + 1, H + 1 + ROLLOUT_DURATION)):
    # Step 1: Create the CCC
    ccc, nodes_pos, _ = build_featured_complex(
        base_ccc,
        None,
        triangle_ids,
        obj_idx_from_node_idx,
        obj_pos[t - H : t + 1],
        obj_quat[t - H : t + 1],
        objects,
        meshes,
        MODEL_NAME == "NoObjectCells",
        H,
        COLLISION_RADIUS,
    )

    # Step 2: Model inference
    with torch.no_grad():
        model.eval()
        out0, out4 = model(
            *model_input_from_ccc(
                ccc,
                horizon=H,
                norm=normalization,
                model=model,
                model_name=MODEL_NAME,
                device=device,
            )
        )

    # Step 3: Compute predicted positions from accelerations
    pred_nodes_pos = positions_from_model_output(
        out0, normalization, nodes_pos[-2:]
    )

    # Step 4: Compute shape matching for the new object rotation
    new_obj_pos, new_obj_quat = shape_matching(pred_nodes_pos, meshes)

    # Step 5: Update the obj_pos and obj_quat with the predicted positions
    obj_pos = np.concatenate((obj_pos, np.expand_dims(new_obj_pos, 0)), axis=0)
    obj_quat = np.concatenate((obj_quat, np.expand_dims(new_obj_quat, 0)), axis=0)

nodes_pos = np.stack(
    [
        compute_nodes_positions(meshes, obj_pos[t], obj_quat[t])
        for t in range(obj_pos.shape[0])
    ]
)


## 6. Visualize the Autogressive Rollout

This section shows the autoregressive rollout. Use the sliding bar at the bottom of the
figure to navigate the timeline.

In [None]:
slider_steps = [
    {
        "args": [
            [time],
            {
                "frame": {"duration": 1 / 240, "redraw": True},
                "mode": "immediate",
                "transition": {"duration": 300},
            },
        ],
        "label": time,
        "method": "animate",
    }
    for time in range(0, nodes_pos.shape[0])
]

sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "transition": {"duration": t_end - t_start, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": slider_steps,
}

frames3 = [
    go.Frame(data=plot_time(t, nodes_pos, "MoVi-B" not in SAMPLE_PATH), name=str(t))
    for t in range(0, nodes_pos.shape[0])
]
fig3 = go.Figure(
    data=plot_time(0, nodes_pos, "MoVi-B" not in SAMPLE_PATH), frames=frames3
)
fig3.update_layout(
    scene=dict(
        xaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
        yaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
        zaxis=dict(
            nticks=4,
            range=[-8, 8],
        ),
    ),
    height=1200,
    margin=dict(r=20, l=10, b=10, t=10),
    updatemenus=[
        dict(
            type="buttons",
            buttons=[
                {
                    "args": [
                        None,
                        {
                            "frame": {"duration": 0, "redraw": True},
                            "fromcurrent": True,
                        },
                    ],
                    "label": "Play",
                    "method": "animate",
                },
                {
                    "args": [
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                    "label": "Pause",
                    "method": "animate",
                },
            ],
        )
    ],
    sliders=[sliders_dict],
)

fig3.show()


In [None]:
all_errors = np.zeros((1, 1, 5, ROLLOUT_DURATION + H + 2))

# Compute position RMSE and MAE
pos_err = (
    obj_pos - objects_positions[START_TIME : START_TIME + ROLLOUT_DURATION + H + 2]
)
pos_mae = np.mean(np.linalg.norm(pos_err, axis=-1), axis=-1)
pos_rmse = np.sqrt(np.mean(np.linalg.norm(pos_err, axis=-1) ** 2, axis=-1))
all_errors[0, 0, 0] = pos_mae
all_errors[0, 0, 1] = pos_rmse

# Compute orientation RMSE and MAE
obj_quat_conj = deepcopy(obj_quat)
obj_quat_conj[:, :, 1:4] *= -1
ori_err = 2 * np.arcsin(
    np.linalg.norm(
        quat_multiply(
            objects_quaternions[START_TIME : START_TIME + ROLLOUT_DURATION + H + 2],
            obj_quat_conj,
        )[:, :, 1:],
        axis=-1,
    )
)
ori_err *= 360 / (2 * np.pi)
ori_rmse = np.sqrt(np.nanmean(ori_err**2, axis=-1))
ori_mae = np.nanmean(ori_err, axis=-1)
all_errors[0, 0, 2] = ori_mae
all_errors[0, 0, 3] = ori_rmse

fig = plot_errors(all_errors)
fig.show()
