# Demo - Transfer Learning from a COCO dataset

In [None]:
import copy
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import deeplabcut
import deeplabcut.pose_estimation_pytorch.utils as utils
import deeplabcut.pose_estimation_pytorch.config.utils as config_utils

from deeplabcut.pose_estimation_pytorch.models import PoseModel
from deeplabcut.pose_estimation_pytorch import COCOLoader
from deeplabcut.pose_estimation_pytorch.data import build_transforms
from deeplabcut.pose_estimation_pytorch.data.collate import COLLATE_FUNCTIONS
from deeplabcut.pose_estimation_pytorch.modelzoo.inference import _parse_model_snapshot
from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
    _get_config_model_paths,
    _update_config,
)
from deeplabcut.pose_estimation_pytorch.runners import build_training_runner
from deeplabcut.pose_estimation_pytorch.task import Task

## Data & Configuration 

In [None]:
experiment_path = Path("/Users/niels/Desktop/coco_transfer_experiments") / "experiment_1"

# create the experiment folder structure
train_dir = experiment_path / "train"
test_dir = experiment_path / "test"
experiment_path.mkdir(parents=True, exist_ok=True)
train_dir.mkdir(exist_ok=True)
test_dir.mkdir(exist_ok=True)
model_config_path = train_dir / "pytorch_config.yaml"

# Path to the folder containing the COCO dataset
# Format:
#   quadruped80k/
#     annotations/
#     images/
dataset_path = Path("/Users/niels/Documents/upamathis/dlc/benchmarks/modelzoo/quadruped80k")
train_file = "train.json"
test_file = "test.json"

project_name = "superanimal_topviewmouse"
model_name = "hrnetw32"

max_individuals = 10  # only needed for detector
num_bodyparts = 17 # the number of bodyparts in the project to transfer learn to

device = "cpu"

## Transfer Learning

### Creating the Experiment Configuration File

In [None]:
# Get paths to SuperAnimal configs and weights
model_cfg, project_cfg, pose_model_path, detector_model_path = _get_config_model_paths(
    project_name, model_name
)
pose_model_path = _parse_model_snapshot(Path(pose_model_path), device)
detector_model_path = _parse_model_snapshot(Path(detector_model_path), device)

# Update the configuration file to have the correct number of output joints
model_cfg = config_utils.replace_default_values(
    model_cfg,
    num_bodyparts=len(project_cfg["bodyparts"]),
    num_individuals=max_individuals,
    backbone_output_channels=model_cfg["model"]["backbone_output_channels"]
)
model_cfg["device"] = device

# print results
print(pose_model_path)
print(detector_model_path)
print("Model Config")
print("------------")
config_utils.pretty_print(model_cfg)

# save config
print("------------")
print(f"Saving Config to {model_config_path}")
config_utils.write_config(model_config_path, model_cfg, overwrite=True)

### Loading the dataset

In [None]:
loader = COCOLoader(
    project_root=dataset_path,
    model_config_path=model_config_path,
    train_json_filename=train_file,
    test_json_filename=test_file,
)

### Training 

In [None]:
# You can update these values here - or directly in the model_config file, 
# but before creating the COCOLoader
epochs = 4
save_epochs = 2
detector_epochs = None  # if 0, will not train the detector
detector_save_epochs = None  # if 0, will not train the detector

updates = {
    "train_settings": {},
    "detector": {"train_settings": {}},
}
if epochs is not None:
    updates["train_settings"]["epochs"] = epochs
if save_epochs is not None:
    updates["train_settings"]["save_epochs"] = save_epochs
if detector_epochs is not None:
    updates["detector"]["train_settings"]["epochs"] = detector_epochs
if detector_save_epochs is not None:
    updates["detector"]["train_settings"]["save_epochs"] = detector_save_epochs

loader.update_model_cfg(updates)

In [None]:
# Loads the pose model, builds a training runner - adapted from apis/train.py
pose_task = Task(loader.model_cfg["method"])
model = PoseModel.build(loader.model_cfg["model"])
runner = build_training_runner(
    runner_config=loader.model_cfg["runner"],
    model_folder=loader.model_folder,
    task=pose_task,
    model=model,
    device=device,
    snapshot_path=None,  # we don't use 'pose_model_path' here, as we only want to load the backbone weights
    logger=loader.model_cfg.get("logger", None),
)

In [None]:
def load_backbone_weights(snapshot_path: Path) -> dict:
    snapshot = torch.load(snapshot_path, map_location=device)
    state_dict = {
        ".".join(k.split(".")[1:]): v  # remove 'backbone.' from the keys
        for k, v in snapshot["model"].items()
        if k.startswith("backbone.")
    }
    print(f"Kept {len(state_dict)} weights")
    return state_dict


backbone_state_dict = load_backbone_weights(pose_model_path)
runner.model.backbone.load_state_dict(backbone_state_dict)

In [None]:
# Loads the dataset, trains
transform = build_transforms(loader.model_cfg["data"]["train"])
inf_transform = build_transforms(loader.model_cfg["data"]["inference"])

train_dataset = loader.create_dataset(transform=transform, mode="train", task=pose_task)
valid_dataset = loader.create_dataset(transform=inf_transform, mode="test", task=pose_task)

collate_fn = None
if collate_fn_cfg := loader.model_cfg["data"]["train"].get("collate"):
    collate_fn = COLLATE_FUNCTIONS.build(collate_fn_cfg)
    print(f"Using custom collate function: {collate_fn_cfg}")

batch_size = loader.model_cfg["train_settings"]["batch_size"]
num_workers = loader.model_cfg["train_settings"]["dataloader_workers"]
pin_memory = loader.model_cfg["train_settings"]["dataloader_pin_memory"]
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

# Train the model
runner.fit(
    train_dataloader,
    valid_dataloader,
    epochs=loader.model_cfg["train_settings"]["epochs"],
    display_iters=loader.model_cfg["train_settings"]["display_iters"],
)