## Downloading packages

In [None]:
!pip install trimesh
!pip install --upgrade ipywidgets

Collecting trimesh
  Downloading trimesh-4.6.10-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.10-py3-none-any.whl (711 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/711.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m711.2/711.2 kB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.10
Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting comm>=0.1.3 (from ipywidgets)
  Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Mounting to Google Drive

In [None]:
import os
from google.colab import drive

# Define the base path for your Google Drive.
base_path = '/content/drive'

# Define the specific folder path within your Google Drive.
folder_path = 'MyDrive/Colab Notebooks/Machine Learning and Deep Learning/Project'

# Combine the base path and folder path to create the full mount path.
full_project_path = os.path.join(base_path, folder_path)

# Mount your drive.
drive.mount(base_path, force_remount=True)

Mounted at /content/drive


## Creating custom dataset

In [None]:
import os
import yaml
import numpy as np
import json
import torch
import trimesh
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from PIL import Image

class PoseEstimationDataset_RGBD(Dataset):
    """
    Custom PyTorch Dataset class for loading RGB-D images and pose annotations
    from the LineMOD dataset for 6D object pose estimation tasks.
    """
    def __init__(self, dataset_root,
                 models_root,
                 folders=list(range(1, 16)),
                 split='train',
                 train_ratio=0.8,
                 seed=42,
                 img_size=(224, 224)):
        """
        Initialize dataset by loading metadata and preparing the image/pose samples.

        Args:
            dataset_root (str): Path to the root of the RGB-D dataset.
            models_root (str): Path to the directory containing 3D object models and models_info.yml.
            folders (list): List of object folders to include.
            split (str): 'train' or 'val' split.
            train_ratio (float): Ratio of samples to include in the training set.
            seed (int): Random seed for train/test split reproducibility.
            img_size (tuple): Target image size for network input.
        """
        self.dataset_root = dataset_root
        self.models_root = models_root
        self.split = split
        self.train_ratio = train_ratio
        self.seed = seed
        self.img_size = img_size
        self.models = {}
        self.invalid_entries = 0

        # Load model metadata (e.g., object dimensions) from YAML.
        self.models_info_path = os.path.join(models_root, 'models_info.yml')
        with open(self.models_info_path, 'r') as f:
            self.models_info = yaml.safe_load(f)

        # Chaching for effectice loading.
        self.gt_data = {}    # Ground truth poses.
        self.info_data = {}  # Camera intrinsics.
        self.all_samples = []  # All available (object_id, sample_id) pairs.

        # Load gt.yml and info.yml for each object folder.
        for obj_id in folders:
            obj_folder = os.path.join(dataset_root, f"{obj_id:02d}")
            gt_path = os.path.join(obj_folder, 'gt.yml')
            info_path = os.path.join(obj_folder, 'info.yml')

            if not os.path.exists(gt_path):
                continue

            with open(gt_path, 'r') as f:
                gt = yaml.safe_load(f)
            with open(info_path, 'r') as f:
                info = yaml.safe_load(f)

            self.gt_data[obj_id] = gt
            self.info_data[obj_id] = info

            for sample_id in gt.keys():
                self.all_samples.append((obj_id, int(sample_id)))

        # Building a mapping between original and internal object IDs.
        object_id_set = set()
        for obj_id, sample_id in self.all_samples:
            annotations = self.gt_data[obj_id][sample_id]
            for ann in annotations:
                object_id_set.add(ann['obj_id'])
        self.object_ids = sorted(object_id_set)
        self.id_to_idx = {obj_id: i for i, obj_id in enumerate(self.object_ids)}
        self.idx_to_id = {i: obj_id for obj_id, i in self.id_to_idx.items()}

        if not self.all_samples:
            raise ValueError(f"No samples found in {dataset_root}. Check dataset structure.")

        # Split dataset into training and test sets.
        train, test = train_test_split(self.all_samples, train_size=self.train_ratio, random_state=self.seed)
        self.samples = train if self.split == 'train' else test

        # Image preprocessing pipeline.
        self.transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.samples)

    def save_mapping(self, filepath=None):
        """Save object ID to index mapping as JSON."""
        if filepath is None:
            filepath = "object_id_mapping.json"
        with open(filepath, "w") as f:
            json.dump(self.id_to_idx, f)

    @staticmethod
    def load_mapping(filepath=None):
        """Load object ID to index mapping from JSON file."""
        if filepath is None:
            filepath = "object_id_mapping.json"
        with open(filepath, "r") as f:
            mapping = json.load(f)
        return {int(k): v for k, v in mapping.items()}

    def printIDMapping(self):
        """Print human-readable object ID mapping."""
        print("Object ID Mapping (Original → Mapped):")
        for orig_id in self.object_ids:
            print(f"  {orig_id:02d} → {self.id_to_idx[orig_id]}")

    def nrInvalidObjects(self):
        """Return number of filtered-out invalid object crops."""
        return self.invalid_entries

    def getMappedIDs(self, ids=None):
        """
        Maps original object IDs to their training indices.

        Args:
            ids (list or None): Original object IDs. If None, returns all mappings.
        Returns:
            mapped_ids (list): Mapped training indices.
            orig_ids (list): Corresponding original object IDs.
        """
        mapped_ids = []
        orig_ids = []
        if ids is None:
            for orig_id in self.object_ids:
                mapped_ids.append(self.id_to_idx[orig_id])
                orig_ids.append(orig_id)
        else:
            for id in ids:
                if id in self.id_to_idx:
                    mapped_ids.append(self.id_to_idx[id])
                    orig_ids.append(id)
                else:
                    print(f"⚠️ Warning: Object ID {id} not found in dataset and will be ignored.")
        return mapped_ids, orig_ids

    def get_model_info(self, object_id):
        """Returns metadata (ex. diameter) for the specified object ID."""
        if object_id not in self.models_info:
            raise ValueError(f"Object ID {object_id} not in models_info.yml")
        return self.models_info[object_id]

    def load_3D_model(self, object_id):
        """
        Loads and returns the 3D model (in meters) of the specified object.
        Caches the result to avoid redundant loading.
        """
        if object_id in self.models:
            return self.models[object_id]

        model_path = os.path.join(self.models_root, f"obj_{object_id:02d}.ply")
        mesh = trimesh.load(model_path)
        points = mesh.vertices.astype(np.float32) / 1000.0
        self.models[object_id] = points
        return points

    def cropImages(self, image, annotations, cam_K, depth_image=None):
        """
        Crops object regions from the image (rgb + depth) based on bounding boxes and adjusts the intrinsics.

        Args:
            image (PIL.Image): Original RGB image.
            annotations (list): Object annotations from gt.yml.
            cam_K (torch.Tensor): Original 3x3 camera intrinsic matrix.
            depth_image (PIL.Image or None): depth image.

        Returns:
            crop_entries (list): List of cropped image entries with adjusted intrinsics.
            cam_K (torch.Tensor): Original intrinsic matrix.
        """
        crop_entries = []
        width, height = image.size

        for ann in annotations:
            x, y, w, h = ann['obj_bb']
            x1 = max(0, x)
            y1 = max(0, y)
            x2 = min(width, x + w)
            y2 = min(height, y + h)

            if x2 <= x1 or y2 <= y1:
                self.invalid_entries += 1
                continue

            cropped = image.crop((x1, y1, x2, y2))
            original_crop_width, original_crop_height = cropped.size

            # Adjust camera intrinsics for the cropped region.
            cropped_K = cam_K.clone()
            cropped_K[0, 2] -= x1
            cropped_K[1, 2] -= y1

            # Resize crop and update intrinsics accordingly.
            cropped = cropped.resize(self.img_size, Image.BILINEAR)
            scale_x = self.img_size[0] / original_crop_width
            scale_y = self.img_size[1] / original_crop_height
            cropped_K[0, 0] *= scale_x
            cropped_K[0, 2] *= scale_x
            cropped_K[1, 1] *= scale_y
            cropped_K[1, 2] *= scale_y

            cropped_rgb_tensor = self.transform(cropped)

            # Extract pose (rotation, translation).
            R_mat = np.array(ann['cam_R_m2c'], dtype=np.float32).reshape(3, 3)
            t_vec = np.array(ann['cam_t_m2c'], dtype=np.float32) / 1000.0

            # Handle optional depth cropping.
            if depth_image is not None:
                cropped_depth = depth_image.crop((x1, y1, x2, y2))
                cropped_depth = cropped_depth.resize(self.img_size, Image.BILINEAR)
                cropped_depth_tensor = transforms.ToTensor()(cropped_depth)
            else:
                cropped_depth_tensor = None

            # Normalize bounding box coordinates to [0, 1].
            norm_bbox = torch.tensor([
                x1 / width,
                y1 / height,
                (x2 - x1) / width,
                (y2 - y1) / height
            ], dtype=torch.float32)

            crop_entries.append({
                'cropped_rgb': cropped_rgb_tensor,
                'cropped_depth': cropped_depth_tensor,
                'cropped_K': cropped_K,
                'object_id': ann['obj_id'],
                'bbox': ann['obj_bb'],
                'norm_bbox': norm_bbox,
                'rotation': R_mat,
                'translation': t_vec
            })

        return crop_entries, cam_K

    def __getitem__(self, idx):
        """
        Load an RGB-D sample and return a dict with:
        - Original RGB and depth image
        - Cropped objects with pose and camera intrinsics
        """
        object_id, sample_id = self.samples[idx]
        folder = os.path.join(self.dataset_root, f"{object_id:02d}")

        rgb_path = os.path.join(folder, 'rgb', f"{sample_id:04d}.png")
        if not os.path.exists(rgb_path):
            raise FileNotFoundError(f"RGB image not found: {rgb_path}")
        rgb = Image.open(rgb_path).convert("RGB")

        depth_path = os.path.join(folder, 'depth', f"{sample_id:04d}.png")
        if not os.path.exists(depth_path):
            raise FileNotFoundError(f"Depth image not found: {depth_path}")
        depth = Image.open(depth_path)
        depth_np = np.array(depth).astype(np.float32)
        depth_np = np.clip(depth_np, 0, 2000) / 2000.0 # Normalizing depth image between [0, 2].
        depth_img = Image.fromarray((depth_np * 255).astype(np.uint8))

        annotations = self.gt_data[object_id][sample_id]
        cam_K = torch.tensor(np.array(self.info_data[object_id][sample_id]['cam_K']).reshape(3, 3), dtype=torch.float32)

        crop_entries, org_K = self.cropImages(rgb, annotations, cam_K, depth_image=depth_img)

        # Map object IDs to internal indices.
        for entry in crop_entries:
            true_obj_id = entry['object_id']
            if true_obj_id not in self.id_to_idx:
                raise ValueError(f"Object ID {true_obj_id} not found in id_to_idx mapping.")
            entry['object_id'] = self.id_to_idx[true_obj_id]

        return {
            'sample_id': sample_id,
            'original_rgb': rgb,
            'original_K': org_K,
            'original_depth': depth_np,
            'objects': crop_entries
        }

