## **CIS 5810 Project 8 - Part I - Hand Pose Estimation with Transformer**

**Introduction**

In Project 7 we follow a top-down method to perform 2D hand pose estimation from a single RGB image, using one of the most popular CNN based network - UNet. In this project, we will extend the hand pose estimation task to 3D with the focus on building a Transformer based network, which is the core design of many state-of-the-art methods.

**Lifting 2D to 3D hand pose estimation**

The task is to lift 2D hand pose to get the 3D hand pose. In other words, the input to our model is the 2D hand keypoints $P (J,2)$ and the output is 3D hand keypoints $\hat{P}(J,3)$ in camera coordinate system. Since there is no image input, the training time would be greatly reduced in this project.

**What to do**

All parts that need your implementation are marked as *TODOs*, including model architecture build, training pipeline and final inference & evaluation. The file structure and helper function detail are listed below.

- `imgs/`: directory where sample images for display are stored.
- `dataset/`
    - `dataset.py`: main Dataset to load and preprocess Ego-Exo4D data for model training.
    - `data_vis.py`: some helper functions to visualize 2D & 3D hand kepypoints.
- `utils/`
    - `utils.py`: a list of utility functions to help model training and debugging.
    - `loss.py`: implementation of loss function and metrics to evaluate model performance.
- `model/`
    - `model.py`: Implementation of PoseTransformer, which is the model for pose estimation
- `CIS_5810_Project_8-1.ipynb`: intergrate every parts from above together from loading dataset to model training and testing.


### 0 - Set-up

In [None]:
import argparse
import os
import cv2
import numpy as np
import torch
import time
from easydict import EasyDict as edict
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from dataset.dataset import ego4dDataset
from dataset.data_vis import *
from model.model import PoseTransformer
from utils.utils import *
from utils.loss import Pose3DLoss, mpjpe, p_mpjpe
import wandb

#### 1 - Load Ego4D Dataset

