In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import dataclasses
import functools
import os.path as osp
from typing import Callable, List, Literal, Union

import cv2
import gin
import jax
import mediapy as media
import numpy as np

from dycheck import core
from dycheck.datasets import Parser
from dycheck.utils import annotation, common, image, io, visuals

gin.enter_interactive_mode()

In [None]:
PREDEFINED_BINDINGS_MAP = {
    "nerfies": {
        "broom": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "curls": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "toby-sit": [
            "SKELETON=@QuadrupedSkeleton",
        ],
        "tail": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=4",
        ],
    },
    "hypernerf": {
        "vrig-3dprinter": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "vrig-chicken": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=7",
        ],
        "vrig-peel-banana": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
    },
    "iphone": {
        "teddy": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "block": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=16",
        ],
        "wheel": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "apple": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=7",
        ],
        "paper-windmill": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=6",
        ],
        "space-out": [
            "SKELETON=@HumanSkeleton",
        ],
        "spin": [
            "SKELETON=@HumanSkeleton",
        ],
        "creeper": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=9",
        ],
        "backpack": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=10",
        ],
        "pillow": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=8",
        ],
        "handwavy": [
            "SKELETON=@UnconnectedSkeleton",
            "UnconnectedSkeleton.num_kps=10",
        ],
        "mochi-high-five": [
            "SKELETON=@QuadrupedSkeleton",
        ],
        "haru-sit": [
            "SKELETON=@QuadrupedSkeleton",
        ],
        "sriracha-tree": [
            "SKELETON=@QuadrupedSkeleton",
        ],
    },
}

In [None]:
@gin.configurable(module="annotate_keypoints")
@dataclasses.dataclass
class Config(object):
    parser_cls: Callable[..., Parser] = gin.REQUIRED
    skeleton_cls: Union[
        visuals.Skeleton, Callable[..., visuals.Skeleton]
    ] = gin.REQUIRED
    split: str = gin.REQUIRED

In [None]:
# SEQUENCE = "iphone/mochi-high-five"
# SEQUENCE = "iphone/block"
# SEQUENCE = "iphone/wheel"
# SEQUENCE = "iphone/apple"
# SEQUENCE = "iphone/sriracha-tree"
# SEQUENCE = "iphone/haru-sit"
# SEQUENCE = "iphone/creeper"
# SEQUENCE = "iphone/backpack"
# SEQUENCE = "iphone/pillow"
# SEQUENCE = "iphone/handwavy"
# SEQUENCE = "iphone/paper-windmill"
# SEQUENCE = "iphone/space-out"
# SEQUENCE = "iphone/spin"

# SEQUENCE = "nerfies/broom"
# SEQUENCE = "nerfies/curls"
# SEQUENCE = "nerfies/toby-sit"
# SEQUENCE = "nerfies/tail"

# SEQUENCE = "hypernerf/vrig-3dprinter"
# SEQUENCE = "hypernerf/vrig-chicken"
SEQUENCE = "hypernerf/vrig-peel-banana"

DATASET, SEQUENCE = SEQUENCE.split("/")

In [None]:
GIN_CONFIGS = [f"../configs/{DATASET}/annotate_keypoints.gin"]
GIN_BINDINGS = [f'SEQUENCE="{SEQUENCE}"'] + PREDEFINED_BINDINGS_MAP[DATASET][
    SEQUENCE
]

In [None]:
with gin.unlock_config():
    core.parse_config_files_and_bindings(
        config_files=GIN_CONFIGS,
        bindings=GIN_BINDINGS,
        skip_unknown=True,
        master=False,
    )
config_str = gin.config_str()
print(f"*** Configuration:\n{config_str}")

config = Config()

In [None]:
parser = config.parser_cls()
frame_names, time_ids, camera_ids = jax.tree_map(
    lambda a: common.strided_subset(a, 10), parser.load_split(config.split)
)
rgbs = np.array(common.parallel_map(parser.load_rgba, time_ids, camera_ids))[
    ..., :3
]

In [None]:
media.show_images(rgbs, height=256)

In [None]:
skeleton = config.skeleton_cls()
keypoints = annotation.annotate_keypoints(rgbs, skeleton, kp_radius=6)

In [None]:
assert len(keypoints) == 10, "Annotation not finished yet."

media.show_images(
    jax.tree_map(
        lambda kps, img: visuals.visualize_kps(
            kps, img, skeleton=skeleton, kp_radius=6
        ),
        list(keypoints),
        list(rgbs),
    ),
    height=256,
)

In [None]:
assert len(keypoints) == 10, "Annotation not finished yet."

io.dump(
    osp.join(
        parser.data_dir,
        "keypoint",
        f"{parser.factor}x",
        config.split,
        "skeleton.json",
    ),
    skeleton.asdict(),
)
for i in range(len(keypoints)):
    io.dump(
        osp.join(
            parser.data_dir,
            "keypoint",
            f"{parser.factor}x",
            config.split,
            frame_names[i] + ".json",
        ),
        keypoints[i],
    )