## PoseNet6D using RGB-D data

In [None]:
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn


class PoseNet_RGBD(nn.Module):
    """
    Neural network for 6D object pose estimation using RGB-D input.
    Takes cropped RGB-D images, normalized 2D bounding box, intrinsic matrix,
    and object ID to predict rotation (quaternion) and translation (3D position).
    """
    def __init__(self, num_objects,
                 embedding_dim=16,
                 img_size=(224, 224),
                 weights=ResNet18_Weights.DEFAULT):
        """
        Initialize PoseNet_RGBD architecture.

        Args:
            num_objects (int): Number of distinct object classes.
            embedding_dim (int): Size of the object ID embedding vector.
            img_size (tuple): Size of input RGB and depth crops.
            weights (ResNet18_Weights): Pretrained weights for ResNet18 RGB encoder.
        """
        super().__init__()
        self.img_size = img_size

        # === RGB Encoder ===
        # Use pretrained ResNet18, discard final classification layers.
        rgb_backbone = resnet18(weights=weights)
        self.rgb_encoder = nn.Sequential(*list(rgb_backbone.children())[:-2])  # output: (B, 512, 7, 7).
        self.global_pool_rgb = nn.AdaptiveAvgPool2d(1)  # convert from (B, 512, 1, 1) to (B, 512).

        # === Depth Encoder ===
        # A lightweight CNN for single-channel depth input.
        self.depth_encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # -> 112x112
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # -> 56x56
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # -> (B, 128, 1, 1)
            nn.Flatten()              # -> (B, 128)
        )

        # === Object ID Embedding ===
        # Learn a small dense vector representation for each object ID.
        self.obj_embedding = nn.Embedding(num_embeddings=num_objects,
                                          embedding_dim=embedding_dim)

        # === Feature Fusion ===
        # Concatenate RGB + depth + bbox + obj_embedding vectors.
        fused_dim = 512 + 128 + 4 + embedding_dim

        # Fully connected layers for depth (Z translation) prediction.
        self.fc_depth = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

        # Fully connected layers for rotation (quaternion) prediction.
        self.fc_rotation = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 4)  # Quaternion output.
        )

    def forward(self, x_rgb, x_depth, norm_bbox, K_crop, object_id):
        """
        Forward pass to predict 6D pose.

        Args:
            x_rgb (Tensor): RGB image crop (B, 3, H, W)
            x_depth (Tensor): Depth image crop (B, 1, H, W)
            norm_bbox (Tensor): Normalized bounding boxes (B, 4)
            K_crop (Tensor): Camera intrinsics for cropped view (B, 3, 3)
            object_id (Tensor): Object class IDs (B,)

        Returns:
            translation (Tensor): Predicted 3D translation vector (B, 3)
            quat (Tensor): Predicted rotation as quaternion (B, 4)
        """
        B = x_rgb.shape[0]

        # === Extract RGB features ===
        rgb_feat = self.rgb_encoder(x_rgb)                      # (B, 512, 7, 7)
        rgb_feat = self.global_pool_rgb(rgb_feat).view(B, -1)   # (B, 512)

        # === Extract depth features ===
        depth_feat = self.depth_encoder(x_depth)                # (B, 128)

        # === Object ID embedding ===
        object_id = object_id.to(self.obj_embedding.weight.device)
        obj_feat = self.obj_embedding(object_id)                # (B, embedding_dim)

        # === Concatenate all features ===
        x = torch.cat([rgb_feat, depth_feat, norm_bbox, obj_feat], dim=1)  # (B, fused_dim)

        # === Predict object depth (Z translation) ===
        depth = self.fc_depth(x).squeeze(1)                     # (B,)
        depth = torch.clamp(depth, min=0.1, max=1.5)            # limit predictions to valid range.

        # === Predict rotation as normalized quaternion ===
        quat = F.normalize(self.fc_rotation(x), dim=1)          # (B, 4)

        # === Compute translation (X, Y, Z) in camera frame ===
        # Extract intrinsic parameters from cropped camera intrinsics.
        fx = K_crop[:, 0, 0]       # Focal length in x-direction.
        fy = K_crop[:, 1, 1]       # Focal length in y-direction.
        cx_crop = K_crop[:, 0, 2]  # Principal point x-coordinate (cropped).
        cy_crop = K_crop[:, 1, 2]  # Principal point y-coordinate (cropped).

        # Compute center of bounding box in pixel coordinates.
        img_w, img_h = self.img_size
        u = (norm_bbox[:, 0] + 0.5 * norm_bbox[:, 2]) * img_w  # Horizontal center (in pixels).
        v = (norm_bbox[:, 1] + 0.5 * norm_bbox[:, 3]) * img_h  # Vertical center (in pixels).

        # Back-project 2D bbox center + predicted depth to 3D camera coordinates.
        x_cam = (u - cx_crop) * depth / fx  # X = (u - cx) * Z / fx.
        y_cam = (v - cy_crop) * depth / fy  # Y = (v - cy) * Z / fy.
        z_cam = depth                       # Z is directly predicted.

        # Stack into final 3D translation vector (X, Y, Z).
        translation = torch.stack([x_cam, y_cam, z_cam], dim=1)  # (B, 3)

        return translation, quat

