## **CIS 5810 Project 8-2 - 3D Hand Pose Estimation**

**Introduction**

In Part 1 you have seen how we can utilize the Transformer to build a simple yet effective hand pose estimation network. In part 2, we will extend the 3D hand pose estimation task to taking image as input using a new Transformer based network - [POTTER](https://github.com/zczcwh/POTTER/tree/main).

**3D hand pose estimation from image**

Recall in Part 1 where we are doing 2D-to-3D hand pose lifting by taking ground truth 2D hand joint coordinates as input, in this part we will take the image $I (H,W,3)$ as input (similar to Project 2) but we will predict the 3D hand pose $\hat{P}(21,3)$ directly.


**What to do**

All parts that need your implementation are marked as *TODOs*, including model architecture build, training pipeline and final inference & evaluation. The file structure and helper function detail are listed below.

- `imgs/`: directory where sample images for display are stored.
- `dataset/`
    - `dataset.py`: main Dataset to load and preprocess Ego-Exo4D data for model training.
    - `dataset_vis.py`: some helper functions to visualize 3D hand kepypoints.
- `utils/`
    - `functions.py`: a list of utility functions to help model training and debugging.
    - `loss.py`: implementation of loss function and metrics to evaluate model performance.
- `model/`
    - `potter.py`: Implementation of POTTER architecture and parts.
    - `model.py`: Implementation of `PoolAttnHR_Pose_3D`, which is the model to do 3D hand pose estimation.
- `CIS_5810_Project_8-2.ipynb`: intergrate every parts from above together from loading dataset to model training and testing.


### 0 - Set-up

Feel free to reuse the environment you set up in Project 3 Part 1. Otherwise, run `pip install -r requirement.txt` to install packages needed.

In [None]:
import os
import time
import torch
import numpy as np
import torchvision.transforms as transforms
from dataset.dataset import ego4dDataset
from model.model import load_pretrained_weights, PoolAttnHR_Pose_3D
from tqdm import tqdm
from utils.functions import (
    AverageMeter,
    update_config,
)
from dataset.dataset_vis import vis_data_3d
from utils.loss import Pose3DLoss, mpjpe, p_mpjpe
import wandb

wandb.login()

### 1 - Load Ego4D dataset

Download annotation JSON files from [here](https://drive.google.com/drive/folders/1a_rhSuq5LsJQyiUVubQFYh8XGRwvMJnP?usp=sharing) and put it under `<anno_dir>`. The images to be used is the same as in Project 7, so feel free to re-use images from Project 7 by modifying `<img_dir>` to be the directory where you store the Project 7 images.

Also, download the pretrained POTTER classification weight from [here](https://drive.google.com/file/d/14d8ky1d_oKXrZEqb3sM_atfpQM5Q6Ob0/view) for transfer learning, where we load in the pretrained classification weight first and then train on hand pose estimation. Modify the `potter_cls_weight` variable below to be the path of this weight stored on your local machine.

In [None]:
# TODO: Modify config as needed, e.g. annotation and image directory, training batch size etc.
cfg = {
        "anno_dir": ...,
        "img_dir": ...,
        "model_cfg": "configs/potter_pose_3d_ego4d.yaml",
        "potter_cls_weight": ...,
        "lr": 1e-4,
        "train_bs": 16,
        "val_bs": 16,
        "epochs": 15,
    }

# Define the transform for image data preprocessing, which in here is just image normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# TODO: Initialize the train, val and test Dataset
# Hint: take a look at the implementation of ego4dDataset to see how to initialize dataset
train_dataset = ...
val_dataset = ...
test_dataset = ...

# Check the dataset length
print("Train: ", len(train_dataset))
print("Val: ", len(val_dataset))
print("Test: ", len(test_dataset))

**Visualizing the dataset**

Take a look at the implementation of `dataset/dataset.py` which applies data preprocessing on both images and annotations, and returns four items when being indexed:

- `input`: Cropped image from original Aria ego image s.t. hand is in image center.
- `pose_3d_gt`: 3D hand keypoints in camera coordinate system, with hand wrist offseted and normalized (similiar to Part 1).
- `vis_flag`: A boolean array indicating whether each hand joint has a valid 3D keypoint.
- `metadata`: A dictionary consisting of current frame info.


In [None]:
## TODO: Modify as needed to take a look at the dataset
check_dataset = train_dataset
idx = 20000

## Get one dataset sample for visualization
input, pose_3d_gt, vis_flag, metadata = check_dataset[idx]

## Visualization of input image and 3d kpts
gt_3d = check_dataset.inv_normalize_3d(pose_3d_gt.numpy())
# Assign None to invalid kpts (so it won't be displayed)
gt_3d[~vis_flag] = None
vis_data_3d(gt_3d, title="GT 3D")

### 2 - Define model

**Generate Overview**

POTTER (POoling aTtention TransformER) is a novel Transformer based architecture with the core design of Pooling Transformer Block (PAT), which replaces the original Attention block (Part 1) by Pooling Attention block. See Figure 1 for a comparison between different Transformer blocks.

<div style="display: center; justify-content: space-between; text-align: center;">
    <figure>
        <img src="imgs/transformer_blocks.png" alt="Different Transformer blocks" width ="450" height="350">
        <figcaption>Figure 1: Different Transformer blocks</figcaption>
    </figure>
</div>

The overall model architecture of POTTER is shown in Figure 2, consisting of two stages: Basic stream and HR stream. The Basic stream has a hierachical structure with four stages, where the resolution of the feature map is gradually reduced to capture more global information. The global features from the basic stream are fused with the local features by patch split blocks in the HR stream. Finally, the output of stage 4 in HR stream are fed into the head block to predict the final pose information.

<div style="display: center; justify-content: space-between; text-align: center;">
    <figure>
        <img src="imgs/POTTER_arch.png" alt="Overall architecture of POTTER" width ="800" height="310">
        <figcaption>Figure 2: Overall architecture of POTTER</figcaption>
    </figure>
</div>


**Pooling Attention in PAT Block**

See Figure 3 for the details of each component in PAT block, which has a very similar structure as the original Transformer Encoder block you have implemented in Part 1.

<div style="display: center; justify-content: space-between; text-align: center;">
    <figure>
        <img src="imgs/PAT.png" alt="Pooling Attention Transfromer Block" width ="580" height="330">
        <figcaption>Figure 3: Pooling Attention Transfromer Block (PAT)</figcaption>
    </figure>
</div>

Given the input feature $X_{in}\in \R^{D \times h \times w}$, where $D$ is the embedding dimension and $h \times w$ are the spatial dimension of the feature map, it is first normalized by Layer Norm as $X_0$, then passed to `PoolAttn` block to perform patch-wise pooling attention and embed-wise pooling attention. The output of `PoolAttn` block are then elementwisely added with $X_{in}$, which are then passed to FFN block to generate the final PAT block output $X_{out}$.

- **Patch-wise Pooling Attention**

    In patch-wise pooling attention, each patch's spatial locations are preserved while capturing the correlation between all patches. The input $X_0$ is first squeezed along $h$ axis and $w$ axis by two [adaptive average pooling](https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html) to get $X_{Ph}$ and $X_{Pw}$, then the matrix multiplication between $X_{Ph}$ and $X_{Pw}$ gives to $X_1$.

    $$X_{Ph} = Pool_1(X_0), \quad X_{Ph}\in \R^{D\times h \times 4}$$
    $$X_{Pw}  =Pool_2(X_0), \quad X_{Pw}\in \R^{D \times 4 \times w}$$
    $$X_1 = MatMul(X_{Ph}, X_{Pw}), \quad X_1 \in \R^{D \times h \times w}$$

- **Embed-wise Pooling Attention**

    In embed-wise pooling attention, the similar spatial cross attention is performed along the embedding dimension. $X_0 \in \R^{D \times h \times w}$ is first reshaped to $X_{0}^{'} \in \R^{N \times D_h \times D_w}$, where $N = h\times w$ and $D = D_h \times D_w$. Then $X_{0}^{'}$ is squeezed along $D_h$ axis and $D_w$ axis by two adaptive average pooling to be $X_{PDh}$ and $X_{PDw}$. The matrix multiplication between $X_{PDh}$ and $X_{PDw}$ then gives to $X_2$, which are reshaped to $X_3\in \R^{D\times h\times w}$ with the same shape as $X_{0}$.

    $$X_{PDh} = Pool_3(X_{0}^{'}), \quad X_{PDh} \in \R^{N\times D_h \times 4}$$
    $$X_{PDw} = Pool_4(X_{0}^{'}), \quad X_{PDw} \in \R^{N\times 4 \times D_w}$$
    $$X_2 = MatMul(X_{PDh}, X_{PDw}) \quad X_2 \in \R^{N\times D_h \times D_w}$$

The output from path-wise pooling and embed-wsie pooling are then projected by a convolutional layer separately, and added together with a layernorm and another convolutional layer as final projection to generate the `PoolAttn` blcok output.

$$X_{out} = Proj_3(LN(Proj_0(X_1) + Proj_1(X_3))), \quad X_{out} \in \R^{D\times h \times w}$$



**TODO**

Your task is to finish the implementation of the Pooling Attention `PoolAttn` in PAT block, which are all marked as *TODOs* in `model/potter.py`. After your implementation, you can run cell below to perform simple check on model output tensor shape.

In [None]:
# Simple Test
model_cfg = update_config(cfg["model_cfg"])
model = PoolAttnHR_Pose_3D(**model_cfg.MODEL)

input = torch.rand(1,3,224,224)
output = model(input)
assert output.shape == (1,21,3), "Implementation is incorrect. Please check your PAT block"

### 3 - Train model

Define the pipeline of model training and validation similar to Project 2. See below TODOs to fill in your implementation.

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    total_loss = AverageMeter()

    # TODO: set model to training mode


    train_loader = tqdm(train_loader, dynamic_ncols=True)
    print_interval = len(train_loader) // 6
    # Iterate over all training samples
    for i, (input, pose_3d_gt, vis_flag, _) in enumerate(train_loader):
        # TODO:
        # 1. Put all revelant data onto same device
        # 2. Model forward (given cropped hand image, predict a set of 3d kpts)


        # TODO: Compute loss
        loss = ...
        total_loss.update(loss.item())

        # TODO:
        # 1. Clear the old parameter gradients
        # 2. Compute the derivative of loss w.r.t the model parameters
        # 3. Update the model parameters with optimizer


        # Log loss to wandb
        if (i+1) % print_interval == 0:
            wandb.log({"Loss/train": total_loss.avg})

    # Return average training loss
    return total_loss.avg


def validate(val_loader, model, criterion, device):
    total_loss = AverageMeter()

    # TODO: set model to evaluate mode


    with torch.no_grad():
        val_loader = tqdm(val_loader, dynamic_ncols=True)
        # Iterate over all validation samples
        for i, (input, pose_3d_gt, vis_flag, _) in enumerate(val_loader):
            # TODO:
            # 1. Put all revelant data onto same device
            # 2. Model forward (given cropped hand image, predict a set of 3d kpts)

            # TODO: Compute loss
            loss = ...
            total_loss.update(loss.item())

        # Log loss to wandb
        wandb.log({"Loss/val": total_loss.avg})

    # Return average training loss
    return total_loss.avg

With the training and validation pipeline set up, we can then instantiate our model and define other essential parts for model training including:

- **optimizer:** We will be using Adam optimizer with default learning rate 1e-4.

- **loss function (criterion):** Use Pose3DLoss() as the loss function, which computes the MSE between predicted and ground truth 3D hand keypoints with filtering. The implementation has been given, please take a look at it before use.

- **dataloaders**: train and val dataloader to return data in batches. Default batch size is 16, and you can adjust properly as needed.

In [None]:
# Instantiate the model and define device for training to use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_cfg = update_config(cfg["model_cfg"])
# Load in pretrained cls weight
model = PoolAttnHR_Pose_3D(**model_cfg.MODEL).to(device)
cls_weight = torch.load(cfg["potter_cls_weight"])
load_pretrained_weights(model.poolattnformer_pose.poolattn_cls, cls_weight)

# TODO: Define loss function (criterion) and optimizer
criterion = ...
optimizer = ...

# TODO: Define train and val dataloader
train_loader = ...
val_loader = ...

# TODO: Define current run name
current_run_name = time.strftime("%Y-%m-%d-%H-%M") # Modify as needed, e.g. "test_run_123"
wandb.init(project="CIS5810_project_8_2", name=current_run_name)

Finally, we can start to train the model. We will log training and validation loss to wandb() to monitor the model training status. Run cell below to start training.

In [None]:
# Define output directory; modify as needed (where model ckpt will be saved)
output_root = "output"
output_dir = os.path.join(output_root, current_run_name)
print("="*10 + f" Training started. Output will be saved at {output_dir} " + "="*10)
os.makedirs(output_dir, exist_ok=True)

# Default training epoches and best val loss
epochs = cfg["epochs"]
best_val_loss = np.inf

for epoch in range(epochs):
    print("="*10, f"Epoch [{epoch}/{epochs}]", "="*10)
    # train for one epoch
    _ = train(train_loader, model, criterion, optimizer, device)

    # evaluate on validation set
    val_loss = validate(val_loader, model, criterion, device)

    # Save best model weight
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Save model weight
        torch.save(model.state_dict(), os.path.join(output_dir, f"best_model_weight.pth.tar"))
        print(f"Saving model weight with best val_loss={val_loss:.5f}")
    print()
print("="*10 + f" Training finished. Got best model with val_loss={best_val_loss:.5f} " + "="*10)
wandb.finish()

### 4 - Evaluate

To evaluate the model performance in 3D hand pose estimation we use same metrics as in Part 1: MPJPE and PA-MPJPE.

Finish the *TODOs* below to evaluate the model performance by reporting MPJPE and PA-MPJPE value. For reference, a good model should have MPJPE ~ 35mm, PA-MPJPE ~ 15mm.

In [None]:
def evaluate(test_loader, model, device):
    epoch_loss_3d_pos = AverageMeter()
    epoch_loss_3d_pos_procrustes = AverageMeter()

    with torch.no_grad():
        test_loader = tqdm(test_loader, dynamic_ncols=True)
        for i, (input, pose_3d_gt, vis_flag, _) in enumerate(test_loader):
            # Pose 3D prediction
            input = input.to(device)
            pose_3d_pred = model(input)

            # Unnormalize predicted and GT pose 3D kpts
            pred_3d_pts = pose_3d_pred.cpu().detach().numpy()
            pred_3d_pts = pred_3d_pts * test_dataset.joint_std + test_dataset.joint_mean
            gt_3d_kpts = pose_3d_gt.cpu().detach().numpy()
            gt_3d_kpts = gt_3d_kpts * test_dataset.joint_std + test_dataset.joint_mean

            # Filter out invalid joints
            valid_pred_3d_kpts = torch.from_numpy(pred_3d_pts)
            valid_pred_3d_kpts = valid_pred_3d_kpts[vis_flag].view(1, -1, 3)
            valid_pose_3d_gt = torch.from_numpy(gt_3d_kpts)
            valid_pose_3d_gt = valid_pose_3d_gt[vis_flag].view(1, -1, 3)
            # Compute MPJPE
            epoch_loss_3d_pos.update(mpjpe(valid_pred_3d_kpts, valid_pose_3d_gt).item(), 1)
            epoch_loss_3d_pos_procrustes.update(p_mpjpe(valid_pred_3d_kpts, valid_pose_3d_gt), 1)

    return epoch_loss_3d_pos.avg, epoch_loss_3d_pos_procrustes.avg

In [None]:
# TODO: Initialize model, device and load in pretrained weight. Remember to set model in eval() mode
model = ...
device = ...
load_pretrained_weights(model, torch.load("REPLACE_TO_BE_MODEL_CKPT_PATH", map_location=device))

# Evalute model performance on test set
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )
mpjpe_, pa_mpjpe_ = evaluate(test_loader, model, device)
print(f"Model performance on test set: MPJPE: {mpjpe_:.2f} (mm) PA-MPJPE: {pa_mpjpe_:.2f} (mm)")

We can also visualize the ground truth 3D hand pose and predicted 3D hand pose to get better qualitative evaluation. Finish the *TODOs* below to generate the comparison plot. Select three random data samples and attach the comparison plot (GT on left, pred on right) in your final submission.

In [None]:
# TODO: Select random idx
vis_idx = ...
input, pose_3d_gt, vis_flag, _ = test_dataset[vis_idx]

# Visualize ground truth 3D hand kpts
gt_3d = test_dataset.inv_normalize_3d(pose_3d_gt.numpy())
gt_3d[~vis_flag] = None
vis_data_3d(gt_3d, title=f"GT - idx={vis_idx}")

# TODO: Visualize predicted 3D hand kpts
pred_kpts_3d = ...
pred_kpts_3d[~vis_flag] = None
vis_data_3d(pred_kpts_3d, title=f"Pred - idx={vis_idx}")

### 5 - Submission

Please submit materials specified below to Gradescope.

- Finished Notebook `CIS_5810_Project_8_2.ipynb`
- Finished `model/potter.py`
- Single PDF file, including:
    - Training and validation loss curve plot from wandb
    - MPJPE (mm) and PA-MPJPE (mm)
    - Three visualization plots, with GT on left and Pred on right