# Create a custom keypoints table

In this tutorial, we will create a custom keypoints table from the [Animal Pose Dataset](https://sites.google.com/view/animal-pose/), originally introduced in the paper [Cross-Domain Adaptation for Animal Pose Estimation](https://doi.org/10.48550/arXiv.1908.05806).

We will use a version of the dataset hosted on Kaggle. The annotations are in a COCO-like json format, which we will extract manually and convert to the 3LC keypoint format.

`Cao, Jinkun, Hongyang Tang, Hao-Shu Fang, Xiaoyong Shen, Cewu Lu, and Yu-Wing Tai.  
*Cross-Domain Adaptation for Animal Pose Estimation*. arXiv preprint [arXiv:1908.05806](https://arxiv.org/abs/1908.05806), 2019.`

## Project setup

In [None]:
PROJECT_NAME = "3LC Tutorials"
DATASET_NAME = "AnimalPose"

## Imports

In [None]:
import json
from pathlib import Path

import numpy as np
import tlc
from PIL import Image
from tlc.core import KeypointHelper
from tlc.core.builtins.schemas import ImageUrlSchema, Keypoints2DSchema
from tqdm import tqdm

## Prepare data

The following cell downloads the dataset from Kaggle. The dataset requires 350MB of disk space, as well as a [Kaggle account](https://www.kaggle.com/docs/api#authentication).

In [None]:
import kagglehub

DATASET_ROOT = kagglehub.dataset_download("bloodaxe/animal-pose-dataset")
DATASET_ROOT = Path(DATASET_ROOT)

print("Path to dataset files:", DATASET_ROOT)

In [None]:
ANNOTATIONS_FILE = DATASET_ROOT / "keypoints.json"
IMAGE_ROOT = DATASET_ROOT / "images" / "images"

## Load annotations / metadata

In [None]:
# Register the dataset root as a project URL alias - this enables to easily share the table or move the source data
tlc.register_project_url_alias("ANIMAL_POSE_DATA", DATASET_ROOT, project=PROJECT_NAME)

In [None]:
with open(ANNOTATIONS_FILE) as f:
    data = json.load(f)

# Load metadata from the annotations file
NUM_KEYPOINTS = 20
KEYPOINT_NAMES = data["categories"][0]["keypoints"]
CLASSES = {cat["id"]: cat["name"] for cat in data["categories"]}
SKELETON = np.array(data["categories"][0]["skeleton"]).reshape(-1).tolist()

Some metadata is not stored in the annotations file, so we need to define it manually.
 These values were taken from the SuperGradients example notebook [YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset.ipynb).


In [None]:
OKS_SIGMAS = [0.07] * 20
FLIP_INDEXES = [1, 0, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17, 18, 19]

KEYPOINT_COLORS = [
    [148, 0, 211],
    [75, 0, 130],
    [0, 0, 255],
    [0, 255, 0],
    [255, 255, 0],
    [255, 165, 0],
    [255, 69, 0],
    [255, 0, 0],
    [139, 0, 0],
    [128, 0, 128],
    [238, 130, 238],
    [186, 85, 211],
    [148, 0, 211],
    [0, 255, 255],
    [0, 128, 128],
    [0, 0, 139],
    [0, 0, 255],
    [0, 255, 0],
    [255, 69, 0],
    [255, 20, 147],
]

EDGE_COLORS = [
    [127, 0, 255],
    [91, 56, 253],
    [55, 109, 248],
    [19, 157, 241],
    [18, 199, 229],
    [54, 229, 215],
    [90, 248, 199],
    [128, 254, 179],
    [164, 248, 158],
    [200, 229, 135],
    [236, 199, 110],
    [255, 157, 83],
    [255, 109, 56],
    [255, 56, 28],
    [255, 0, 0],
]

In [None]:
annotations = data["annotations"]
images = data["images"]

row_data = {
    "image": [],
    "keypoints_2d": [],
    "image_id": [],
}

# Pre-compute mapping from image_id to annotations for faster lookup
image_id_2_anns = {}
for ann in annotations:
    image_id_2_anns.setdefault(str(ann["image_id"]), []).append(ann)

for image_id, image_path in tqdm(images.items(), total=len(images), desc="Loading annotations"):
    image_path = Path(IMAGE_ROOT) / image_path
    row_data["image_id"].append(image_id)
    if not image_path.exists():
        print(f"Image {image_path} does not exist")
        continue

    with Image.open(image_path) as img:
        width, height = img.size

    anns = image_id_2_anns[image_id]
    keypoints = {
        "x_max": width,
        "y_max": height,
        "instances": [],
        "instances_additional_data": {
            "label": [],
        },
    }

    for ann in anns:
        kpts = np.array(ann["keypoints"])[:, :2].reshape(-1).tolist()
        visibilities = np.array(ann["keypoints"])[:, 2].tolist()
        bb = {"x_min": ann["bbox"][0], "y_min": ann["bbox"][1], "x_max": ann["bbox"][2], "y_max": ann["bbox"][3]}
        label = ann["category_id"]

        keypoints["instances"].append(
            {
                "vertices_2d": kpts,
                "vertices_2d_additional_data": {
                    "visibilities": visibilities,
                },
                "bbs_2d": [bb],
            }
        )
        keypoints["instances_additional_data"]["label"].append(label)

    row_data["image"].append(tlc.Url(image_path).to_relative().to_str())  # Url.to_relative applies aliases
    row_data["keypoints_2d"].append(keypoints)

## Create table

In [None]:
# The Keypoints2DSchema accepts attributes per point/line. These will be used
# when visualizing the table in the 3LC Dashboard.


def rgb_tuple_to_hex(rgb) -> str:
    return "#" + "".join(f"{c:02X}" for c in rgb)


LINE_ATTRIBUTES = [tlc.MapElement(internal_name="edge", display_color=rgb_tuple_to_hex(color)) for color in EDGE_COLORS]

KEYPOINT_ATTRIBUTES = [
    tlc.MapElement(internal_name=kpt_name, display_color=rgb_tuple_to_hex(color))
    for kpt_name, color in zip(KEYPOINT_NAMES, KEYPOINT_COLORS)
]

In [None]:
# For convenience, the schema for the keypoints column also stores metadata
# relevant to training: oks_sigmas and flip_indices.

keypoints_schema = Keypoints2DSchema(
    num_keypoints=NUM_KEYPOINTS,
    classes=CLASSES,
    lines=SKELETON,
    line_attributes=LINE_ATTRIBUTES,
    point_attributes=KEYPOINT_ATTRIBUTES,
    include_per_point_visibilities=True,
    flip_indices=FLIP_INDEXES,
    oks_sigmas=OKS_SIGMAS,
)

In [None]:
tw = tlc.TableWriter(
    table_name="initial",
    dataset_name=DATASET_NAME,
    project_name=PROJECT_NAME,
    column_schemas={"image": ImageUrlSchema(), "keypoints_2d": keypoints_schema},
    if_exists="rename",
)
tw.add_batch(row_data)
table = tw.finalize()

## Inspect the table

We can use the `KeypointHelper` class to extract various geometric information from the table.

In [None]:
table

In [None]:
# Get the oks sigmas from the table
KeypointHelper.get_oks_sigmas_from_table(table)

In [None]:
# Get the flip indices from the table
KeypointHelper.get_flip_indices_from_table(table)

In [None]:
# Get the skeleton from the table
KeypointHelper.get_lines_from_table(table)

In [None]:
# Get the keypoint attributes from the table
KeypointHelper.get_keypoint_attributes_from_table(table)

In [None]:
# Get the line attributes from the table
KeypointHelper.get_line_attributes_from_table(table)