## Helper functions

In [None]:
import torch.nn.functional as F
from scipy.spatial.transform import Rotation as R
from scipy.spatial import cKDTree

def computeADD(R_pred, t_pred, R_gt, t_gt, model_points):
    """
    Computes the ADD (Average Distance of Model Points) metric.

    Args:
        R_pred (np.ndarray or Tensor): Predicted rotation matrix (3x3).
        t_pred (np.ndarray or Tensor): Predicted translation vector (3,).
        R_gt (np.ndarray or Tensor): Ground truth rotation matrix (3x3).
        t_gt (np.ndarray or Tensor): Ground truth translation vector (3,).
        model_points (np.ndarray or Tensor): 3D object model points (N, 3).

    Returns:
        float: Mean Euclidean distance between transformed predicted and ground truth points.
    """
    def to_np(x):
        return x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x

    # Convert all inputs to NumPy arrays.
    R_pred, t_pred = to_np(R_pred), to_np(t_pred)
    R_gt, t_gt = to_np(R_gt), to_np(t_gt)
    model_points = to_np(model_points)

    # Apply transformations to model points.
    pred_pts = model_points @ R_pred.T + t_pred
    gt_pts = model_points @ R_gt.T + t_gt

    # Compute mean L2 distance.
    distances = np.linalg.norm(pred_pts - gt_pts, axis=1)
    return distances.mean()

def matrix_to_quaternion_batch(rotation_matrix):
    """
    Converts a batch of rotation matrices to quaternions.

    Args:
        rotation_matrix (Tensor): (B, 3, 3) or (3, 3) rotation matrix.

    Returns:
        Tensor: (B, 4) batch of normalized quaternions (w, x, y, z).
    """
    if rotation_matrix.dim() == 2:
        rotation_matrix = rotation_matrix.unsqueeze(0)  # Ensure batch dimension.

    batch_size = rotation_matrix.size(0)
    quaternions = torch.zeros(batch_size, 4, device=rotation_matrix.device)

    # Compute trace for each matrix.
    trace = torch.diagonal(rotation_matrix, dim1=1, dim2=2).sum(dim=1)
    trace = torch.clamp(trace, min=-0.999)  # Stability clamp

    # Compute scalar (w) part.
    s = torch.sqrt(1.0 + trace) / 2.0
    quaternions[:, 0] = s

    # Prevent division by zero.
    denom = 4.0 * s
    denom = torch.clamp(denom, min=1e-6)

    # Compute vector part (x, y, z).
    quaternions[:, 1] = (rotation_matrix[:, 2, 1] - rotation_matrix[:, 1, 2]) / denom
    quaternions[:, 2] = (rotation_matrix[:, 0, 2] - rotation_matrix[:, 2, 0]) / denom
    quaternions[:, 3] = (rotation_matrix[:, 1, 0] - rotation_matrix[:, 0, 1]) / denom

    # Normalize result quaternion.
    return F.normalize(quaternions, dim=1)

