# Pose Estimation Fine Tuninig with Super Gradients

In this tutorial notebook we demonstrate how to fine tune a pose estimation model using SuperGradients. It is recommended that you go over [Pose Estimation tutorial](https://docs.deci.ai/super-gradients/documentation/source/PoseEstimation.html) docs first to get familiar with terminology and concepts we use here.

From this tutorial you will learn:
* How to implement a custom dataset class for pose estimation task
* How to instantiate a pre-trained pose estimation model and change number of joints it predicts to fit your dataset
* How to fine-tune a pose estimation model using SuperGradients




In [2]:
!pip install super_gradients==3.2

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


ERROR: Could not find a version that satisfies the requirement super_gradients==3.2 (from versions: 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.5.2, 1.6.0, 1.7.1, 1.7.2, 1.7.3, 1.7.4, 1.7.5, 2.0.0, 2.0.1, 2.1.0, 2.2.0, 2.5.0, 2.6.0, 3.0.0, 3.0.1, 3.0.2, 3.0.3, 3.0.4, 3.0.5, 3.0.6, 3.0.7, 3.0.8, 3.0.9, 3.1.0, 3.1.1)
ERROR: No matching distribution found for super_gradients==3.2


## Dataset

The first thing we need is our dataset. For this tutorial we will be using [Animals Pose](https://sites.google.com/view/animal-pose/) dataset. It is a relatively small dataset of 6K+ instances of animals, each annotated with 20 keypoints.

![](https://lh6.googleusercontent.com/hehW9yRzdcniQ2i1Ts65ceGERa70cBbaLlRixxu7HlUMHabt8HdgcxutG4vmVOas-U1h6g=w16383)

You need to download the datasets from these locations and update the paths below accordingly:

* https://drive.google.com/drive/folders/1xxm6ZjfsDSmv6C9JvbgiGrmHktrUjV5x?usp=sharing
* https://drive.google.com/drive/folders/1-yOSGWts2ZDYFx29u7vPcX4CdGJkPx1w?usp=sharing

In [3]:
ANIMALS_POSE_DATA_DIR = "e:/animalpose"

Animals Pose dataset uses COCO-style annotation format.
Unfortunately it's not 100% compatible with COCO parser from pycocotools.
So we have to write out own parser. That is not a big issue, since the format is very simple:

```json
{
    "images": {
        "1": "2007_000063.jpg",
        "2": "2007_000175.jpg",
        "6": "2007_000491.jpg",
        ...
        "4606": "sh97.jpg",
        "4607": "sh98.jpeg",
        "4608": "sh99.jpeg"
    },
    "annotations": [
        {
            "image_id": 1,
            "bbox": [
                123,
                115,
                379,
                275
            ],
            "keypoints": [
                [193, 216, 1],
                [160, 217, 1],
                [174, 261, 1],
                [204, 186, 1],
                [152, 182, 1],
                [0, 0, 0],
                [0, 0, 0],
                [273, 168, 1],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [266, 225, 1],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [190, 145, 1],
                [351, 238, 1]
            ],
            "num_keypoints": 20,
            "category_id": 1
        },
        ...
    ],
    "categories": [
        {
            "supercategory": "animal",
            "id": 1,
            "name": "dog",
            "keypoints": [
                "left_eye",
                "right_eye",
                "nose",
                "left_ear",
                "right_ear",
                "left_front_elbow",
                "right_front_elbow",
                "left_back_elbow",
                "right_back_elbow",
                "left_front_knee",
                "right_front_knee",
                "left_back_knee",
                "right_back_knee",
                "left_front_paw",
                "right_front_paw",
                "left_back_paw",
                "right_back_paw",
                "throat",
                "withers",
                "tailbase"
            ],
            "skeleton": [
                [0, 1],
                [0, 2],
                [1, 2],
                [0, 3],
                [1, 4],
                [2, 17],
                [18, 19],
                [5, 9],
                [6, 10],
                [7, 11],
                [8, 12],
                [9, 13],
                [10, 14],
                [11, 15],
                [12, 16]
            ]
        },
```

To train pose estimation model using Super Gradients we need to implement a custom dataset class that will parse this format and return images and targets for the model.
Fortunately, Super Gradients provides a base class for pose estimation datasets that handles most of the boilerplate code for us.

We need to implement is dataset parsing method that will return annotations in a format that Super Gradients expects.
The dataset class is expected to return a tuple of the following objects:

```python
class AnimalsPoseDataset:
    def __getitem__(self, index):
        ...
        return image, targets, {"gt_joints": gt_joints, "gt_bboxes": gt_bboxes, "gt_iscrowd": gt_iscrowd, "gt_areas": gt_areas}
```
Return values are:
* image - torch tensor of [C,H,W] shape that represents an input image to the model
* targets - model-specific targets to train the model itself. Fortunately SG will take care of generating these targets for us. Our goal is to provide the keypoints of [Num Instances, Num Joints, 3] shape.
* extras - Additional information with poses for metric computation. Must be a dictionary with following keys:
    * gt_joints - Array of keypoints for all poses in the image. Numpy array of [Num Instances, Num Joints, 3] shape
    * gt_bboxes - Array of bounding boxes for each pose in the image. Numpy array of [Num Instances, 4] shape
    * gt_iscrowd - Array of iscrowd flags for each pose in the image. Numpy array of [Num Instances] shape
    * gt_areas - Array of areas for each skeleton in the image. Numpy array of [Num Instances] shape

In [4]:
import json
import os
from typing import List, Mapping, Any, Tuple

import cv2
import numpy as np
from torch import Tensor

from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.target_generator_factory import TargetGeneratorsFactory
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.datasets.pose_estimation_datasets import BaseKeypointsDataset
from super_gradients.training.transforms.keypoint_transforms import KeypointTransform


class AnimalPoseKeypointsDataset(BaseKeypointsDataset):
    """
    Dataset class for training pose estimation models on COCO Keypoints dataset.
    Use should pass a target generator class that is model-specific and generates the targets for the model.
    """

    @resolve_param("transforms", TransformsFactory())
    @resolve_param("target_generator", TargetGeneratorsFactory())
    def __init__(
        self,
        data_dir: str,
        images_dir: str,
        json_file: str,
        include_empty_samples: bool,
        target_generator,
        transforms: List[KeypointTransform],
        min_instance_area: float,
    ):
        """

        :param data_dir: Root directory of the COCO dataset
        :param images_dir: path suffix to the images directory inside the dataset_root
        :param json_file: path suffix to the json file inside the dataset_root
        :param include_empty_samples: if True, images without any annotations will be included in the dataset.
            Otherwise, they will be filtered out.
        :param target_generator: Target generator that will be used to generate the targets for the model.
            See DEKRTargetsGenerator for an example.
        :param transforms: Transforms to be applied to the image & keypoints
        :param min_instance_area: Minimum area of an instance to be included in the dataset
        """
        super().__init__(transforms=transforms, target_generator=target_generator, min_instance_area=min_instance_area)

        with open(os.path.join(data_dir, json_file), "r") as f:
            json_annotations = json.load(f)

        self.joints = json_annotations["categories"][0]["keypoints"]
        self.num_joints = len(self.joints)

        images_and_ids = [(image_id, os.path.join(data_dir, images_dir, image_path)) for image_id, image_path in json_annotations["images"].items()]
        self.image_ids, self.image_files = zip(*images_and_ids)

        self.annotations = []

        for image_id in self.image_ids:
            keypoints_per_image = []
            bboxes_per_image = []

            image_annotations = [ann for ann in json_annotations["annotations"] if str(ann["image_id"]) == str(image_id)]
            for ann in image_annotations:
                keypoints = np.array(ann["keypoints"]).reshape(self.num_joints, 3)
                bbox = np.array(ann["bbox"])
                keypoints_per_image.append(keypoints)
                bboxes_per_image.append(bbox)

            keypoints_per_image = np.array(keypoints_per_image, dtype=np.float32).reshape(-1, self.num_joints, 3)
            bboxes_per_image = np.array(bboxes_per_image, dtype=np.float32).reshape(-1, 4)
            annotation = keypoints_per_image, bboxes_per_image
            self.annotations.append(annotation)

        # if not include_empty_samples:
        #     subset = [img_id for img_id in self.ids if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0]
        #     self.ids = subset

    def __len__(self):
        return len(self.image_ids)

    def load_sample(self, index):
        file_path = self.image_files[index]
        keypoints, boxes = self.annotations[index] # boxes in xyxy format

        gt_areas = np.array([(xmax-xmin) * (ymax-ymin) for (xmin, ymin, xmax, ymax) in boxes], dtype=np.float32)
        gt_iscrowd = np.array([0] * len(keypoints), dtype=bool)

        image = cv2.imread(file_path, cv2.IMREAD_COLOR)
        mask = np.zeros(image.shape[:2], dtype=np.float32)

        return image, mask, keypoints, gt_areas, boxes, gt_iscrowd

    def __getitem__(self, index: int) -> Tuple[Tensor, Any, Mapping[str, Any]]:
        img, mask, gt_joints, gt_areas, gt_bboxes, gt_iscrowd = self.load_sample(index)
        img, mask, gt_joints, gt_areas, gt_bboxes = self.transforms(img, mask, gt_joints, areas=gt_areas, bboxes=gt_bboxes)

        image_shape = img.size(1), img.size(2)
        gt_joints, gt_areas, gt_bboxes, gt_iscrowd = self.filter_joints(image_shape, gt_joints, gt_areas, gt_bboxes, gt_iscrowd)

        targets = self.target_generator(img, gt_joints, mask)
        return img, targets, {"gt_joints": gt_joints, "gt_bboxes": gt_bboxes, "gt_iscrowd": gt_iscrowd, "gt_areas": gt_areas}


    def filter_joints(
        self,
        image_shape,
        joints: np.ndarray,
        areas: np.ndarray,
        bboxes: np.ndarray,
        is_crowd: np.ndarray,
    ):
        """
        Filter instances that are either too small or do not have visible keypoints.

        :param image: Image if [H,W,C] shape. Used to infer image boundaries
        :param joints: Array of shape [Num Instances, Num Joints, 3]
        :param areas: Array of shape [Num Instances] with area of each instance.
                      Instance area comes from segmentation mask from COCO annotation file.
        :param bboxes: Array of shape [Num Instances, 4] for bounding boxes in XYWH format.
                       Bounding boxes comes from segmentation mask from COCO annotation file.
        :param: is_crowd: Array of shape [Num Instances] indicating whether an instance is a crowd target.
        :return: [New Num Instances, Num Joints, 3], New Num Instances <= Num Instances
        """

        # Update visibility of joints for those that are outside the image
        outside_image_mask = (joints[:, :, 0] < 0) | (joints[:, :, 1] < 0) | (joints[:, :, 0] >= image_shape[1]) | (joints[:, :, 1] >= image_shape[0])
        joints[outside_image_mask, 2] = 0

        # Filter instances with all invisible keypoints
        instances_with_visible_joints = np.count_nonzero(joints[:, :, 2], axis=-1) > 0
        instances_with_good_area = areas > self.min_instance_area

        keep_mask = instances_with_visible_joints & instances_with_good_area

        joints = joints[keep_mask]
        areas = areas[keep_mask]
        bboxes = bboxes[keep_mask]
        is_crowd = is_crowd[keep_mask]

        return joints, areas, bboxes, is_crowd


The console stream is logged into C:\Users\blood\sg_logs\console.log


Now, when we have the dataset class we can leverage existing dataset configs for COCO and use transforms & target generator from these configs.

In [5]:
from typing import Dict
from torch.utils.data import DataLoader
from super_gradients.training.dataloaders import get_data_loader

def animalpose_pose_train(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
    return get_data_loader(
        config_name="coco_pose_estimation_dekr_dataset_params",
        dataset_cls=AnimalPoseKeypointsDataset,
        train=True,
        dataset_params=dataset_params,
        dataloader_params=dataloader_params,
    )


def animalpose_pose_val(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
    return get_data_loader(
        config_name="coco_pose_estimation_dekr_dataset_params",
        dataset_cls=AnimalPoseKeypointsDataset,
        train=False,
        dataset_params=dataset_params,
        dataloader_params=dataloader_params,
    )

We are almost ready to instantiate our data loaders. There is only one important nuance that we have to cover. Let's instantiate our train dataset and inspect the transformations that we apply to our samples:



In [6]:
from pprint import pprint

train_data = animalpose_pose_train(dataset_params=dict(data_dir=ANIMALS_POSE_DATA_DIR,
                                                       images_dir="images",
                                                       json_file="keypoints.json"),
                                   dataloader_params=dict(num_workers=0, batch_size=8))
train_data.dataset.transforms.transforms

[KeypointsLongestMaxSize(max_height=640, max_width=640, interpolation=1, prob=1.0),
 KeypointsPadIfNeeded(min_height=640, min_width=640, image_pad_value=(127, 127, 127), mask_pad_value=1),
 KeypointsRandomHorizontalFlip(flip_index=[0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], prob=0.5),
 KeypointsRandomAffineTransform(max_rotation=30, min_scale=0.5, max_scale=2, max_translate=0.2, image_pad_value=(127, 127, 127), mask_pad_value=1, prob=0.75),
 KeypointsImageToTensor(),
 KeypointsImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

During training we apply several augmentations to our samples. There is one, however that needs out special attention, which is KeypointsRandomHorizontalFlip.
When working with object detection and semantic segmentation we can safely skip the fact that left and right in flipped image changing sides.
However, in pose estimation of objects that possess vertical symmetry, like animals or humans, we have to take this into account.
So when we flip the image we also have to swap left and right keypoints. In order to do this correctly KeypointsRandomHorizontalFlip transform must know the rearrange indices for keypoints.

If we look at the JSON data of your annotation file, we can see that the order of keypoints is the following:
```
# 0 "left_eye" -> "right_eye": 1
# 1 "right_eye" -> "left_eye": 0
# 2 "nose" -> "nose": 2
# 3 "left_ear" -> "right_ear": 4
# 4 "right_ear" -> "left_ear": 3
# 5 "left_front_elbow" -> "right_front_elbow": 6
# 6 "right_front_elbow" -> "left_front_elbow": 5
# 7 "left_back_elbow" -> "right_back_elbow": 8
# 8 "right_back_elbow" -> "left_back_elbow": 7
# 9 "left_front_knee" -> "right_front_knee": 10
# 10 "right_front_knee" -> "left_front_knee": 9
# 11 "left_back_knee" -> "right_back_knee": 12
# 12 "right_back_knee" -> "left_back_knee": 11
# 13 "left_front_paw" -> "right_front_paw": 14
# 14 "right_front_paw" -> "left_front_paw": 13
# 15 "left_back_paw" -> "right_back_paw": 16
# 16 "right_back_paw" -> "left_back_paw": 15
# 17 "throat" -> "throat": 17
# 18 "withers" -> "withers": 18
# 10 "tailbase" -> "tailbase": 19
```

So our array of indexes will look like this:


In [7]:
ANIMALS_POSE_FLIP_INDEXES = [1,0,2,4,3,6,5,8,7,10,9,12,11,14,13,16,15,17,18,19]

Now we are ready to instantiate our data loaders. For training data loader we will pass the transforms explicitly:

In [8]:
from super_gradients.training.transforms.keypoint_transforms import KeypointsLongestMaxSize, KeypointsPadIfNeeded, \
    KeypointsRandomHorizontalFlip, KeypointsRandomAffineTransform, KeypointsImageNormalize, KeypointsImageToTensor

train_transforms = [
    KeypointsLongestMaxSize(640,640),
    KeypointsPadIfNeeded(640,640, image_pad_value=(127,127,127), mask_pad_value=1),
    KeypointsRandomHorizontalFlip(ANIMALS_POSE_FLIP_INDEXES),
    KeypointsRandomAffineTransform(max_rotation=30, min_scale=0.5, max_scale=2, max_translate=0.2, image_pad_value=(127,127,127), mask_pad_value=1, prob=0.5),
    KeypointsImageToTensor(),
    KeypointsImageNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]

train_data = animalpose_pose_train(dataset_params=dict(data_dir=ANIMALS_POSE_DATA_DIR,
                                                       images_dir="images",
                                                       json_file="keypoints.json",
                                                       transforms=train_transforms),
                                   dataloader_params=dict(num_workers=0, batch_size=8))

val_data = animalpose_pose_val(dataset_params=dict(data_dir=ANIMALS_POSE_DATA_DIR,
                                                   images_dir="images",
                                                   json_file="keypoints.json"),
                               dataloader_params=dict(num_workers=0, batch_size=8))


In [9]:
val_data.dataset[0]


(tensor([[[ 0.3309,  1.4269,  1.9920,  ...,  1.8893,  1.5982,  1.3413],
          [-0.2342,  1.2385,  1.9578,  ...,  2.0263,  1.9064,  1.7352],
          [-1.0390,  0.9817,  1.9578,  ...,  2.0777,  2.0777,  2.0263],
          ...,
          [ 0.0569,  0.0569,  0.0569,  ...,  0.0569,  0.0569,  0.0569],
          [ 0.0569,  0.0569,  0.0569,  ...,  0.0569,  0.0569,  0.0569],
          [ 0.0569,  0.0569,  0.0569,  ...,  0.0569,  0.0569,  0.0569]],
 
         [[ 0.6429,  1.7458,  2.2710,  ...,  2.2185,  1.9209,  1.6583],
          [ 0.0651,  1.5707,  2.2710,  ...,  2.3410,  2.2185,  2.0434],
          [-0.7402,  1.3256,  2.2885,  ...,  2.3410,  2.3410,  2.3060],
          ...,
          [ 0.1877,  0.1877,  0.1877,  ...,  0.1877,  0.1877,  0.1877],
          [ 0.1877,  0.1877,  0.1877,  ...,  0.1877,  0.1877,  0.1877],
          [ 0.1877,  0.1877,  0.1877,  ...,  0.1877,  0.1877,  0.1877]],
 
         [[ 1.1585,  2.1520,  2.6226,  ...,  2.4657,  2.1868,  1.9254],
          [ 0.5834,  1.9603,

In [10]:
train_data.dataset[0]

(tensor([[[0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569],
          [0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569],
          [0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569],
          ...,
          [0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569],
          [0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569],
          [0.0569, 0.0569, 0.0569,  ..., 0.0569, 0.0569, 0.0569]],
 
         [[0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877],
          [0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877],
          [0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877],
          ...,
          [0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877],
          [0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877],
          [0.1877, 0.1877, 0.1877,  ..., 0.1877, 0.1877, 0.1877]],
 
         [[0.4091, 0.4091, 0.4091,  ..., 0.4091, 0.4091, 0.4091],
          [0.4091, 0.4091, 0.4091,  ..., 0.4091, 0.4091, 0.4091],
          [0.4091, 0.4091, 0.4091,  ...,

In [11]:
from super_gradients.common import StrictLoad
from super_gradients.common.object_names import Models
from super_gradients.training import models

yolo_nas_pose_l = models.get(Models.YOLO_NAS_POSE_L, num_classes=20, checkpoint_path="../pretrained/yolo_nas_pose_l_coco.pth", strict_load=StrictLoad.KEY_MATCHING)

In [12]:

from super_gradients.training.metrics import PoseEstimationMetrics

train_params = {
    "average_best_models":True,
    "warmup_mode": "linear_epoch_step",
    "warmup_initial_lr": 1e-6,
    "lr_warmup_epochs": 3,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "Adam",
    "optimizer_params": {"weight_decay": 0.0001},
    "zero_weight_decay_on_bias_and_bn": True,
    "ema": False,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    # ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
    "max_epochs": 10,
    "mixed_precision": True,
    "loss": "dekr_loss",
    "criterion_params": {
        "heatmap_loss": "qfl",
        "heatmap_loss_factor": 1.0,
        "offset_loss_factor": 0.1,
    },
    "valid_metrics_list": [
        PoseEstimationMetrics(
            num_joints=20,
            oks_sigmas=None,
            max_objects_per_image=30,
            post_prediction_callback=yolo_nas_pose_l.get_post_prediction_callback(conf=0.05, iou=0.05),
        )
    ],
    "metric_to_watch": 'AP'
}


In [None]:
from super_gradients.training import Trainer

CHECKPOINT_DIR = 'checkpoints'
trainer = Trainer(experiment_name='animal_pose_fine_tuning', ckpt_root_dir=CHECKPOINT_DIR)

trainer.train(model=yolo_nas_pose_l,
              training_params=train_params,
              train_loader=train_data,
              valid_loader=val_data)


The console stream is now moved to checkpoints\animal_pose_fine_tuning/console_июнь02_15_50_59.txt


Train epoch 0: 100%|██████████| 576/576 [10:39<00:00,  1.11s/it, DEKRLoss/heatmap=0.201, DEKRLoss/offset=0.00401, DEKRLoss/total=0.205, gpu_mem=7.05]
Validation epoch 0: 100%|██████████| 576/576 [07:24<00:00,  1.30it/s]


SUMMARY OF EPOCH 0
├── Training
│   ├── Dekrloss/heatmap = 0.201
│   ├── Dekrloss/offset = 0.004
│   └── Dekrloss/total = 0.205
└── Validation
    ├── Ap = 0.0
    ├── Ar = 0.0
    ├── Dekrloss/heatmap = 0.1877
    ├── Dekrloss/offset = 0.0041
    └── Dekrloss/total = 0.1917



Train epoch 1: 100%|██████████| 576/576 [10:11<00:00,  1.06s/it, DEKRLoss/heatmap=0.0133, DEKRLoss/offset=0.0033, DEKRLoss/total=0.0166, gpu_mem=7.42] 
Validation epoch 1:  22%|██▏       | 125/576 [01:09<04:07,  1.82it/s]