# 6D Pose Estimation

## Set up the project

We will work with a portion of this dataset, which you can find here: https://drive.google.com/drive/folders/19ivHpaKm9dOrr12fzC8IDFczWRPFxho7

In [None]:
# Step 1: Download the dataset (LineMOD)
# Download LineMOD dataset
# create directory structure without errors
!mkdir -p datasets/linemod/
%cd datasets/linemod/

Check working directory

In [None]:
!pwd

In [None]:
# Download DenseFusion Folder (Which includes a portion of the LimeMOD dataset)

!gdown --folder "https://drive.google.com/drive/folders/19ivHpaKm9dOrr12fzC8IDFczWRPFxho7"

In [None]:
!mkdir -p DenseFusion/
%cd DenseFusion/

In [None]:
!unzip Linemod_preprocessed.zip

Install requirements

In [None]:
! pip install -r ../../requirements.txt

Get working directory

In [None]:
path = !pwd
path = path[0]

In [None]:
import os
import yaml
import torch
import open3d as o3d
import itertools
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches

## Data Exploration

In [None]:
img_path = f"{path}/DenseFusion/Linemod_preprocessed/data/01/rgb/0000.png"
img = Image.open(img_path).convert("RGB")
plt.imshow(img)
plt.show()

In [None]:
class CustomDataset(Dataset): # used to load and preprocess data
    def __init__(self, dataset_root, split='train', train_ratio=0.8, seed=42):
        """
        Args:
            dataset_root (str): Path to the dataset directory.
            split (str): 'train' or 'test'.
            train_ratio (float): Percentage of data used for training (default 80%).
            seed (int): Random seed for reproducibility.
        """
        self.dataset_root = dataset_root
        self.split = split
        self.train_ratio = train_ratio
        self.seed = seed

        # Get list of all samples (folder_id, sample_id)
        self.samples = self.get_all_samples()

        # Check if samples were found
        if not self.samples:
            raise ValueError(f"No samples found in {self.dataset_root}. Check the dataset path and structure.")

        # Split into training and test sets
        self.train_samples, self.test_samples = train_test_split(
            self.samples, train_size=self.train_ratio, random_state=self.seed
        )

        # Select the appropriate split
        self.samples = self.train_samples if split == 'train' else self.test_samples

        # Define image transformations
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def get_all_samples(self):
        """Retrieve the list of all available sample indices from all folders."""
        samples = []
        for folder_id in range(1, 16):  # Assuming folders are named 01 to 15
            folder_path = os.path.join(self.dataset_root, 'data', f"{folder_id:02d}", "rgb")
            #print(folder_path)
            if os.path.exists(folder_path):
                # get id of the images
                sample_ids = sorted([int(f.split('.')[0]) for f in os.listdir(folder_path) if f.endswith('.png')])
                samples.extend([(folder_id, sid) for sid in sample_ids])  # Store (folder_id, sample_id)
        return samples
    
    def load_config(self, folder_id):
        """Load YAML configuration files for camera intrinsics and object info for a specific folder."""
        camera_intrinsics_path = os.path.join(self.dataset_root, 'data', f"{folder_id:02d}", 'info.yml')
        objects_info_path = os.path.join(self.dataset_root, 'models', f"models_info.yml")

        with open(camera_intrinsics_path, 'r') as f:
            camera_intrinsics = yaml.load(f, Loader=yaml.FullLoader)

        with open(objects_info_path, 'r') as f:
            objects_info = yaml.load(f, Loader=yaml.FullLoader)

        return camera_intrinsics, objects_info

    #Define here some usefull functions to access the data
    def load_image(self, img_path):
        """Load an RGB image and convert to tensor."""
        img = Image.open(img_path).convert("RGB")
        return self.transform(img)
    
    def load_depth(self, depth_path):
        """Load a depth image and convert to tensor."""
        depth = np.array(Image.open(depth_path))
        return torch.tensor(depth, dtype=torch.float32)
    
    def load_point_cloud(self, depth, intrinsics):
        """Convert depth image to point cloud using Open3D."""
        intrinsics = intrinsics[0]['cam_K'] # take intrinsincs of the first image
        h, w = depth.shape
        # focal lengths and principal centers
        fx, fy, cx, cy = intrinsics[0], intrinsics[4], intrinsics[2], intrinsics[5]

        # Generate 3D points
        xmap, ymap = np.meshgrid(np.arange(w), np.arange(h))
        z = depth / 1000.0  # Convert to meters
        x = (xmap - cx) * z / fx
        y = (ymap - cy) * z / fy

        points = np.stack((x, y, z), axis=-1).reshape(-1, 3)
        point_cloud = o3d.geometry.PointCloud()
        point_cloud.points = o3d.utility.Vector3dVector(points)

        return point_cloud

    def load_6d_pose(self, folder_id, sample_id):
        """Load the 6D pose (translation and rotation) for the object in this sample."""
        pose_file = os.path.join(self.dataset_root, 'data', f"{folder_id:02d}", "gt.yml")

        # Load the ground truth poses from the gt.yml file
        with open(pose_file, 'r') as f:
            pose_data = yaml.load(f, Loader=yaml.FullLoader)

        # The pose data is a dictionary where each key corresponds to a frame with pose info
        # We assume sample_id corresponds to the key in pose_data
        if sample_id not in pose_data:
            raise KeyError(f"Sample ID {sample_id} not found in gt.yml for folder {folder_id}.")

        pose = pose_data[sample_id][0]  # There's only one pose per sample

        # Extract translation and rotation
        translation = np.array(pose['cam_t_m2c'], dtype=np.float32)  # [3] ---> (x,y,z)
        rotation = np.array(pose['cam_R_m2c'], dtype=np.float32).reshape(3, 3)  # [3x3] ---> rotation matrix
        bbox = np.array(pose['obj_bb'], dtype=np.float32) #[4] ---> x_min, y_min, width, height
        obj_id = np.array(pose['obj_id'], dtype=np.float32) #[1] ---> label

        x_min, y_min, width, height = bbox
        x_max = x_min + width
        y_max = y_min + height
        # values given in the file are the top left corner
        bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32) #x_min, y_min, x_max, y_max

        return translation, rotation, bbox, obj_id

    def __len__(self):
        #Return the total number of samples in the selected split.
        return len(self.samples)

    def __getitem__(self, idx):
        #Load a dataset sample.
        folder_id, sample_id = self.samples[idx]

        # Load the correct camera intrinsics and object info for this folder
        camera_intrinsics, objects_info = self.load_config(folder_id)

        img_path = os.path.join(self.dataset_root, 'data', f"{folder_id:02d}", f"rgb/{sample_id:04d}.png")
        depth_path = os.path.join(self.dataset_root, 'data', f"{folder_id:02d}", f"depth/{sample_id:04d}.png")

        img = self.load_image(img_path)
        depth = self.load_depth(depth_path)
        point_cloud = self.load_point_cloud(depth.numpy(), camera_intrinsics)
        point_cloud = torch.tensor(np.asarray(point_cloud.points), dtype=torch.float32)
        translation, rotation, bbox, obj_id = self.load_6d_pose(folder_id, sample_id)

        #Dictionary with all the data
        return {
            "rgb": img,
            "depth": torch.tensor(depth, dtype=torch.float32),
            "point_cloud": point_cloud,
            "camera_intrinsics": camera_intrinsics[0]['cam_K'],
            "objects_info": objects_info,
            "translation": torch.tensor(translation),
            "rotation": torch.tensor(rotation),
            "bbox": torch.tensor(bbox),
            "obj_id": torch.tensor(obj_id)

        }