def quaternion_to_matrix_batch(quat):
    """
    Converts a batch of quaternions to rotation matrices.

    Args:
        quat (Tensor): (B, 4) or (4,) quaternions in (w, x, y, z) format.

    Returns:
        Tensor: (B, 3, 3) batch of rotation matrices.
    """
    if quat.dim() == 1:
        quat = quat.unsqueeze(0)  # Ensure batch dimension.

    w, x, y, z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]
    B = quat.size(0)
    R = torch.zeros((B, 3, 3), device=quat.device)

    # Populate rotation matrix elements from quaternion.
    R[:, 0, 0] = 1 - 2 * (y**2 + z**2)
    R[:, 0, 1] = 2 * (x * y - z * w)
    R[:, 0, 2] = 2 * (x * z + y * w)
    R[:, 1, 0] = 2 * (x * y + z * w)
    R[:, 1, 1] = 1 - 2 * (x**2 + z**2)
    R[:, 1, 2] = 2 * (y * z - x * w)
    R[:, 2, 0] = 2 * (x * z - y * w)
    R[:, 2, 1] = 2 * (y * z + x * w)
    R[:, 2, 2] = 1 - 2 * (x**2 + y**2)

    return R.squeeze(0) if quat.size(0) == 1 else R

def quaternion_loss(quat_pred, quat_gt):
    """
    Computes quaternion-based rotation loss.

    Args:
        quat_pred (Tensor): (B, 4) predicted quaternions.
        quat_gt (Tensor): (B, 4) ground truth quaternions.

    Returns:
        Tensor: Scalar loss measuring angular difference.
    """
    quat_pred = F.normalize(quat_pred, dim=1)
    quat_gt = F.normalize(quat_gt, dim=1)

    # Dot product gives cosine of half-angle between quaternions.
    dot = torch.sum(quat_pred * quat_gt, dim=1)
    dot = torch.clamp(dot, -1.0 + 1e-4, 1.0 - 1e-4)

    # Loss is 1 - cos²(theta) to minimize angular error.
    return (1 - dot**2).mean()

def quaternion_angular_error(q1, q2):
    """
    Computes angular error between two batches of quaternions.

    Args:
        q1 (Tensor): (B, 4) predicted quaternions.
        q2 (Tensor): (B, 4) ground truth quaternions.

    Returns:
        Tensor: (B,) angular error in degrees.
    """
    dot = torch.sum(q1 * q2, dim=1).clamp(-1.0, 1.0)
    dot = torch.abs(dot)  # Resolves the ±q ambiguity.
    angle = 2 * torch.acos(dot) * (180.0 / torch.pi)
    return angle

def computeMSE(rot_pred, t_pred, rot_gt, t_gt, quat=False,
               weight_xyz=(1.0, 1.0, 0.1), beta=1.0, print_mse=False):
    """
    Computes combined MSE loss for translation and rotation.

    Args:
        rot_pred (Tensor): (B, 4) or (B, 3, 3) predicted rotation.
        t_pred (Tensor): (B, 3) predicted translation.
        rot_gt (Tensor): ground truth rotation.
        t_gt (Tensor): ground truth translation.
        quat (bool): If True, use quaternion loss; else use matrix loss.
        weight_xyz (tuple): Weights for (x, y, z) translation axes.
        beta (float): Weight for rotation loss term.
        print_mse (bool): If True, print detailed loss values.

    Returns:
        tuple: (total_loss, (x_loss, y_loss, z_loss, angle_deg))
    """
    # Clamp Z translation to avoid numerical instability in log.
    t_pred = torch.clamp(t_pred, min=1e-3)
    t_gt = torch.clamp(t_gt, min=1e-3)

    # Weighted MSE translation loss (with log scaling for Z).
    x_loss = F.mse_loss(t_pred[:, 0], t_gt[:, 0]) * weight_xyz[0]
    y_loss = F.mse_loss(t_pred[:, 1], t_gt[:, 1]) * weight_xyz[1]
    z_loss = F.mse_loss(torch.log(t_pred[:, 2]), torch.log(t_gt[:, 2])) * weight_xyz[2]
    translation_loss = x_loss + y_loss + z_loss

    # Rotation loss: quaternion or matrix.
    if quat:
        rotation_loss = quaternion_loss(rot_pred, rot_gt)
    else:
        rot_diff = torch.bmm(rot_pred.transpose(1, 2), rot_gt)
        identity = torch.eye(3, device=rot_pred.device).unsqueeze(0).expand(rot_pred.size(0), -1, -1)
        rotation_loss = F.mse_loss(rot_diff, identity)

    # Combine the losses.
    total_loss = translation_loss + beta * rotation_loss

    # Optional logging.
    angle_deg = None
    if print_mse:
        print(f"\nX loss:           {x_loss:.6f}")
        print(f"Y loss:           {y_loss:.6f}")
        print(f"Z loss:           {z_loss:.6f}")
        print(f"Rotation loss:    {rotation_loss:.6f}")
        print(f"Total loss:       {total_loss:.6f}")
        if quat:
            angle_deg = quaternion_angular_error(rot_pred, rot_gt).mean().item()
            print(f"Angular error (deg): {angle_deg:.2f}")

    return total_loss, (x_loss, y_loss, z_loss, angle_deg)

def flatten_collate_fn(batch):
    """
    Custom collate function for DataLoader to flatten all object crops
    from each scene into a single batch.

    Args:
        batch (list): List of dicts from __getitem__ containing 'objects' lists.

    Returns:
        dict: Batched tensors for rgb, depth, rotation, translation, etc.
    """
    flat_data = []

    # Iterate over samples and flatten all object entries.
    for sample in batch:
        for obj in sample['objects']:
            flat_data.append({
                'rgb': obj['cropped_rgb'],
                'depth': obj['cropped_depth'],
                'rotation': torch.tensor(obj['rotation'], dtype=torch.float32),
                'translation': torch.tensor(obj['translation'], dtype=torch.float32),
                'object_id': obj['object_id'],
                'bbox': obj['norm_bbox'].clone().detach(),
                'cropped_K': obj['cropped_K'].clone().detach()
            })

    # Stack each field into batch tensors.
    rgb = torch.stack([item['rgb'] for item in flat_data])
    depth = torch.stack([item['depth'] for item in flat_data])
    rotation = torch.stack([item['rotation'] for item in flat_data])
    translation = torch.stack([item['translation'] for item in flat_data])
    object_ids = torch.tensor([item['object_id'] for item in flat_data], dtype=torch.int64)
    bbox = torch.stack([item['bbox'] for item in flat_data])
    cropped_K = torch.stack([item['cropped_K'] for item in flat_data])

    return {
        'rgb': rgb,
        'depth': depth,
        'rotation': rotation,
        'translation': translation,
        'object_id': object_ids,
        'norm_bbox': bbox,
        'cropped_K': cropped_K
    }

## Dataloaders

