# Buildots Rig: Minimal inference demo

This notebook loads Buildots data via **BuildotsRigAdapter**, then runs:
1. **Original MapAnything** (pretrained, outputs full 6-DoF poses)
2. **Rig MapAnything** (rig config, outputs xyz + yaw)

**Setup:** Set `BUILDOTS_PYCODE_PATH` and `DATA_ROOT` below to your Buildots repo and dataset path.

In [None]:
# Path to your Buildots pycode (for BuildotsDataset)
BUILDOTS_PYCODE_PATH = "/Users/jenia/projects/buildots/pycode"  # adjust as needed
# Buildots dataset root (segment folders)
DATA_ROOT = "/bd-resources/jenia/dataset/buildots_da3"  # adjust as needed

import sys
if BUILDOTS_PYCODE_PATH and BUILDOTS_PYCODE_PATH not in sys.path:
    sys.path.insert(0, BUILDOTS_PYCODE_PATH)

import torch
from mapanything.datasets.buildots_rig import BuildotsRigAdapter, buildots_rig_collate_fn

## 1. Load Buildots data with BuildotsRigAdapter

In [None]:
from research.positioning_net.buildots_dataset_generator import BuildotsDataset

buildots_ds = BuildotsDataset(
    root_dir=DATA_ROOT,
    seq_len=10,
    image_size=(350, 350),
    fov=90,
    debug=False,
    training=False,
)

adapter = BuildotsRigAdapter(
    buildots_ds,
    num_timestamps=3,
    data_norm_type="dinov2",  # MapAnything encoder expects dinov2
)

print(f"Adapter length: {len(adapter)}")
print(f"Views per sample: {3 * 4}")

In [None]:
def add_batch_dim(views):
    """Turn a list of per-view dicts (single sample) into batch-of-1 format for model.forward."""
    batch = []
    for v in views:
        out = {}
        for k, val in v.items():
            if isinstance(val, torch.Tensor):
                out[k] = val.unsqueeze(0)
            elif isinstance(val, str):
                out[k] = [val]
            else:
                out[k] = val
        batch.append(out)
    return batch

# Get one sample: 12 views (3 timestamps x 4 cameras)
views_single = adapter[0]
batch_views = add_batch_dim(views_single)
print(f"Batch: {len(batch_views)} views, img shape: {batch_views[0]['img'].shape}")

## 2. Original MapAnything (pretrained)

In [None]:
from mapanything.models import MapAnything

device = "cuda" if torch.cuda.is_available() else "cpu"
model_orig = MapAnything.from_pretrained("facebook/map-anything").to(device).eval()

# Move batch to device
batch_views_device = []
for v in batch_views:
    batch_views_device.append(
        {k: v[k].to(device) if isinstance(v[k], torch.Tensor) else v[k] for k in v}
    )

In [None]:
with torch.no_grad():
    preds_orig = model_orig(batch_views_device)

print("Original model outputs (per-view):")
p0 = preds_orig[0]
print(f"  cam_trans: {p0.get('cam_trans', 'N/A')}")
print(f"  cam_quats: {p0.get('cam_quats', 'N/A')}")
if "cam_trans" in p0:
    print(f"  cam_trans shape: {p0['cam_trans'].shape}")

## 3. Rig MapAnything (xyz + yaw)

In [None]:
import os
import hydra
from omegaconf import OmegaConf
from mapanything.models import model_factory

# Build full rig config: base MapAnything (encoder, transformer) + rig head & geometric inputs
pkg_dir = os.path.dirname(os.path.abspath(__import__("mapanything").__file__))
project_root = os.path.dirname(pkg_dir)
hydra.core.global_hydra.GlobalHydra.instance().clear()
with hydra.initialize(version_base=None, config_path=os.path.join(project_root, "configs")):
    cfg = hydra.compose(config_name="train", overrides=["model=mapanything"])
base_model_config = OmegaConf.to_container(cfg.model.model_config, resolve=True)
rig_config_path = os.path.join(pkg_dir, "configs", "model", "mapanything_rig.yaml")
rig_cfg = OmegaConf.load(rig_config_path)
rig_overrides = OmegaConf.to_container(rig_cfg.model.model_config, resolve=True)
merged_config = {**base_model_config, **rig_overrides}

model_rig = model_factory("mapanything", **merged_config).to(device).eval()

# Optional: load a trained rig checkpoint
# ckpt = torch.load("path/to/checkpoint-best.pth", map_location="cpu", weights_only=False)
# model_rig.load_state_dict(ckpt["model"], strict=False)

In [None]:
with torch.no_grad():
    preds_rig = model_rig(batch_views_device)

print("Rig model outputs (per-view):")
p0_rig = preds_rig[0]
print(f"  cam_trans: {p0_rig.get('cam_trans', 'N/A')}")
print(f"  cam_yaw (cos, sin): {p0_rig.get('cam_yaw', 'N/A')}")
if "cam_trans" in p0_rig:
    print(f"  cam_trans shape: {p0_rig['cam_trans'].shape}")
if "cam_yaw" in p0_rig:
    print(f"  cam_yaw shape: {p0_rig['cam_yaw'].shape}")

## 4. Compare (first 4 views = anchor at t=0)

In [None]:
print("View 0 (anchor):")
print(f"  Original cam_trans: {preds_orig[0]['cam_trans'].cpu().squeeze()}")
print(f"  Rig      cam_trans: {preds_rig[0]['cam_trans'].cpu().squeeze()}")
print(f"  Rig      cam_yaw:   {preds_rig[0]['cam_yaw'].cpu().squeeze()}")
print("GT (anchor):")
print(f"  camera_pose_trans_gt: {batch_views[0]['camera_pose_trans_gt']}")