# Let's try to figure out how to develop RTMDet

## First, let's understand what is the difference between CSPNext and RTMPose

In [1]:
import torch
x = torch.randn(1, 3, 256, 256)

# ----

# First, let's create a CSPNeXt object
from deeplabcut.pose_estimation_pytorch.models.backbones.cspnext import CSPNeXt

MODEL_VARIANTS = {
    "cspnext_s":  {"model_name": "cspnext_s", "freeze_bn_stats": False, "freeze_bn_weights": False, "deepen_factor": 0.33, "widen_factor": 0.5},
    "cspnext_m": {"model_name": "cspnext_m", "freeze_bn_stats": False, "freeze_bn_weights": False, "deepen_factor": 0.67, "widen_factor": 0.75},
    "cspnext_x":  {"model_name": "cspnext_x", "freeze_bn_stats": False, "freeze_bn_weights": False, "deepen_factor": 1.33, "widen_factor": 1.25},
}

my_variant = "cspnext_m" # choose from cspnext_s , cspnext_m , cspnext_x

cspnext = CSPNeXt(**MODEL_VARIANTS[my_variant])

print(type(cspnext)) # cspnext is of type deeplabcut.pose_estimation_pytorch.models.backbones.cspnext.CSPNeXt , which is a nn.Module
print(f"Nb parameters: {sum(p.numel() for p in cspnext.parameters())}")
# print(cspnext) # This floods the output
y = cspnext(x)
print(f"Shape of output of dummy tensor: {y.shape}")

torch.save(cspnext.state_dict(), f"/home/max/tmp/rtmdet_dev/dlc_{my_variant}.pth")

mmdetection_state_dict = torch.load("/home/max/tmp/rtmdet_dev/remaped_mmdetection_rtmdet_m_backbone.pth", map_location="cpu")

cspnext.load_state_dict(mmdetection_state_dict)

print("-"* 80)

# Next, let's create a RTMPose nn.Module object
from deeplabcut.pose_estimation_pytorch.config.utils import get_config_folder_path, replace_default_values
from deeplabcut.core.config import read_config_as_dict
from deeplabcut.pose_estimation_pytorch.models import PoseModel

net_type = "rtmpose_m" # for example
nb_bodyparts = 5 # for example

configs_dir = get_config_folder_path()
architecture = net_type.split("_")[0]
cfg_path = cfg_path = configs_dir / architecture / f"{net_type}.yaml"
model_cfg = read_config_as_dict(cfg_path)
model_cfg = replace_default_values(
    model_cfg,
    num_bodyparts=nb_bodyparts,
) # Interesting observation: the yaml file defines values that depend on nb_individuals and/or nb_bodyparts, and are updated with real values once known (when creating the actual pytorch_config.yaml)
rtmpose = PoseModel.build(model_cfg["model"]) # here, there might be some optional parameters, todo investigate
print(type(rtmpose))
print(f"Nb parameters: {sum(p.numel() for p in rtmpose.parameters())}")
# print(rtmpose) # This floods the output
z = rtmpose(x)
print(f"Shape of output of dummy tensor: {z['bodypart']['x'].shape} , {z['bodypart']['y'].shape}")
# Okay so basically, the first level keys are defined by the pytorch_config.yaml (in the heads block), and the second level keys are defined in the RTMCCHead forward method.
# The first level keys are certainly added when doing PoseModel.build
# When printing rtmpose, there are different submodules (backbone, head). They are very certainly created during PoseModel.build()
# The differences in sizes of rtmpose are the same ones as the difference in sizes of cspnext. 
# The different parameters (deepen_factor, widen_factor, backbone_output_channels) are configured in the rtmpose yaml files.

# - state_dict() returns a dictionary containing all the modelâ€™s learnable parameters and buffers,
# while load_state_dict() restores those values into a model with the same architecture.

rtmpose.backbone.load_state_dict(mmdetection_state_dict)

Loading DLC 3.0.0rc13...
<class 'deeplabcut.pose_estimation_pytorch.models.backbones.cspnext.CSPNeXt'>
Nb parameters: 12279432
Shape of output of dummy tensor: torch.Size([1, 768, 8, 8])
--------------------------------------------------------------------------------
<class 'deeplabcut.pose_estimation_pytorch.models.model.PoseModel'>
Nb parameters: 13172879
Shape of output of dummy tensor: torch.Size([1, 5, 512]) , torch.Size([1, 5, 512])


<All keys matched successfully>

## Then, let's dive into how the existing detectors work

In [None]:
from deeplabcut.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN

# Instantiate a pretrained Faster R-CNN with a MobileNetV3 backbone
detector = FasterRCNN(
    variant="fasterrcnn_mobilenet_v3_large_fpn",  # or "fasterrcnn_resnet50_fpn"
    pretrained=True,                              # load COCO pretrained weights
    box_score_thresh=0.05,                        # filter weak detections
)
# This throws away the Head and loads an other one, so the Head is not pretrained anymore.