In [None]:
from torch.utils.data import DataLoader
import os
import numpy as np

# Path to the full dataset.
dataset_root = os.path.join(full_project_path, 'dataset/LineMOD/Linemod_preprocessed/data')
models_root = os.path.join(full_project_path, 'dataset/LineMOD/Linemod_preprocessed/models')

# What parts of the dataset would you like to include?
#folders = [1]
#print(f"Loading data folder(s): {folders}")

# Defining the dataset splits.
train_dataset = PoseEstimationDataset(dataset_root,
                                      models_root,
                                      #folders=folders, # Comment out to train on the entire thing.
                                      split='train')

test_dataset = PoseEstimationDataset(dataset_root,
                                     models_root,
                                     #folders =folders, # Comment out to train on the entire thing.
                                     split='test')

print(f"Training dataset size: {len(train_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

# Print the list of objects present in the defined training set.
_, orig_ids = train_dataset.getMappedIDs()
print(f"Training on {len(orig_ids)} object types: {[f'{oid:02d}' for oid in sorted(orig_ids)]}")

# Define dataloaders.
num_workers = 2
batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=flatten_collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=flatten_collate_fn
)

# Storing the mapping.
train_dataset.save_mapping()

## Functions for training and evaluation

In [None]:
def train_model(epoch, model, train_loader, criterion, optimizer, device):
    """
    Trains the PoseNet model for one epoch.

    Args:
        epoch (int): Current training epoch number.
        model (nn.Module): PoseNet model to be trained.
        train_loader (DataLoader): DataLoader with training batches.
        criterion (function): Loss function (not used directly here).
        optimizer (torch.optim.Optimizer): Optimizer instance.
        device (torch.device): Target device (e.g., 'cuda').

    Returns:
        nn.Module: The trained model.
    """
    model.train()
    running_loss = 0.0
    total = 0

    # Loop over batches with tqdm progress bar.
    for batch_idx, data in enumerate(tqdm(train_loader, desc=f'Epoch {epoch}', leave=False, unit="batch")):

        # Move data to device.
        crop_rgb = data['rgb'].to(device)
        crop_depth = data['depth'].to(device)
        t_gt = data['translation'].to(device)
        R_gt = data['rotation'].to(device)
        norm_bbox = data['norm_bbox'].to(device)
        cropped_K = data['cropped_K'].to(device)
        object_ids = data['object_id'].to(device)

        # Forward pass through the model.
        t_pred, quat_pred = model(crop_rgb, crop_depth, norm_bbox, cropped_K, object_ids)

        # Convert ground truth rotation matrix to quaternion.
        quat_gt = matrix_to_quaternion_batch(R_gt)

        # Use a dynamic loss weighting schedule based on epoch.
        if epoch < 5:
            beta = 1
            weight_xyz = (0, 0, 0.1)
        elif epoch < 10:
            beta = 5
            weight_xyz = (0.5, 0.5, 0.4)
        else:
            beta = 10
            weight_xyz = (0.1, 0.1, 1)

        # Compute loss using combined MSE (translation + rotation).
        loss, _ = computeMSE(quat_pred, t_pred, quat_gt, t_gt,
                             quat=True, weight_xyz=weight_xyz, beta=beta)

        # Skip invalid loss values.
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"⚠️ Skipping batch {batch_idx} due to invalid loss (NaN or Inf)")
            continue

        # Backpropagation.
        optimizer.zero_grad()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
        loss.backward()
        optimizer.step()

        # Track cumulative loss.
        running_loss += loss.item()
        total += t_gt.size(0)

    train_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch} | Loss: {train_loss:.6f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')
    return model

def evaluate_model(model, val_loader, dataset, device, track_per_object=False):
    """
    Evaluates the PoseNet model using ADD and loss metrics.

    Args:
        model (nn.Module): Trained model to evaluate.
        val_loader (DataLoader): Validation dataset loader.
        dataset (PoseEstimationDataset): Dataset class instance (for model point loading).
        device (torch.device): Device to evaluate on.
        track_per_object (bool): If True, track ADD scores per object ID.

    Returns:
        tuple: (avg_loss, avg_ADD, (x_loss, y_loss, z_loss, angular_error))
    """
    model.eval()
    running_loss = 0.0
    add_total = 0.0
    num_samples = 0
    first_round = True  # Flag to print detailed loss only once.

    # Load ID mapping: mapped_id → original object ID.
    idx_to_id = {v: k for k, v in PoseEstimationDataset.load_mapping("object_id_mapping.json").items()}

    # Initialize per-object ADD tracking.
    add_per_object = {orig_id: [] for orig_id in idx_to_id.values()} if track_per_object else None

    model_points_cache = {}  # Cache to avoid reloading model points.
    return_tuple = tuple()

    with torch.no_grad():
        for data in tqdm(val_loader, desc="Evaluating", leave=False):

            # Move inputs to device.
            crop_rgb = data['rgb'].to(device)
            crop_depth = data['depth'].to(device)
            t_gt = data['translation'].to(device)
            R_gt = data['rotation'].to(device)
            object_ids = data['object_id']
            norm_bbox = data['norm_bbox'].to(device)
            cropped_K = data['cropped_K'].to(device)

            # Forward pass and pose prediction.
            t_pred, quat_pred = model(crop_rgb, crop_depth, norm_bbox, cropped_K, object_ids)
            quat_gt = matrix_to_quaternion_batch(R_gt)
            R_pred = quaternion_to_matrix_batch(quat_pred)

            # Compute loss and optionally print detailed metrics once.
            loss, loss_tuple = computeMSE(quat_pred, t_pred, quat_gt, t_gt, quat=True, print_mse=first_round)
            if first_round:
                return_tuple = loss_tuple
            first_round = False
            running_loss += loss.item()

            # Compute ADD metric for each object in batch.
            for i in range(crop_rgb.size(0)):
                mapped_id = int(object_ids[i])
                original_id = idx_to_id[mapped_id]

                # Use cached model points if available.
                if original_id not in model_points_cache:
                    model_np = dataset.load_3D_model(original_id)
                    model_points_cache[original_id] = torch.tensor(model_np, dtype=torch.float32).to(device)

                model_points = model_points_cache[original_id]

                # ADD metric (predicted vs ground truth transformation).
                add = computeADD(R_pred[i], t_pred[i], R_gt[i], t_gt[i], model_points)
                add_total += add
                num_samples += 1

                if track_per_object:
                    add_per_object[original_id].append(add)

    avg_loss = running_loss / len(val_loader)
    avg_add = add_total / num_samples
    print(f'Validation Loss: {avg_loss:.6f}, Avg ADD: {avg_add:.4f}')

    # Optionally print per-object ADD stats.
    if track_per_object:
        print("\nPer-object ADD (mean):")
        for obj_id, adds in sorted(add_per_object.items()):
            if adds:
                mean_add = np.mean(adds)
                print(f"  Object {obj_id:02d}: ADD = {mean_add:.4f}")
            else:
                print(f"  Object {obj_id:02d}: No samples")

    return avg_loss, avg_add, return_tuple

## Training

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from torch import nn, optim
from tqdm import tqdm

# Flags for saving locally or to Google Drive.
SAVE_LOCAL = True
SAVE_DRIVE = True

def train_and_evaluate(model, train_loader, test_loader, train_dataset, test_dataset,
                       full_project_path, num_epochs=10, patience=5, start_epoch=1):
    """
    Trains and evaluates PoseNet over multiple epochs, with support for checkpointing,
    early stopping, and performance plotting.

    Args:
        model (nn.Module): PoseNet_RGBD model instance.
        train_loader (DataLoader): Dataloader for training data.
        test_loader (DataLoader): Dataloader for validation/test data.
        train_dataset (Dataset): Training dataset instance (for diagnostics).
        test_dataset (Dataset): Test dataset instance (used for ADD calculation).
        full_project_path (str): Root directory for saving models and logs.
        num_epochs (int): Number of training epochs.
        patience (int): Early stopping patience based on ADD.
        start_epoch (int): Epoch to resume training from (if checkpoint exists).

    Returns:
        tuple: (trained model, list of train losses, list of ADD values)
    """

    # === Path setup ===
    checkpoint_path_local = "/content/checkpoint_OP.pth"
    checkpoint_path_drive = os.path.join(full_project_path, "models/checkpoint_OP.pth")
    best_model_path_drive = os.path.join(full_project_path, "models/best_posenet_OP.pt")
    best_model_path_local = "/content/best_posenet_OP.pt"

    # === Device configuration ===
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("<<<<<<Using GPU>>>>>>" if torch.cuda.is_available() else "<<<<<<Using CPU>>>>>>")
    model.to(device)

    # === Optimizer and Learning Rate Scheduler ===
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.001, weight_decay=0.005)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

    # === Resume from checkpoint (if it's available) ===
    train_losses, add_losses, ang_losses = [], [], []
    z_losses, x_losses, y_losses = [], [], []
    best_add = float('inf')
    if os.path.exists(checkpoint_path_local):
        checkpoint = torch.load(checkpoint_path_local, map_location=device)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_add = checkpoint['best_add']
        train_losses = checkpoint['train_losses']
        add_losses = checkpoint['add_losses']
        x_losses = checkpoint['x_losses']
        y_losses = checkpoint['y_losses']
        z_losses = checkpoint['z_losses']
        ang_losses = checkpoint['ang_losses']
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed training from epoch {start_epoch}.")
    else:
        print("No checkpoint found, starting from epoch 1.")

    counter = 0  # Counter for early stopping.

    # === Main Training Loop ===
    for epoch in range(start_epoch, num_epochs + 1):
        print(f"\n--------- Starting Epoch {epoch}/{num_epochs} ---------")
        print(f">>>>>>>Current best ADD is {best_add:.4f}<<<<<<<<<")
        print(f"Invalid samples found in dataset: {train_dataset.nrInvalidObjects()}")

        # === Training Phase ===
        model = train_model(epoch, model, train_loader, computeMSE, optimizer, device)

        # === Evaluation Phase ===
        avg_loss, avg_add, loss_tuple = evaluate_model(model, test_loader, test_dataset, device)

        # Recording metrics.
        train_losses.append(avg_loss)
        add_losses.append(avg_add)
        x_losses.append(loss_tuple[0].item())
        y_losses.append(loss_tuple[1].item())
        z_losses.append(loss_tuple[2].item())
        ang_losses.append(loss_tuple[3])

        # === Best Model Saving ===
        if avg_add < best_add:
            best_add = avg_add
            counter = 0
            model.eval()
            if SAVE_LOCAL:
                torch.save(model.state_dict(), best_model_path_local)
                print(f"✅ New best ADD: {avg_add:.4f} (saved model locally)")
            if SAVE_DRIVE:
                torch.save(model.state_dict(), best_model_path_drive)
                print(f"☁️ New best ADD: {avg_add:.4f} (saved model on Google Drive)")
        else:
            counter += 1
            if counter >= patience:
                print("⏹ Early stopping triggered.")
                break

        # === Save Checkpoint (regardless of improvement) ===
        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_add': best_add,
            'train_losses': train_losses,
            'add_losses': add_losses,
            'x_losses': x_losses,
            'y_losses': y_losses,
            'z_losses': z_losses,
            'ang_losses': ang_losses
        }
        if SAVE_LOCAL:
            torch.save(checkpoint, checkpoint_path_local)
            print("💾 Checkpoint saved locally.")
        if SAVE_DRIVE:
            torch.save(checkpoint, checkpoint_path_drive)
            print("☁️ Checkpoint saved on Google Drive.")

        # Update learning rate based on validation loss.
        lr_scheduler.step(avg_loss)
        print(f"Epoch {epoch}/{num_epochs} | Average ADD: {avg_add:.4f}")

    # === Plotting Training Curve after finished training (Loss and ADD) ===
    epochs_run = list(range(1, len(train_losses) + 1))

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(epochs_run, train_losses, 'g-', label='MSE Loss')
    ax2.plot(epochs_run, add_losses, 'b-', label='ADD')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('MSE Loss', color='g')
    ax2.set_ylabel('ADD (m)', color='b')
    ax1.set_title("Training Loss and Validation ADD")
    ax1.grid(True)
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    plt.tight_layout()
    plt.show()

    # === Plotting Translation and Angular Error ===
    fig, ax1 = plt.subplots()
    ax1.plot(epochs_run, x_losses, label='X loss', color='tab:red')
    ax1.plot(epochs_run, y_losses, label='Y loss', color='tab:green')
    ax1.plot(epochs_run, z_losses, label='Z loss', color='tab:blue')
    ax1.set_ylabel('Translation Loss (MSE)', color='black')
    ax1.set_xlabel('Epoch')
    ax1.grid(True)

    ax2 = ax1.twinx()
    ax2.plot(epochs_run, ang_losses, label='Angular Error (deg)', color='tab:purple', linestyle='--')
    ax2.set_ylabel('Angular Error (°)', color='tab:purple')
    ax2.tick_params(axis='y', labelcolor='tab:purple')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

    plt.title("Translation Losses and Angular Error Over Epochs")
    plt.tight_layout()
    plt.show()

    return model, train_losses, add_losses

