# 6D Pose Estimation

## Set up the project

We will work with a portion of this dataset, which you can find here: https://drive.google.com/drive/folders/19ivHpaKm9dOrr12fzC8IDFczWRPFxho7

Set some variables to conditionally run some codes. First download the project and change directory to ```6D_pose_estimation```

In [None]:
MOUNT_DRIVE = False
WANDB = False

In [None]:
if MOUNT_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    %cd /content/drive/MyDrive/6D_pose_estimation/

Install all dependencies of PyTorch dependencies

In [None]:
!pip install torch torchvision torchaudio

In [None]:
%%capture
import os
import torch

%env TORCH=$torch.__version__
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-${TORCH}.html

Install all packages, you may need to restart the runtime before continuing

In [None]:
!pip install -r ./requirements.txt
print("Restart runtime")

In [None]:
import os
import yaml
import torch
import torchvision
import open3d as o3d
import itertools
import shutil
from torch.utils.data import Dataset
from torch import nn, optim
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import wandb
from scipy.spatial.transform import Rotation as R
from torchvision import models
import cv2
from torch.optim import Adam
import quaternion
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from ultralytics import YOLO
from torchvision.transforms import v2
import trimesh

# install PyTorch Geometric after installation and restart
import torch_geometric
from torch import Tensor
from torch_geometric.nn import knn_interpolate, MessagePassing
from torch_geometric.nn.pool import fps, radius

from utils.data_exploration import load_image

IMG_WIDTH = 640
IMG_HEIGHT = 480

# check if everything works
try:
    from torch_geometric.nn.pool import fps
    print("PyTorch Geoemtric correctly installed")
    print(f"PyTorch version: {torch.__version__}")
    print(f"PyTorch Geometric version: {torch_geometric.__version__}")
except ImportError as e:
    print(f"Error: {e}")
    print("Check if you have restarted runtime after installation")

Set device

In [None]:
if torch.cuda.is_available():
    print("Cuda")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Cuda not available, use mps")
    device = torch.device("mps")
else:
    print("Use CPU")
    device = torch.device("cpu")

Connect to wandb

In [None]:
if WANDB:
    os.makedirs("./wandb", exist_ok=True)
    %env WANDB_DIR="./wandb"
    wandb.login(key="<YOUR_KEY>")
    wandb.init(project="6D_pose_estimation")

## Download dataset

In [None]:
# Step 1: Download the dataset (LineMOD)
# Download LineMOD dataset
# create directory structure without errors
!mkdir -p datasets/linemod/
%cd datasets/linemod/

In [None]:
!mkdir -p DenseFusion/
%cd DenseFusion/

In [None]:
# Download dataset (which includes a portion of the LimeMOD dataset)
!gdown --folder "https://drive.google.com/drive/folders/19ivHpaKm9dOrr12fzC8IDFczWRPFxho7"

In [None]:
!unzip Linemod_preprocessed.zip
!rm Linemod_preprocessed.zip
%cd ../../../

Get working directory

In [None]:
path = !pwd
path = path[0]

## Data Exploration

Load an image

In [None]:
load_image(label=1, object=0)

In [None]:
from data.get_samples import get_samples

# divide dataset into training, validation and testing set for training YOLO
train_samples = get_samples(split="train")
print(f"Training samples: {len(train_samples)}")

validation_samples = get_samples(split="validation")
print(f"Validation samples: {len(validation_samples)}")

test_samples = get_samples(split="test") # test folder is optional for training YOLO
print(f"Testing samples: {len(test_samples)}")

## Data Preprocessing

Structure the data such that
```
datasets/
├── data.yaml
│
├── train/
│   ├── images/
│   │
│   └── labels/
│  
├── val/
│
└── test/
```

In [None]:
# divide the dataset into training, validation and testing set
train_samples = train_dataset.get_samples_id()
validation_samples = val_dataset.get_samples_id()
test_samples = test_dataset.get_samples_id() # test folder is optional for training YOLO

Create a new folder containing all the info, we just need the rgb image and a text file with the label and bounding box.
The ```Linemod_preprocessed``` is not removed, as it contains info about translation and rotation that are needed for pose estimation, but not for object detection model.

The working directory is in the ```DenseFusion```

In [None]:
# create a folder to contain the dataset for YOLO model
os.makedirs("../YOLO/datasets", exist_ok=True)