Download the data from [here](https://drive.google.com/drive/folders/1g9SpvjDyndOGQg70AqqkKimZRPwvgV5_?usp=sharing) and put it under `{anno_dir}` as shown below.

```
{anno_dir}
    ├── ego_pose_gt_anno_test.json
    ├── ego_pose_gt_anno_train.json
    └── ego_pose_gt_anno_val.json
```

The JSON file is similar to the one you used from Project 7 but with additional 3D hand poses annotation. Run cell below to gain better idea about the dataset.

*NOTE: Ego-Exo4D dataset annotation is still in progress, thus there might be some bad/missing annotations. However, the general quality of the dataset should be good enough for you to train a hand pose estimation model that gives reasonable predictions.*

In [None]:
# TODO: Modify data_root_path as {anno_dir}
data_root_path = ...

# Since there is only hand keypoints data, we simply just transform it to tensor
transform = transforms.Compose([
    transforms.ToTensor()
])

# TODO: Initialize the train, val and test Dataset
# Hint: take a look at the implementation of ego4dDataset for initialization requirement
train_dataset = ...
val_dataset = ...
test_dataset = ...

# Check the dataset length
print("Train: ", len(train_dataset))
print("Val: ", len(val_dataset))
print("Test: ", len(test_dataset))

**Visualizing the dataset**

We follow the same hand joints index distribution as in project 7 (See figure below). The relationship between 2D and 3D hand keypoints can be defined as

$$\lambda \begin{bmatrix}
u\\
v\\
1
\end{bmatrix} = K \begin{bmatrix}
X_c\\
Y_c\\
Z_c
\end{bmatrix}$$

where $[u,v]^{T}$ are the 2D hand keypoints on image plane, $[X_c, Y_c, Z_c]^{T}$ are 3D hand keypoints in camera coodinate system, $\lambda$ is the scale factor, and $K$ is the camera intrinsic matrix. If you take a look at the implementation of the provided `ego4dDataset` dataset, there are two preprocessing we have done to help training:

1. The 3D hand keypoints are offseted by hand wrist s.t. $\hat{P}[0]=[0,0]$. Recovering the 3D coordinates from 2D coordinates is not a trivial problem as the scale information($\lambda$) is unknown. Therefore, we are not predicting the real 3D hand pose in the camera coordinate system, but rather the relative location of each hand joint w.r.t the hand wrist.

2. The 2D and offseted 3D hand keypoints are then normalized by subtracting the mean and then divided by stand deviation, similar to what we did in project 7 (on preprocess images) to help stablize model training and convergence.

<div style="display: center; justify-content: space-between; text-align: center;">
    <figure>
        <img src="imgs/hand_index.png" alt="Hand index" width ="300" height="350">
        <figcaption>Figure 1: Hand joints index</figcaption>
    </figure>
</div>

**Dataset items**

If you take a look at the dataset returnd data it has four objects: `kpts_2d`, `kpts_3d`, `weight`, `metadata`. `kpts_2d` and `kpts_3d` corresponds to the processed 2D and 3D hand keypoints; `metadata` contains frame number and take info for quick debugging; `weight` is a boolean array indicating whether each joint is valid or not. If `False`, then this joint won't been included during model training.


In cell below, you can select data from a different dataset and visualize those ground truth pose. You can even try visualize the 2D hand keypoints by overlaying it on the real images from project 7 to get a better understanding.

In [None]:
## TODO: Modify as needed to take a look at the dataset
check_dataset = ...
idx = ...

## Get one sample for visualization
kpts_2d, kpts_3d, weight, metadata = check_dataset[idx]

# Visualize 2D hand kpts in image plane
kpts_2d = check_dataset.inv_normalize_2d(kpts_2d.numpy())
# Assign None to invalid kpts (so it won't be displayed)
kpts_2d[~weight] = None
vis_data_2d(kpts_2d, title="GT 2D")

# Visualize 3D hand kpts in camera coordinate system
# Hint: Change the 3D plot view angle to get a better visualization:
# https://matplotlib.org/stable/api/toolkits/mplot3d/view_angles.html
gt_3d = check_dataset.inv_normalize_3d(kpts_3d.numpy())
# Assign None to invalid kpts (so it won't be displayed)
gt_3d[~weight] = None
vis_data_3d(gt_3d, title="GT 3D")

### 2 - Define model

**General overview**

The 2D to 3D hand pose lifting model is based on Transformer with architecture shown in Figure 2. It has a similar design as the [ViT](https://arxiv.org/pdf/2010.11929.pdf), but takes 2D hand keypoints $(21,2)$ as input and treat each joint as a patch. Each 2D keypoints path is first projected into embedding space via linear layer $(21,D)$, and then added along with positional encoding to retain positional information. The resulting emebedded feature vectors are then fed into Transformer Encoder, which consists of alternating layers of multihead self-attention (MSA), MLP layer and layer normalization (LN). The same encoder is repeated several times and the output of final encoder block is then fed into the MLP head, which consists of layer normalization and a single linear layer, to project embedded vectors $(21,D)$ back to 3D coordinates $(21,3)$ as the final model output.

<div style="display: center; justify-content: space-between; text-align: center;">
    <figure>
        <img src="imgs/model.png" alt="Model architecture" width ="700" height="530">
        <figcaption>Figure 2: Model architecture</figcaption>
    </figure>
</div>

**Multihead Self-Attention (MSA)**

The core design of the Transformer is the multihead self-attention mechanism. For a single head, given an input sequence $z \in \mathbb{R}^{(N,D)}$, we first compute three feature vectors $q,k,v \in \mathbb{R}^{N \times D_h}$, namely Query, Key and Value, via linear projection s.t.

$$[q,k,v] = zU_{qkv} \quad\quad\quad U_{qkv} \in \mathbb{R}^{D \times 3D_h}$$

Then compute the attention weight $A$, where $A_{ij}$ indicates the pairwise similarity between two elements $z_i$ and $z_j$ computed based on their respective query $q_i$ and value $k_i$:

$$A = softmax(\frac{qk^T}{\sqrt{D_h}}) \quad\quad\quad A \in \mathbb{R}^{N \times N}$$

Then we compute the weighted sum over all values $v$ in the sequence via self-attention (SA):

$$SA(z)=Av$$

Multihead self-attention (MSA) extends the self-attention (SA) in which we run $k$ self-attention operations, which are called *head*, in parallel and then project the concatenated head output as current MSA output. In this case, the head embedding dimension $D_h$ is defined as $\frac{D}{k}$ s.t. the concatenated vectors maintain the same embedding dimension $D$.

$$MSA(z) = [SA_1(z), SA_2(z), ..., SA_k(z)]U_{msa} \quad\quad\quad U_{msa} \in \mathbb{R}^{D \times D}$$


**TODO**

Your task is to finish the implementation of `forward()` method of Encoder block and Attention block, and part of the main model in `model/model.py` (skeleton code has been given), and the training and evaluation pipeline below within this notebook which are all marked as *TODO*.

### 3 - Train model

Define the pipeline of model training and validation similar to project 7. See below TODOs to fill in your implementation.

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, device):
    total_loss = AverageMeter()

    # TODO: set model to training mode


    train_loader = tqdm(train_loader, dynamic_ncols=True)
    print_interval = len(train_loader) // 6
    # Iterate over all training samples
    for i, (kpts_2d, gt_kpts_3d, weight, _) in enumerate(train_loader):
        # TODO:
        # 1. Put all revelant data onto same device
        # 2. Model forward (given 2d kpts and weight, return 3d kpts)


        # TODO: Compute loss (remember to pass weight)
        loss = ...
        # Record current batch's loss
        total_loss.update(loss.item())


        # TODO:
        # 1. Clear the old parameter gradients
        # 2. Compute the derivative of loss w.r.t the model parameters
        # 3. Update the model parameters with optimizer


        # Log loss to wandb
        if (i+1) % print_interval == 0:
            wandb.log({"Loss/train": total_loss.avg})

    return total_loss.avg


def validate(val_loader, model, criterion, epoch, device):
    total_loss = AverageMeter()

    # TODO: set model to evaluate mode


    with torch.no_grad():
        val_loader = tqdm(val_loader, dynamic_ncols=True)
        # Iterate over all validation samples
        for _, (kpts_2d, gt_kpts_3d, weight, _) in enumerate(val_loader):
            # TODO:
            # 1. Put all revelant data onto same device
            # 2. Model forward (given 2d kpts and weight, return 3d kpts)


            # TODO: Compute loss (remember to pass weight)
            loss = ...
            # Record current batch's loss
            total_loss.update(loss.item())

        # Log loss to wandb
        wandb.log({"Loss/val": total_loss.avg})

    return total_loss.avg

With the training and validation pipeline set up, we can then instantiate our model and define other essential parts for model training including:

- **optimizer:** We will be using Adam optimizer with learning rate 2e-4.

- **loss function (criterion):** Use Pose3DLoss() as the loss function, which computes the MSE between predicted and ground truth 3D hand keypoints with filtering. The implementation has been given, please take a look at it before use.

- **dataloaders**: train and val dataloader to return data in batches. Default batch size is 64, and you can adjust properly as needed.

In [None]:
# TODO: Instantiate the model and define device for training to use GPU if available
device = ...
model = ...

# TODO: Define criterion
criterion = ...
# TODO: Define optimizer
optimizer = ...

# Define train and val dataloader
train_loader = ...
val_loader = ...

Finally, we can start to train the model. We will log training and validation loss to wandb() to monitor the model training status. Run cell below to initialization wandb and start training. For Wandb Usage, please refer to [Wandb Github](https://github.com/wandb/wandb?tab=readme-ov-file).

In [None]:
wandb.login()
# TODO: Define current run name
current_run_name = ...
wandb.init(project="CIS5810_project_8_1", name=current_run_name)

In [None]:
# Define output directory; modify as needed (where model ckpt will be saved)
output_root = "output"
output_dir = os.path.join(output_root, current_run_name)
print("="*10 + f" Training started. Output will be saved at {output_dir} " + "="*10)
os.makedirs(output_dir, exist_ok=True)

# Default training epoches and best val loss
epochs = 30
best_val_loss = np.inf

for epoch in range(epochs):
    print("="*10, f"Epoch [{epoch}/{epochs}]", "="*10)
    # train for one epoch
    _ = train(train_loader, model, criterion, optimizer, epoch, device)

    # evaluate on validation set
    val_loss = validate(val_loader, model, criterion, epoch, device)

    # Save best model weight
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Save model weight
        torch.save(model.state_dict(), os.path.join(output_dir, f"final_state.pth.tar"))
        print(f"Saving model weight with best val_loss={val_loss:.5f}")
    print()
print("="*10 + f" Training finished. Got best model with val_loss={best_val_loss:.5f} " + "="*10)
wandb.finish()

### 4 - Evaluate

To evaluate the model performance in lifting 2D to 3D hand pose estimation we use two metrics: MPJPE and PA-MPJPE. Mean Per-Joints Position Error (MPJPE) computes the average distance between each predicted joint and ground truth joint. PA-MPJPE performs procruste aligns on predicted pose first and then calculate MPJPE, thus its value is lower (better) and corresponds to the model performance upper bound.

Finish the *TODOs* below to evaluate the model performance by reporting MPJPE and PA-MPJPE value. For reference, a good model should have MPJPE ~ 25mm, PA-MPJPE ~ 10mm.

In [None]:
def evaluate(testset, model, device):
    epoch_loss_3d_pos = AverageMeter()
    epoch_loss_3d_pos_procrustes = AverageMeter()

    with torch.no_grad():
        for kpts_2d, gt_kpts_3d, weight, _ in tqdm(testset, total=len(testset)):
            # Pose 3D prediction
            kpts_2d = kpts_2d.unsqueeze(0).to(device)
            weight = weight.unsqueeze(0).to(device)
            gt_kpts_3d = gt_kpts_3d.unsqueeze(0).to(device)
            pred_kpts_3d = model(kpts_2d, weight)

            # mm to m
            pred_kpts_3d *= 1000.0
            gt_kpts_3d *= 1000.0

            # Get valid kpts
            valid_pred_kpts_3d = pred_kpts_3d[weight].view(1,-1,3).cpu().detach().numpy()
            valid_gt_kpts_3d = gt_kpts_3d[weight].view(1,-1,3).cpu().detach().numpy()
            # Un-normalize
            valid_pred_kpts_3d = testset.inv_normalize_3d(valid_pred_kpts_3d)
            valid_gt_kpts_3d = testset.inv_normalize_3d(valid_gt_kpts_3d)
            # Compute MPJPE
            epoch_loss_3d_pos.update(mpjpe(valid_pred_kpts_3d, valid_gt_kpts_3d).item())
            epoch_loss_3d_pos_procrustes.update(p_mpjpe(valid_pred_kpts_3d, valid_gt_kpts_3d).item())

    return epoch_loss_3d_pos.avg, epoch_loss_3d_pos_procrustes.avg

In [None]:
# TODO: Initialize model, device and load in pretrained weight. Remember to set model in eval() mode
model = ...

# Evalute model performance
mpjpe_, pa_mpjpe_ = evaluate(test_dataset, model, device)
print(f"Model performance on test set: MPJPE: {mpjpe_:.2f} (mm) PA-MPJPE: {pa_mpjpe_:.2f} (mm)")

We can also visualize the ground truth 3D hand pose and predicted 3D hand pose to get better qualitative evaluation. Finish the *TODOs* below to generate the comparison plot. **Select three random data samples and attach the comparison plot (GT on left, pred on right) in your final submission.**

In [None]:
# TODO: Select random idx
vis_idx = ...
kpts_2d, gt_kpts_3d, weight, _ = test_dataset[vis_idx]

# Visualize ground truth 3D hand kpts
gt_3d = check_dataset.inv_normalize_3d(gt_kpts_3d.numpy())
gt_3d[~weight] = None
vis_data_3d(gt_3d, title=f"GT - idx={vis_idx}")

# TODO: Visualize predicted 3D hand kpts
pred_kpts_3d = ...
vis_data_3d(pred_kpts_3d, title=f"Pred - idx={vis_idx}")

### 5 - Submission

Please submit materials specified below to Gradescope.

- Finished Notebook `CIS_5810_Project_8-1.ipynb`
- Finished `model.py`
- Single PDF file, including:
    - Training and validation loss curve plot from wandb
    - Three comparison plots
    - Report Model performance on test set: MPJPE and PA-MPJPE.