# === Initialize model ===
_, num_objects = train_dataset.getMappedIDs()
num_obj = len(num_objects)
print(f"This model will be trained to find {num_obj} object(s).")

model = PoseNet_RGBD(
    num_objects=num_obj,
    embedding_dim=16,
    img_size=(224, 224)
)

# === Train and Evaluate ===
_ = train_and_evaluate(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    full_project_path=full_project_path,
    num_epochs=65,
    patience=15
)

## Functions for plotting the results

In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image

def draw_model_projection(image, points, color, radius=1):
    """
    Draws projected 2D model points as circles on the image.

    Args:
        image (np.ndarray): BGR image.
        points (np.ndarray): 2D projected points (N, 2).
        color (tuple): BGR color (e.g., (0, 255, 0)).
        radius (int): Radius of the circle to draw.
    """
    for pt in points.astype(int):
        x, y = pt
        if 0 <= x < image.shape[1] and 0 <= y < image.shape[0]:
            cv2.circle(image, (x, y), radius, color, -1)

def draw_legend(image, labels_colors):
    """
    Draws a simple color legend (e.g., "GT" vs "Pred") on the image.

    Args:
        image (np.ndarray): BGR image.
        labels_colors (list): List of (label, color) pairs.
    """
    x, y, spacing = 10, 25, 25
    font_scale = 0.4
    text_thickness = 1
    box_width = 10
    box_height = 10
    for i, (label, color) in enumerate(labels_colors):
        cv2.putText(image, label, (x + box_width + 5, y + i * spacing),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, text_thickness)
        cv2.rectangle(image,
                      (x, y - box_height // 2 + i * spacing),
                      (x + box_width, y + box_height // 2 + i * spacing),
                      color, -1)

def draw_axes(image, R, t, K, is_gt=False, axis_length=0.05, thickness=2):
    """
    Projects and draws coordinate axes (X, Y, Z) for a given pose.

    Args:
        image (np.ndarray): BGR image.
        R (np.ndarray): 3x3 rotation matrix.
        t (np.ndarray): 3D translation vector.
        K (np.ndarray): 3x3 camera intrinsic matrix.
        is_gt (bool): Whether this is ground truth (duller colors).
        axis_length (float): Length of axes in meters.
        thickness (int): Line thickness.

    Returns:
        np.ndarray: Annotated image.
    """
    axes_3d = np.array([[0, 0, 0],
                        [axis_length, 0, 0],
                        [0, axis_length, 0],
                        [0, 0, axis_length]], dtype=np.float32)

    pts_2d = (K @ (axes_3d @ R.T + t).T).T
    pts_2d = pts_2d[:, :2] / pts_2d[:, 2:3]
    pts_2d = pts_2d.astype(int)

    origin = tuple(pts_2d[0])
    cv2.circle(image, origin, 6, (255, 255, 255), -1)

    color_map = [(0, 0, 255), (0, 255, 0), (255, 0, 0)] if not is_gt else \
                [(100, 100, 255), (100, 255, 100), (255, 100, 100)]

    for i in range(1, 4):
        cv2.line(image, origin, tuple(pts_2d[i]), color_map[i - 1], thickness)

    return image

def project(pts, R, t, K):
    """
    Projects 3D model points into 2D image space.

    Args:
        pts (np.ndarray): 3D model points (N, 3).
        R (np.ndarray): 3x3 rotation matrix.
        t (np.ndarray): 3D translation vector.
        K (np.ndarray): 3x3 camera intrinsics.

    Returns:
        np.ndarray: 2D projected points (N, 2).
    """
    proj = (K @ (pts @ R.T + t).T).T
    return proj[:, :2] / proj[:, 2:3]

def visualize_pose_prediction(obj, model, dataset, obj_id, device,
                              draw_axes_flag=False, draw_models=True, use_full_frame=True,
                              save_dir=None, save_prefix="result", legend=False):
    """
    Visualizes predicted vs ground truth pose for a given object crop.

    Args:
        obj (dict): Single object crop entry from dataset.
        model (nn.Module): PoseNet model.
        dataset (PoseEstimationDataset): Dataset instance (for 3D model lookup).
        obj_id (int): Mapped object ID (internal index).
        device (torch.device): Device to run inference on.
        draw_axes_flag (bool): Whether to draw 3D coordinate axes.
        draw_models (bool): Whether to draw full model projections.
        use_full_frame (bool): Whether to also show full frame projection.
        save_dir (str): Directory to save visualizations (optional).
        save_prefix (str): Filename prefix for saved files.
        legend (bool): Whether to draw a color legend.
    """
    # Prepare the model inputs.
    rgb = obj['cropped_rgb'].unsqueeze(0).to(device)
    crop_depth = obj['cropped_depth'].unsqueeze(0).to(device)
    norm_bbox = obj['norm_bbox'].unsqueeze(0).to(device)
    cropped_K = obj['cropped_K'].unsqueeze(0).to(device)
    obj_tensor = torch.tensor([obj_id], dtype=torch.long).to(device)

    R_gt = obj['rotation']
    t_gt = obj['translation']

    # Inference.
    with torch.no_grad():
        t_pred, quat_pred = model(rgb, crop_depth, norm_bbox, cropped_K, obj_tensor)

    # Convert predictions to the correct format.
    R_pred = quaternion_to_matrix_batch(quat_pred.detach()).squeeze().cpu().numpy()
    t_pred = t_pred.detach().squeeze().cpu().numpy()

    quat_gt = matrix_to_quaternion_batch(torch.tensor(R_gt, dtype=torch.float32).to(device))
    with torch.no_grad():
        ang_err = quaternion_angular_error(quat_pred, quat_gt).item()

    # Load 3D model.
    original_id = dataset.idx_to_id[obj_id]
    model_points = dataset.load_3D_model(original_id)

    # === Visualize Cropped Image ===
    crop_rgb = obj['cropped_rgb'].cpu().numpy().transpose(1, 2, 0)
    crop_rgb = ((crop_rgb * 0.229 + 0.485).clip(0, 1) * 255).astype(np.uint8)
    crop_rgb = cv2.cvtColor(crop_rgb, cv2.COLOR_RGB2BGR)
    cropped_K_np = cropped_K.squeeze().cpu().numpy()
    vis_crop = crop_rgb.copy()

    if draw_axes_flag:
        vis_crop = draw_axes(vis_crop, R_pred, t_pred, cropped_K_np)
        vis_crop = draw_axes(vis_crop, R_gt, t_gt, cropped_K_np, is_gt=True)

    if draw_models:
        proj_gt = project(model_points, R_gt, t_gt, cropped_K_np)
        proj_pred = project(model_points, R_pred, t_pred, cropped_K_np)
        draw_model_projection(vis_crop, proj_gt, (0, 255, 0))
        draw_model_projection(vis_crop, proj_pred, (0, 0, 255))

    if legend:
        draw_legend(vis_crop, [("GT", (0, 255, 0)), ("Pred", (0, 0, 255))])

    # Computing the ADD metric.
    add = computeADD(R_pred, t_pred, R_gt, t_gt, model_points)
    print(f"\n➡️ Mapped object ID {obj_id:02d}")
    print(f"ADD: {add:.4f} m | Angular Error: {ang_err:.2f}°")

    plt.figure()
    plt.imshow(cv2.cvtColor(vis_crop, cv2.COLOR_BGR2RGB))
    plt.title(f"[Cropped] Obj {original_id:02d}")
    plt.axis("off")
    plt.show()

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"{save_prefix}_cropped_obj{obj_id:02d}.png")
        cv2.imwrite(save_path, vis_crop)

    # === Visualize Full Image View ===
    if use_full_frame:
        full_image = cv2.cvtColor(np.array(obj['original_rgb']), cv2.COLOR_RGB2BGR)
        original_K = obj['original_K'].numpy()

        if draw_axes_flag:
            full_image = draw_axes(full_image, R_pred, t_pred, original_K)
            full_image = draw_axes(full_image, R_gt, t_gt, original_K, is_gt=True)

        if draw_models:
            proj_gt_full = project(model_points, R_gt, t_gt, original_K)
            proj_pred_full = project(model_points, R_pred, t_pred, original_K)
            draw_model_projection(full_image, proj_gt_full, (0, 255, 0))
            draw_model_projection(full_image, proj_pred_full, (0, 0, 255))

        if legend:
            draw_legend(full_image, [("GT", (0, 255, 0)), ("Pred", (0, 0, 255))])

        plt.figure()
        plt.imshow(cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB))
        plt.title(f"[Full Frame] Obj {original_id:02d}")
        plt.axis("off")
        plt.show()

        if save_dir:
            save_path = os.path.join(save_dir, f"{save_prefix}_fullframe_obj{obj_id:02d}.png")
            cv2.imwrite(save_path, full_image)