# count number of distinct classes
number_classes = 0
class_names = []
for el in os.scandir("./Linemod_preprocessed/data"):
    # if entry is a directory and its name is an integer value (this is just to avoid counting non directories or other directories)
    if (el.is_dir() and el.name.isdigit()):
        class_names.append(el.name)
        number_classes += 1

# get string of all class names
class_names.sort() # sort the names
names = "["
for index, el in enumerate(class_names):
    # if last element don't add comma
    if index == number_classes-1:
        names += f"'{str(el)}'"
    else:
        names += f"'{str(el)}',"
names += "]"

# create data.yaml (as class names use ids of the folder)
content = f"""train: ./train/images
val: ./val/images
test: ./test/images

nc: {number_classes}
names: {names}"""
# write to file
with open("../YOLO/datasets/data.yaml", "w") as fout:
    fout.write(content)
fout.close()

While creating the folder structure, we have to change the class id by using the index in the array written in the ```data.yaml```

In [None]:
# create a dictionary to have easily access to the index
index_dict = dict()
for index, el in enumerate(class_names):
    index_dict[int(el)] = index

Create the folders. Note that each image may contain multiple objects. For instance in ```data/02/gt.yml``` for one image there are multiple objects, but just consider the object of that class

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os

# create images and labels
# dataset = [train_samples, validation_samples, test_samples]
folder_names = ["train", "val", "test"]

# count also the number of instances of each class
classes = range(0, number_classes)
counter_df = pd.DataFrame()
for idx in range(3):
    if idx == 0:
        dataset = train_samples
    elif idx == 1:
        dataset = validation_samples
    else:
        dataset = test_samples
    print(f"------------------------------{folder_names[idx].upper()}------------------------------")
    os.makedirs(f"../YOLO/datasets/{folder_names[idx]}/images", exist_ok=True)
    os.makedirs(f"../YOLO/datasets/{folder_names[idx]}/labels", exist_ok=True)
    classCount = {label_object: 0 for label_object in index_dict.keys()} # initialize dictionary for counting
    total = 0 # used to normalize count
    for el in tqdm(dataset, desc="Moving..."):
        # el is (folderId, sampleId)
        _, _, bbox, obj_id = train_dataset.load_6d_pose(el[0], el[1])
        # copy image into the new folder
        # avoid overwriting the files, so concat also the name of the folderId to the destination file
        shutil.copy(f"./Linemod_preprocessed/data/{el[0]:02d}/rgb/{el[1]:04d}.png", f"../YOLO/datasets/{folder_names[idx]}/images/{el[0]:02d}_{el[1]:04d}.png")
        # create label file with the same name as the image
        with open(f"../YOLO/datasets/{folder_names[idx]}/labels/{el[0]:02d}_{el[1]:04d}.txt", "w") as fout:
            # bbox is a list of values in the form of [x_center, y_center, width, height] and obj_id a list of class labels
            # where each label is in the format 01-15
            classCount[int(obj_id)] += 1
            total += 1
            content = f"{index_dict[int(obj_id)]} {bbox[0]} {bbox[1]} {bbox[2]} {bbox[3]}\n"
            fout.write(content)
        fout.close()
    
    # store in the dataframe
    values = pd.array(list(classCount.values()))/total
    counter_df[folder_names[idx]] = values.copy()

In [None]:
# plot distribution of labels in training, validation and test set
fig, axes = plt.subplots(1,3,figsize=(15,6),sharey=True)
for index, column in enumerate(counter_df.columns):
    axes[index].barh([str(el) for el in index_dict.keys()], counter_df[column],color="orange", edgecolor='gray')
    axes[index].set_title(column.capitalize())
    # add line that represents the uniform distribution of the labels
    axes[index].axvline(x=1/number_classes, color="blue")
    axes[index].text(x=1/number_classes,y=-0.5,s=f"{1/number_classes: .5f}", color="blue")

fig.supxlabel("Frequency")
fig.supylabel("Labels")
plt.subplots_adjust(left=0.07, wspace=0.1)
plt.suptitle("Labels Distribution over the Training, Validation and Test sets")
plt.savefig("../../../images/YOLO_dataset_distribution.png")
plt.show()

### Visualize data

Visualize depth image