In [None]:
dataset_root = "./Linemod_preprocessed"

train_dataset = CustomDataset(dataset_root, split="train")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

print(f"Training samples: {len(train_dataset)}") # 12640
print(f"Training Loader samples: {len(train_loader)}") # 12640/batch_size

test_dataset = CustomDataset(dataset_root, split="test")
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

print(f"Testing samples: {len(test_dataset)}") # 3160
print(f"Test Loader samples: {len(test_loader)}") # 3160/batch_size

### Visualize data

In [None]:
train_subset_num_batches = 10 # take 10 batches from training (each batch has 4 samples)
test_subset_num_batches = 5 # take 5 batches from testing
train_subset = list(itertools.islice(train_loader, train_subset_num_batches))
test_subset=list(itertools.islice(test_loader, test_subset_num_batches))

In [None]:
# Get one batch from the train loader (4 images)
batch = next(iter(train_loader))

# Extract relevant data
rgb_images = batch["rgb"]         # (B, 3, H, W)
bboxes = batch["bbox"]            # (B, 4) in pixel coords: x_min, y_min, x_max, y_max
obj_ids = batch["obj_id"]         # (B,)

# Convert to numpy and rearrange channels
rgb_images = rgb_images.permute(0, 2, 3, 1).numpy()  # (B, H, W, 3)
bboxes = bboxes.numpy()
obj_ids = obj_ids.numpy()

# Plot settings
batch_size = rgb_images.shape[0]
cols = min(4, batch_size)
rows = (batch_size + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
axes = axes.flatten()

for i in range(batch_size):
    ax = axes[i]
    img = rgb_images[i]
    x_min, y_min, x_max, y_max = bboxes[i]
    width = x_max - x_min
    height = y_max - y_min #---> [x_min, y_min, width, height]
    obj_id = obj_ids[i]

    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"Sample {i}")

    # Draw bounding box
    rect = patches.Rectangle(
        (x_min, y_min),   # (x_min, y_min)
        width,              # width
        height,              # height
        linewidth=2,
        edgecolor='red',
        facecolor='none'
    )
    ax.add_patch(rect)

    # Add object ID as label
    ax.text(
        x_min,
        y_min - 10,
        f'ID: {int(obj_id)}',
        color='yellow',
        fontsize=10,
        backgroundcolor='black'
    )

# Hide unused axes if batch_size < cols * rows
for j in range(batch_size, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()