def run_visualization(model, dataset, device, target_obj_ids,
                      img_idx=0, save_dir=None, draw_axes=False,
                      draw_legend=False, draw_models=False):
    """
    Runs pose prediction visualization for selected objects in a given image.

    Args:
        model (nn.Module): Trained PoseNet model.
        dataset (Dataset): PoseEstimationDataset instance.
        device (torch.device): CUDA or CPU device.
        target_obj_ids (set): Set of internal object IDs to visualize.
        img_idx (int): Index of the sample image to visualize.
        save_dir (str): Directory to save images (optional).
        draw_axes (bool): Whether to draw 3D coordinate axes.
        draw_legend (bool): Whether to draw GT/Pred legend.
        draw_models (bool): Whether to project and overlay 3D models.
    """
    model.eval()
    idx_to_id = {v: k for k, v in dataset.id_to_idx.items()}

    with torch.no_grad():
        data_item = dataset[img_idx]
        found_obj_ids = set()

        original_rgb = data_item['original_rgb']
        original_K = data_item['original_K']

        for obj in data_item['objects']:
            obj_id = obj['object_id']
            if obj_id in target_obj_ids:
                found_obj_ids.add(obj_id)

                # Inject full image and intrinsics into the object for visualization.
                obj['original_rgb'] = original_rgb
                obj['original_K'] = original_K

                visualize_pose_prediction(
                    obj, model, dataset, obj_id, device,
                    draw_axes_flag=draw_axes,
                    draw_models=draw_models,
                    legend=draw_legend,
                    save_dir=save_dir,
                    save_prefix=f"img{img_idx:03d}_obj{obj_id:02d}"
                )

        # Warn if some requested objects weren't present in the image.
        missing_ids = set(target_obj_ids) - found_obj_ids
        if missing_ids:
            missing_original_ids = sorted([idx_to_id[mapped_id] for mapped_id in missing_ids])
            print(f"⚠️ Note: These object IDs (mapped) were not present in image {img_idx}: {sorted(missing_ids)}")