In [None]:
img_path = "./Linemod_preprocessed/data/02/depth/0000.png"
img = Image.open(img_path)
plt.imshow(img)
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Plot image with bounding box

# Load the ground truth poses from the gt.yml file
with open("./Linemod_preprocessed/data/02/gt.yml", 'r') as f:
  pose_data = yaml.load(f, Loader=yaml.FullLoader)
pose = pose_data[0][1] # access image 0 (start counting from 0) and get second object in that image (in case of multiple objects)

bbox = np.array(pose['obj_bb'], dtype=np.float32) #[4]
obj_id = np.array(pose['obj_id'], dtype=np.float32) #[1]

fig, ax = plt.subplots()
ax.imshow(img)

# Create a rectangle patch
rect = patches.Rectangle(
    (bbox[0], bbox[1]),  # (x, y)
    bbox[2],             # width
    bbox[3],             # height
    linewidth=2,
    edgecolor='red',
    facecolor='none'
)

# Add the rectangle to the plot
ax.add_patch(rect)

# Optionally add object ID label (write a bit above the top left corner)
ax.text(bbox[0], bbox[1] - 10, f'ID: {int(obj_id)}', color='yellow', fontsize=12, backgroundcolor='black')

plt.axis('off')
plt.show()

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
print(f"Training loader: {len(train_loader)}")
print(f"Validation loader: {len(val_loader)}")
print(f"Test loader: {len(test_loader)}")

In [None]:
import itertools

# Get only the first 1 batch
train_subset_num_batches = 1
val_subset_num_batches = 1
test_subset_num_batches = 1
train_subset = list(itertools.islice(train_loader, train_subset_num_batches))
val_subset = list(itertools.islice(val_loader, val_subset_num_batches))
test_subset = list(itertools.islice(test_loader, test_subset_num_batches))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Get one batch from the train loader (4 images)
batch = next(iter(train_loader)) # it uses load_6d_pose, so one pose per object

# Extract relevant data
rgb_images = batch["rgb"]         # (B, 3, H, W)
bboxes = batch["bbox"]            # (B, 4) in pixel coords: x_min, y_min, x_max, y_max
obj_ids = batch["obj_id"]         # (B,)

# Convert to numpy and rearrange channels
rgb_images = rgb_images.permute(0, 2, 3, 1).numpy()  # (B, H, W, 3)
bboxes = bboxes.numpy()
obj_ids = obj_ids.numpy()

# Plot settings
batch_size = rgb_images.shape[0]
cols = min(4, batch_size)
rows = (batch_size + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
axes = axes.flatten()

for i in range(batch_size):
    ax = axes[i]
    img = rgb_images[i]
    # each element is [x_center/IMG_WIDTH, y_center/IMG_HEIGHT, width/IMG_WIDTH, height/IMG_HEIGHT]
    x_center, y_center, width, height = bboxes[i]
    # remove normalization
    x_center = x_center*IMG_WIDTH
    y_center = y_center*IMG_HEIGHT
    width = width*IMG_WIDTH
    height = height*IMG_HEIGHT
    x_min = x_center-(width/2)
    y_min = y_center-(height/2)
    obj_id = obj_ids[i]

    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"Sample {i}")

    # Draw bounding box
    rect = patches.Rectangle(
        (x_min, y_min),   # (x_min, y_min)
        width,              # width
        height,              # height
        linewidth=2,
        edgecolor='red',
        facecolor='none'
    )
    ax.add_patch(rect)

    # Add object ID as label
    ax.text(
        x_min,
        y_min - 10,
        f'ID: {int(obj_id)}',
        color='yellow',
        fontsize=10,
        backgroundcolor='black'
    )