# Dummy input batch of 2 RGB images, 3x224x224 each
images = [torch.rand(3, 224, 224), torch.rand(3, 224, 224)]

# During inference (no targets)
detector.eval()
with torch.no_grad():
    losses, detections = detector(images)

print(detections)


## Finally, let's instantiate the official RTMDet network for comparison

In [1]:
# this must be performed in the openmmlab environment.
from mmdet.utils import register_all_modules
register_all_modules()

import torch
from mmengine import Config
from mmdet.registry import MODELS

#x = torch.randn(1, 3, 256, 256)

my_variant = "rtmdet_x" # choose from rtmdet_tiny , rtmdet_s , rtmdet_m , rtmdet_l , rtmdet_x
rtmdet_cfg = Config.fromfile(f"/home/max/Work/mmdetection/configs/rtmdet/{my_variant}_8xb32-300e_coco.py")

# Build the model
rtmdet = MODELS.build(rtmdet_cfg.model)

# Put it in evaluation mode (no gradients, etc.)
rtmdet.eval()

print(type(rtmdet))

print(f"Nb parameters: {sum(p.numel() for p in rtmdet.parameters())}")

#print(rtmdet) # This floods the output
print(type(rtmdet.backbone))
print(type(rtmdet.neck))
print(type(rtmdet.bbox_head))

torch.save(rtmdet.backbone.state_dict(), f"/home/max/tmp/rtmdet_dev/mmdetection_{my_variant}_backbone.pth")

dlc_state_dict = torch.load("/home/max/tmp/rtmdet_dev/remaped_dlc_cspnext_x.pth", map_location="cpu")

rtmdet.backbone.load_state_dict(dlc_state_dict)

  from torch.distributed.optim import \


<class 'mmdet.models.detectors.rtmdet.RTMDet'>
Nb parameters: 94855572
<class 'mmdet.models.backbones.cspnext.CSPNeXt'>
<class 'mmdet.models.necks.cspnext_pafpn.CSPNeXtPAFPN'>
<class 'mmdet.models.dense_heads.rtmdet_head.RTMDetSepBNHead'>


  dlc_state_dict = torch.load("/home/max/tmp/rtmdet_dev/remaped_dlc_cspnext_x.pth", map_location="cpu")


<All keys matched successfully>

## This is a DetInferencer from MMDetection (used by the demo script)

In [None]:
from mmdet.apis import DetInferencer

inferencer = DetInferencer(
    model="/home/max/Work/mmdetection/rtmdet_tiny_8xb32-300e_coco.py",
    weights="/home/max/Work/mmdetection/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth",
    device="cpu",
)

inferencer.model.test_cfg.chunked_size = -1

inferencer(
    inputs="/home/max/Work/mmdetection/demo/demo.jpg",
    out_dir="/home/max/Work/mmdetection/outputs",
    no_save_pred=False,
)

## Let's try to load one saved state dict of CSPNeXt into the other, and vice-versa

In [1]:
from pathlib import Path
import torch
from collections import OrderedDict


def remap_cspnext_keys(state_dict, direction):
    """
    direction:
      - 'dlc_to_mmdet': makes DLC keys compatible with MMDet
      - 'mmdet_to_dlc': makes MMDet keys compatible with DLC
    """
    new_sd = OrderedDict()
    for k, v in state_dict.items():
        new_k = k
        if direction == "dlc_to_mmdet":
            # make DLC checkpoint loadable into MMDet model
            new_k = new_k.replace(".norm", ".bn")
        elif direction == "mmdet_to_dlc":
            # make MMDet checkpoint loadable into DLC model
            new_k = new_k.replace(".bn", ".norm")
        new_sd[new_k] = v
    return new_sd


dlc_cspnext_variants = ("cspnext_s", "cspnext_m", "cspnext_x")
mmdetection_cspnext_variants = ("rtmdet_s", "rtmdet_m", "rtmdet_x")

for dlc_cspnext_variant, mmdetection_cspnext_variant in zip(dlc_cspnext_variants, mmdetection_cspnext_variants):
    snapshots_dir = Path("/home/max/tmp/rtmdet_dev/")
    dlc_snapshot_path = snapshots_dir / f"dlc_{dlc_cspnext_variant}.pth"
    mmdetection_snapshot_path = snapshots_dir / f"mmdetection_{mmdetection_cspnext_variant}_backbone.pth"

    dlc_state_dict = torch.load(dlc_snapshot_path, map_location="cpu")
    mmdetection_state_dict = torch.load(mmdetection_snapshot_path, map_location="cpu")

    dlc_state_dict_remaped = remap_cspnext_keys(dlc_state_dict, direction="dlc_to_mmdet")
    mmdetection_state_dict_remaped = remap_cspnext_keys(mmdetection_state_dict, direction="mmdet_to_dlc")

    torch.save(dlc_state_dict_remaped, snapshots_dir / f"remaped_dlc_{dlc_cspnext_variant}.pth")
    torch.save(mmdetection_state_dict_remaped, snapshots_dir / f"remaped_mmdetection_{mmdetection_cspnext_variant}_backbone.pth")