## Plotting the results

In [None]:
# --- Dataset & Paths ---
dataset_root = os.path.join(full_project_path, 'dataset/LineMOD/Linemod_preprocessed/data')
models_root = os.path.join(full_project_path, "dataset/LineMOD/Linemod_preprocessed/models")

# Load a subset (e.g., objects 1 and 13)
folders = [2] # List of folder [1,2,3].
dataset = PoseEstimationDataset(dataset_root,
                                models_root,
                                folders=folders)

In [None]:
# Print the available object IDs from dataset.
print("Object IDs in dataset:", dataset.object_ids)  # Original LineMOD object IDs.
all_mapped_ids, _ = dataset.getMappedIDs()
print("Mapped IDs:", all_mapped_ids)  # Internally mapped IDs used in training/evaluation

# === Select which original object IDs to visualize ===
original_ids_to_visualize = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
# original_ids_to_visualize = [13]  # If you want to test a single object.

# === Load Trained Model ===
num_objects = len(original_ids_to_visualize)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model with same number of object classes used during training.
model = PoseNet_RGBD(num_objects=13)

# Load pretrained model weights
model.load_state_dict(torch.load('/content/best_posenet_OP.pt', map_location=device))
model.to(device)

# === Prepare Visualization Parameters ===
# Convert original LineMOD IDs to internal mapped IDs.
mapped_ids, _ = dataset.getMappedIDs(original_ids_to_visualize)

img_idx = 20 # Datasets index of the image to visualize.
save_dir = "/content/visualizations"  # Folder to save visualization images.

# === Run Visualization ===
run_visualization(
    model=model,
    dataset=dataset,
    device=device,
    target_obj_ids=mapped_ids,  # Mapped object IDs to look for in the selected image.
    img_idx=img_idx,
    save_dir=save_dir,
    draw_axes=True,     # Draw 3D coordinate axes (Pred vs GT).
    draw_legend=False,  # Hide the GT/Pred legend.
    draw_models=False   # Visualize 3D model for GT (green) and pred (red).
)