# Hide unused axes if batch_size < cols * rows
for j in range(batch_size, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()

## Training Object Detection model

Check if CUDA available, otherwise try with MPS and then CPU

In [None]:
if torch.cuda.is_available():
    print("Cuda")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Cuda not available, use mps")
    device = torch.device("mps")
else:
    print("Use CPU")
    device = torch.device("cpu")

In [None]:
%cd ../YOLO/

In [None]:
path = !pwd
path = path[0]

In [None]:
from ultralytics import YOLO

model_path = "../../../checkpoints/yolo11n.pt"
model = YOLO(model_path)
epochs = 50
batch_size = 64
IMG_SIZE = 640

In [None]:
# model will automatically scale the image and related bounding box according to imgsz
results = model.train(data=f"{path}/datasets/data.yaml", epochs=epochs, batch=batch_size, device=device,
        imgsz=IMG_SIZE,
        augment=True,
        flipud=0.5,
        fliplr=0.5,
        hsv_h=0.4,
        hsv_s=0.4,
        hsv_v=0.4,
        degrees=120,
        translate=0.1,
        scale=0.5,
        shear=20,
        perspective=0.0001
    )

Copy model file to ```checkpoints```

In [None]:
shutil.copy(f"./runs/detect/train/weights/best.pt", f"../../../checkpoints/best.pt")

Validate model

In [None]:
model_path = "../../../checkpoints/best.pt"
model = YOLO(model_path)
results = model.val(
        data=f"{path}/datasets/data.yaml",
        epochs=epochs,
        batch=batch_size,
        imgsz=IMG_SIZE,
        device=device
    )

Test model

In [None]:
model_path = "../../../checkpoints/best.pt"
model = YOLO(model_path)
results = model.val(
        data=f"{path}/datasets/data.yaml",
        epochs=epochs,
        batch=batch_size,
        imgsz=IMG_SIZE,
        device=device,
        split="test"
    )

## Pose Estimator Module

In [None]:
%cd ../

In [None]:
path = !pwd
path = path[0]

In [None]:
from tqdm import tqdm

os.makedirs("./YOLO/datasets_cropped", exist_ok=True)

# input: path YOLO/datasets
# output: save images

def mask_outside_bbox(path="./YOLO/datasets", color=(0, 0, 0)):

    for folder in os.scandir(path):
        if folder.is_dir():
            os.makedirs(f"./YOLO/datasets_cropped/{folder.name}", exist_ok=True)
            path_dir= os.path.join(path, folder.name)
            images_path = os.path.join(path_dir, "images")
            labels_path = os.path.join(path_dir, "labels")

            for img_name, label_name in zip(sorted(os.listdir(images_path)), sorted(os.listdir(labels_path))):
                img_path = os.path.join(images_path, img_name)
                label_path = os.path.join(labels_path, label_name)
                img = cv2.imread(img_path)
                
                with open(label_path, "r") as f:
                    _, x_center, y_center, w, h = map(float,f.readline().split(" "))
                f.close()
                
                img_height, img_width, _ = img.shape
                
                x1 = round((x_center - w / 2) * img_width)
                y1 = round((y_center - h / 2) * img_height)
                x2 = round((x_center + w / 2) * img_width)
                y2 = round((y_center + h / 2) * img_height)
                
                img_masked = np.full_like(img, color)
                img_masked[y1:y2, x1:x2] = img[y1:y2, x1:x2]
                
                cv2.imwrite(f"./YOLO/datasets_cropped/{folder.name}/{img_name}",img=img_masked)   
             
mask_outside_bbox()

Copy the gt.yaml and info.yaml

In [None]:
for el in class_names:
    shutil.copy(f"./DenseFusion/Linemod_preprocessed/data/{el}/gt.yml", f"./YOLO/datasets_cropped/{el}_gt.yml")
    shutil.copy(f"./DenseFusion/Linemod_preprocessed/data/{el}/info.yml", f"./YOLO/datasets_cropped/{el}_info.yml")

In [None]:
class CustomDatasetPose(Dataset): # used to load and preprocess data
    def __init__(self, dataset_root, split='train'):
        """
        Args:
            dataset_root (str): Path to the dataset directory.
            split (str): 'train', 'validation' or 'test'.
        """
        self.dataset_root = dataset_root
        if split == "train":
            self.split = split
        elif split == "validation":
            self.split = "val"
        else:
            self.split = "test"

        # Get list of all samples (folder_id, sample_id)
        self.samples = self.get_all_samples()

        # Check if samples were found
        if not self.samples:
            raise ValueError(f"No samples found in {self.dataset_root}. Check the dataset path and structure.")

        # Define image transformations
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def get_samples_id(self):
        return self.samples

    def get_all_samples(self):
        """Retrieve the list of all available sample indices from all folders."""
        for folder in ["train","val","test"]:
            if folder == self.split:
                folder_path = os.path.join(self.dataset_root, f"{folder}")
                #print(folder_path)
                if os.path.exists(folder_path):
                    # get name of files <folder id>_<image>
                    sample_ids = sorted([f.split('.')[0] for f in os.listdir(folder_path) if f.endswith('.png')])
        return sample_ids

    #Define here some usefull functions to access the data
    def load_image(self, img_path):
        """Load an RGB image and convert to tensor."""
        img = Image.open(img_path).convert("RGB")
        return self.transform(img)

    def load_6d_pose(self, sample_id):
        """Load the 6D pose (translation and rotation) for the object in this sample."""
        label = int(sample_id.split("_")[0])
        objectId = int(sample_id.split("_")[1])
        pose_file = os.path.join(self.dataset_root, f"{label:02d}_gt.yml")

        # Load the ground truth poses from the gt.yml file
        with open(pose_file, 'r') as f:
            pose_data = yaml.load(f, Loader=yaml.FullLoader)

        # The pose data is a dictionary where each key corresponds to a frame with pose info
        # We assume sample_id corresponds to the key in pose_data
        if objectId not in pose_data:
            raise KeyError(f"Sample ID {objectId} not found in {label:02d}_gt.yml.")

        for pose in pose_data[objectId]: # There can be more than one pose per sample, but take the one of label=folder_id
            # Extract translation and rotation
            if (int(pose['obj_id']) == int(label)):
                translation = np.array(pose['cam_t_m2c'], dtype=np.float32)  # [3] ---> (x,y,z)
                rotation = np.array(pose['cam_R_m2c'], dtype=np.float32).reshape(3, 3)  # [3x3] ---> rotation matrix
                # bbox is top left corner and width and height info, YOLO needs center coordinates and width and height
                x_min, y_min, width, height = np.array(pose['obj_bb'], dtype=np.float32) # [4] ---> x_min, y_min, width, height
                # compute initial center
                x_center = x_min + width/2
                y_center = y_min + height/2
                
                # store coordinates of the center and width and height of the bounding box normalized to the
                # image width=640 pixels and height=480 pixels
                bbox = np.array([x_center/IMG_WIDTH, y_center/IMG_HEIGHT, width/IMG_WIDTH, height/IMG_HEIGHT], dtype=np.float32)

                obj_id = np.array(pose['obj_id'], dtype=np.float32) # [1] ---> label
                break

        return translation, rotation, bbox, obj_id

    def __len__(self):
        #Return the total number of samples in the selected split.
        return len(self.samples)

    def __getitem__(self, idx):
        #Load a dataset sample.
        sample = self.samples[idx]

        img_path = os.path.join(self.dataset_root, f"{self.split}", f"{sample}.png")

        img = self.load_image(img_path)
        translation, rotation, bbox, obj_id = self.load_6d_pose(sample)

        #Dictionary with all the data
        return {
            "rgb": img,
            "translation": torch.tensor(translation),
            "rotation": torch.tensor(rotation),
            "bbox": torch.tensor(bbox),
            "obj_id": torch.tensor(obj_id),
        }

In [None]:
dataset_root_pose = "./YOLO/datasets_cropped"

train_dataset = CustomDatasetPose(dataset_root_pose, split="train")
print(f"Training samples: {len(train_dataset)}")

val_dataset = CustomDatasetPose(dataset_root_pose, split="validation")
print(f"Validation samples: {len(val_dataset)}")

test_dataset = CustomDatasetPose(dataset_root_pose, split="test")
print(f"Testing samples: {len(test_dataset)}")

In [None]:
class PoseEstimatorQuat6D(nn.Module):
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super(PoseEstimatorQuat6D, self).__init__()

        # Load backbone (without last fully connected layer)
        backbone = getattr(models, backbone_name)(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        in_features = backbone.fc.in_features

        # 7 output (4 quaternion, 3 translation)
        self.fc = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Linear(256, 7)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        output = self.fc(x)

        quat = output[:, :4]
        quat = quat / torch.norm(quat, dim=1, keepdim=True)  # Normalize quaternion

        trans = output[:, 4:]
        return quat, trans

In [None]:
def criterion(outputs, translation, rotation):
    quat, trans = outputs
    # rotation is list of type quaternion, make it tensor of float32 not float64 (double)
    rotation = torch.tensor(np.stack([el.components for el in rotation]).astype(np.float32)).to(device)
    loss_quat = F.mse_loss(quat, rotation)
    loss_trans = F.mse_loss(trans, translation)

    return loss_quat + loss_trans

In [None]:
from torch.optim import Adam
import quaternion

def train(model, epoch, dataloader, criterion, optimizer=Adam, device=device):
    model.train()
    running_loss = 0.0

    for batch_idx, data in enumerate(dataloader):
        img, translation, rotation = data["rgb"], data["translation"], data["rotation"]
        img, translation, rotation = img.to(device), translation.to(device), quaternion.from_rotation_matrix(rotation)

        outputs = model(img)
        loss = criterion(outputs, translation, rotation)
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if batch_idx % 2 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} '
                f'({100. * batch_idx / len(dataloader):.0f}%)]\t Loss: {loss.item():.6f}')

In [None]:
learning_rate=[0.001]
batch_size=[2]
num_epoch=10
for lr in learning_rate:
    for batch in batch_size:
        train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=batch, shuffle=False)
        model = PoseEstimatorQuat6D().to(device)
        print(model)
        for epoch in range(1, num_epoch):
            train(model, epoch, train_dataloader, criterion, optimizer=Adam(model.parameters(), lr=lr), device=device)

Plot

In [None]:
import quaternion
import trimesh

In [None]:
def plotPose(pathImage, translation_gt, rotation_gt, translation_pred, rotation_pred):
    '''
        Input:
            path for image (in DenseFusion)
            ground truth translation tensor (in millimeters)
            ground truth rotation tensor (either matrix or quaternion)
            predicted translation tensor (in millimeters)
            predicted rotation tensor (either matrix or quaternion)
    '''

    # read image
    image = cv2.imread(pathImage)
    transparent_image = image.copy() # copy of the image to work on a transparent image (for the second reference system)

    # read translation and rotation
    rotat_gt = rotation_gt.numpy() # transform tensor to numpy array
    trans_gt = translation_gt.numpy()/1000 # in meters
    rotat_pred = rotation_pred.numpy()
    trans_pred = translation_pred.numpy()/1000

    # read camera intrinsics
    label = pathImage.split("/")[-1].split(".")[0].split("_")[0]
    image_id = pathImage.split("/")[-1].split(".")[0].split("_")[1]
    with open(f"./YOLO/datasets_cropped/{label}_info.yml", 'r') as f:
            camera_info = yaml.load(f, Loader=yaml.FullLoader)
    camera_intrinsics = np.array(camera_info[int(image_id)]["cam_K"]).reshape(3,3)

    # read 3D model
    meshModel = trimesh.load(f"./DenseFusion/Linemod_preprocessed/models/obj_{label}.ply")
    vertices = meshModel.vertices/1000 # it has 3 columns, for X, Y, Z, use unit of measurement of translation
    # compute corners
    min_corner = vertices.min(axis=0) # find for each column the smallest value
    max_corner = vertices.max(axis=0)
    
    bounding_box_3d = np.array([[min_corner[0], min_corner[1], min_corner[2]],
                                [max_corner[0], min_corner[1], min_corner[2]],
                                [max_corner[0], max_corner[1], min_corner[2]],
                                [min_corner[0], max_corner[1], min_corner[2]],
                                [min_corner[0], min_corner[1], max_corner[2]],
                                [max_corner[0], min_corner[1], max_corner[2]],
                                [max_corner[0], max_corner[1], max_corner[2]],
                                [min_corner[0], max_corner[1], max_corner[2]],])

    # convert quaternion to rotation matrix, if input was quaternion
    if rotat_gt.size == 4:
        rotat_gt = quaternion.as_rotation_matrix(np.quaternion(*rotat_gt))
    else:
        rotat_gt = rotat_gt.reshape(3,3)
    if rotat_pred.size == 4:
        rotat_pred = quaternion.as_rotation_matrix(np.quaternion(*rotat_pred))
    else:
        rotat_pred = rotat_pred.reshape(3,3)

    # build 3D axes according to object coordinate system, same unit of measurement of translation, so in meters
    axes_3d = np.array([
        [0, 0, 0],      # origin, in the object coordinate system
        [0.15, 0, 0],   # how long the arrow should be in the X coordinate
        [0, 0.15, 0],   # how long the arrow should be in the Y coordinate
        [0, 0, 0.15]    # how long the arrow should be in the Z coordinate
    ])

    # transform the object coordinate system to the camera coordinate system
    # rotat_gt is 3x3, so axes_3d has to be transposed, then add to origin, and coordinates the translation
    axes_cam_gt = (rotat_gt @ axes_3d.T).T + trans_gt
    # bounding box
    bounding_box_3d_cam_gt = (rotat_gt @ bounding_box_3d.T).T + trans_gt
    # project 3D axes to 2D
    axes_2d_gt = (camera_intrinsics @ axes_cam_gt.T).T # camera_intrinsics is 3x3, while axes_cam_gt 4x3, axes_2d_gt 4x3
    axes_2d_gt = axes_2d_gt[:, :2] / axes_2d_gt[:, 2:3] # take first 2 columns and normalize by depth
    # bounding box
    bounding_box_2d_gt = (camera_intrinsics @ bounding_box_3d_cam_gt.T).T
    bounding_box_2d_gt = (bounding_box_2d_gt[:, :2] / bounding_box_2d_gt[:, 2:3]).astype(int)
    # get point coordinates
    p_gt = [tuple(el) for el in bounding_box_2d_gt]
    # define edges using two points, access with index
    edges = [(0,1), (1,2), (2,3), (3,0), # bottom
             (0,4), (1,5), (2,6), (3,7), # vertical
             (4,5), (5,6), (6,7), (7,4) # top
             ]
    # draw edges
    for el in edges:
        cv2.line(image, p_gt[el[0]], p_gt[el[1]], (0,0,255), 5)

    p0_gt = tuple(axes_2d_gt[0].astype(int)) # take origin coordinates
    p1_gt = tuple(axes_2d_gt[1].astype(int))
    p2_gt = tuple(axes_2d_gt[2].astype(int))
    p3_gt = tuple(axes_2d_gt[3].astype(int))

    # color is in BGR format, set tickness=2
    cv2.arrowedLine(image, p0_gt, p1_gt, (0, 0, 255), 2) # X is red
    cv2.arrowedLine(image, p0_gt, p2_gt, (0, 255, 0), 2) # Y is green
    cv2.arrowedLine(image, p0_gt, p3_gt, (255, 0, 0), 2) # Z is blue

    # for predicted
    axes_cam_pred = (rotat_pred @ axes_3d.T).T + trans_pred
    bounding_box_3d_cam_pred = (rotat_pred @ bounding_box_3d.T).T + trans_pred

    axes_2d_pred = (camera_intrinsics @ axes_cam_pred.T).T
    axes_2d_pred = axes_2d_pred[:, :2] / axes_2d_pred[:, 2:3]
    bounding_box_2d_pred = (camera_intrinsics @ bounding_box_3d_cam_pred.T).T
    bounding_box_2d_pred = (bounding_box_2d_pred[:, :2] / bounding_box_2d_pred[:, 2:3]).astype(int)

    p_pred = [tuple(el) for el in bounding_box_2d_pred]
    edges = [(0,1), (1,2), (2,3), (3,0), # bottom
             (0,4), (1,5), (2,6), (3,7), # vertical
             (4,5), (5,6), (6,7), (7,4) # top
             ]
    for el in edges:
        cv2.line(image, p_pred[el[0]], p_pred[el[1]], (255,0,0), 5)

    p0_pred = tuple(axes_2d_pred[0].astype(int))
    p1_pred = tuple(axes_2d_pred[1].astype(int))
    p2_pred = tuple(axes_2d_pred[2].astype(int))
    p3_pred = tuple(axes_2d_pred[3].astype(int))

    cv2.arrowedLine(transparent_image, p0_pred, p1_pred, (0, 0, 255), 2)
    cv2.arrowedLine(transparent_image, p0_pred, p2_pred, (0, 255, 0), 2)
    cv2.arrowedLine(transparent_image, p0_pred, p3_pred, (255, 0, 0), 2)

    # show image after merging the two images
    overlapImage = cv2.addWeighted(transparent_image, 0.5, image, 1, 0)
    plt.imshow(cv2.cvtColor(overlapImage, cv2.COLOR_BGR2RGB))
    plt.title("Object Pose Estimation (prediction is transparent)")
    plt.show()

    return