
# Environment Setup for SLEAP Evaluation

## Introduction
This notebook is dedicated to evaluating the inference performance of SLEAP in terms of both tracking quality and latency. Below are the steps to set up the necessary environment on Ubuntu 22.04.

## Installation

### Create Conda Environment
To begin, create a Conda environment named `sleap` with the necessary packages. Run the following command in your terminal:

```bash
mamba create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap=1.3.3
```

### Additional Libraries
After setting up the Conda environment, install the following Python libraries using pip:

- **tqdm** for progress bars:
  ```bash
  pip install tqdm==4.66.4
  ```
- **tabulate** for table creation and formatting:
  ```bash
  pip install tabulate==0.9.0
  ```
- **motmetrics** for computing metrics for object tracking:
  ```bash
  pip install motmetrics==1.4.0
  ```
- **OpenCV** for image processing tasks:
  ```bash
  pip install opencv-python==4.1.2.30
  ```
- **mmengine** for the registry :
  ```bash
  pip install mmengine-lite
  ```
## Next Steps
Proceed to the next section once the environment setup is complete to start the evaluation process.



In [8]:
import sleap
import numpy as np
from tqdm import tqdm
import os.path as osp
from time import perf_counter
import sys
import os
import pandas as pd
import re
import abc
import os
from collections import OrderedDict
from typing import Any, List
from tabulate import tabulate

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from mmengine.config import Config

from utils.sleap133_inference_w_timer import load_model
from utils.downsize_video import downsize_video
from methods.mot import MOTEvaluation

In [9]:
def display_latency(times: np.ndarray, title, buffer_size=5, precision=3):
    assert len(times) >= buffer_size
    times = times[buffer_size:]
    mean = np.mean(times)
    table = [
        ["Mean", mean],
        ["Std", np.std(times)],
        ["Min", np.min(times)],
        ["Median", np.median(times)],
        ["Max", np.max(times)],
    ]
    print(
        "\n"
        + tabulate(
            table,
            headers=[title, "Inference Time (s)"],
            tablefmt="github",
            floatfmt=f".{precision}f",
            stralign="left",
        )
    )
    return mean


def display_mot_results(evaluation: dict, precision=3):
    table = []
    for cls, metrics in evaluation.items():
        table.append([f"MOTA on {cls}", metrics["mota"]])
        table.append([f"IDF1 on {cls}", metrics["idf1"]])
        table.append([f"IDP on {cls}", metrics["idp"]])
        table.append([f"IDR on {cls}", metrics["idr"]])
        table.append([f"Precision on {cls}", metrics["precision"]])
        table.append([f"Recall on {cls}", metrics["recall"]])
        table.append([f"IDFP on {cls}", int(metrics["idfp"])])
        table.append([f"IDFN on {cls}", int(metrics["idfn"])])
        table.append([f"IDTP on {cls}", int(metrics["idtp"])])
        table.append([f"Num Switches on {cls}", int(metrics["num_switches"])])
        table.append([f"Num Detections on {cls}", int(metrics["num_detections"])])
    print(
        "\n"
        + tabulate(
            table,
            headers=["Metric", "Score"],
            tablefmt="github",
            floatfmt=f".{precision}f",
            stralign="left",
        )
    )

In [10]:
def keypoints_cxcywh(keypoints: np.ndarray) -> np.ndarray:
    mask = ~np.isnan(keypoints).any(1)
    if not mask.any():
        return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
    keypoints = keypoints[mask]
    x = keypoints[:, 0]
    y = keypoints[:, 1]
    xmin, xmax = np.min(x), np.max(x)
    ymin, ymax = np.min(y), np.max(y)
    w, h = xmax - xmin, ymax - ymin
    cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
    return np.array([cx, cy, w, h], dtype=np.float32)


def cxcywh_xywh(cxcywh: np.ndarray) -> np.ndarray:
    if np.isnan(cxcywh).any():
        return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
    cx, cy, w, h = cxcywh
    return np.array([cx - w / 2, cy - h / 2, w, h], dtype=np.float32)


def to_numpy(x):
    if isinstance(x, list):
        return np.array(x)
    elif isinstance(x, (np.ndarray, int, float, str, np.generic)):
        return x
    else:
        raise TypeError(f"{type(x)} not yet supported.")


class BaseCsvOutput():
    SUPPORTED_PRECISION = {32: "float32", 64: "float64"}
    EXTENSION = ".csv"
    MAPPING_EXTENSION = ".npy"

    def __init__(
        self,
        path: str,
        instance_data: str,
        columns: List[str],
        confidence_threshold: float = 0.5,
        precision: int = 32,
    ) -> None:
        """A BaseCsvOutput iteratively store the relevant dict's instances data
        and can save it to csv afterward.

        Args:
            path (str): The path for either save the content of a BaseCsvOutput to a csv or to read a csv into a BaseCsvOutput
            instance_data (str): The name of the relevant instances data inside the dict
            columns (List[str]): The name of the csv's columns
            confidence_threshold (float, optional): The threshold from which data is retained. Defaults to 0.5.
            precision (int, optional): The saving precision for the data. Defaults to 32.
        """
        self.frame_id_mapping = OrderedDict()
        self.results = []
        self.curr_frame_idx = 0
        if precision not in self.SUPPORTED_PRECISION:
            raise ValueError(f"Precision {precision} not supported. Supported precisions are {list(self.SUPPORTED_PRECISION.keys())}")
        self.precision = precision
        self.columns = columns
        self.confidence_threshold = confidence_threshold
        self.supported_instance_data = ["pred_track_instances", "pred_instances"]
        self.instance_data = instance_data
        self._setup_path(path)

    def _setup_path(self, path: str):
        raw_path, _ = os.path.splitext(path)
        path = f"{raw_path}{self.EXTENSION }"
        self.path = os.path.abspath(path)
        self.mapping_path = os.path.abspath(f"{raw_path}_mapping{self.MAPPING_EXTENSION }")
        os.makedirs(os.path.dirname(self.path), exist_ok=True)

    @abc.abstractmethod
    def __call__(self, data: Any) -> None:
        pass

    def __len__(self):
        return next(reversed(self.frame_id_mapping))

    def __getitem__(self, frame_id: int) -> List[Any]:
        idx_range = self.frame_id_mapping.get(frame_id, None)
        if idx_range is None:
            return
        return self.results[idx_range[0] : idx_range[1]]

    def save(self) -> None:
        """Save the data to csv and also a mapping of the data (for faster
        __getitem__).

        Saves only one frame_id_mapping
        """
        formatted_results = pd.DataFrame(self.results, columns=["frame_id", "class_id", "instance_id"] + self.columns)
        formatted_results["frame_id"] = formatted_results["frame_id"].astype("uint32")
        formatted_results["class_id"] = formatted_results["class_id"].astype("uint16")
        formatted_results["instance_id"] = formatted_results["instance_id"].astype("int16")
        for col in self.columns:
            formatted_results[col] = formatted_results[col].astype(self.SUPPORTED_PRECISION[self.precision])
        formatted_results.to_csv(self.path, index=False)
        np.save(self.mapping_path, self.frame_id_mapping, allow_pickle=True)

    def read(self) -> None:
        """Load a csv and a mapping (for faster __getitem__)"""
        assert os.path.exists(self.path), f"{self.path} does not exist."
        assert os.path.exists(self.mapping_path), (
            f"To ensure fast iteration through the {os.path.basename(self.path)} file,"
            "you should keep its provided mapping. Which is expected to be at"
            f"{self.mapping_path}, but does not exist."
        )
        self.results = pd.read_csv(self.path).values.tolist()
        self.frame_id_mapping = np.load(self.mapping_path, allow_pickle=True).item()
        self.curr_frame_idx = self.frame_id_mapping[self.__len__()][1] + 1

    def _add_row(self, *args) -> None:
        self.results.append(list(args))

    def _update_frame_id_mapping(self, frame_id: int, increment: int):
        if frame_id not in self.frame_id_mapping:
            curr_frame_idx = self.curr_frame_idx + increment
            self.frame_id_mapping[frame_id] = (
                self.curr_frame_idx,
                curr_frame_idx,
            )
            self.curr_frame_idx = curr_frame_idx

    def _set_ids(self, instance_data: dict):
        return (
            np.zeros_like(instance_data["labels"]) - 1
            if self.instance_data
            not in [
                "pred_track_instances",
                "validation_instances",
                "correction_instances",
            ]
            else instance_data["instances_id"]
        )

    def _get_ds_info(self, data_sample: dict):
        instance_data = data_sample.get(self.instance_data, None)
        if instance_data is None:
            raise ValueError(f"The provided data sample do not contain the expected instance data ({self.instance_data}).")
        return instance_data, data_sample["img_id"]


class CsvBoundingBoxes(BaseCsvOutput):

    def __init__(
        self,
        path: str,
        bbox_format: str = "xywh",
        instance_data: str = "pred_instances",
        confidence_threshold: float = 0.5,
        precision: int = 32,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            path=path,
            precision=precision,
            confidence_threshold=confidence_threshold,
            columns=["x", "y", "w", "h", "scores"],
            instance_data=instance_data,
        )
        self.bbox_format = bbox_format
        assert self.instance_data in self.supported_instance_data, f"The provided instance_data must be one one {self.supported_instance_data}"

    def __call__(self, det_data_sample: dict):
        instance_data, frame_id = self._get_ds_info(det_data_sample)
        ids = self._set_ids(instance_data)
        i = 0
        for id_, label, bbox, score in zip(
            ids,
            instance_data["labels"],
            instance_data["bboxes"],
            instance_data["scores"],
        ):
            if (score >= self.confidence_threshold and self.instance_data == "pred_instances") or (id_ >= 0 and self.instance_data == "pred_track_instances"):
                self._add_row(frame_id, label, id_, *bbox, score)
                i += 1
        self._update_frame_id_mapping(frame_id, i)


class CsvKeypoints(BaseCsvOutput):

    def __init__(
        self,
        path: str,
        instance_data: str = "pred_instances",
        confidence_threshold: float = 0.5,
        precision: int = 32,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            path,
            precision=precision,
            confidence_threshold=confidence_threshold,
            columns=[],
            instance_data=instance_data,
        )
        assert self.instance_data in self.supported_instance_data, f"The provided instance_data must be one one {self.supported_instance_data}"

    def __call__(self, det_data_sample):
        instance_data, frame_id = self._get_ds_info(det_data_sample)
        ids = self._set_ids(instance_data)
        i = 0
        for id_, label, keypoints, scores, score in zip(
            ids,
            instance_data["labels"],
            instance_data["keypoints"],
            instance_data["keypoint_scores"],
            instance_data["scores"],
        ):
            label = to_numpy(label)
            keypoints = to_numpy(keypoints)
            keypoint_scores = to_numpy(scores)
            score = to_numpy(score)
            if (score >= self.confidence_threshold and self.instance_data in ["pred_instances", "gt_instances"]) or (
                id_ >= 0 and self.instance_data == "pred_track_instances"
            ):
                poses = np.concatenate((keypoints.reshape(1, -1, 2), keypoint_scores.reshape(1, -1, 1)), axis=2)
                poses = np.nan_to_num(poses, nan=0.0).flatten().tolist()
                self._add_row(frame_id, label, id_, poses)
                i += 1
        self._update_frame_id_mapping(frame_id, i)

    def _add_row(self, frame_id, class_id, object_id, keypoints):
        self._set_columns(frame_id, keypoints)
        super()._add_row(frame_id, class_id, object_id, *keypoints)

    def _set_columns(self, frame_id: int, keypoints: list):
        if not self.columns:
            self.columns = [f"{coord}{i}" for i in range(len(keypoints) // 2) for coord in ("x", "y")]
        else:
            assert len(keypoints) == len(self.columns), f"Inconsistent number of keypoints: {len(keypoints)}, expected: {len(self.columns)} as frame{frame_id}"


class SLEAPOutput:
    def __init__(self, save_path_bboxes: str, save_path_kpts: str, precision: int = 32):
        self.outputs = [
            CsvBoundingBoxes(save_path_bboxes, precision=precision, instance_data="pred_track_instances"),
            CsvKeypoints(save_path_kpts, precision=precision, instance_data="pred_track_instances"),
        ]

    def __call__(self, labeled_frame, x_scale=1, y_scale=1):
        frame_idx = labeled_frame.frame_idx
        data_sample = {
            "pred_track_instances": {
                "labels": [],
                "bboxes": [],
                "instances_id": [],
                "keypoints": [],
                "keypoint_scores": [],
                "scores": [],
            },
            "img_id": frame_idx,
        }
        for inst in labeled_frame.instances:
            inst_id = int(re.search(r"(\d+)", inst.track.name).group(0))
            inst_kpts = inst.get_points_array(invisible_as_nan=False)
            inst_kpts[:, 0] = inst_kpts[:, 0] * x_scale
            inst_kpts[:, 1] = inst_kpts[:, 1] * y_scale
            inst_bbox = keypoints_cxcywh(inst_kpts)
            inst_bbox = cxcywh_xywh(inst_bbox)
            class_id = 0
            if not np.all(np.isnan(inst_bbox)) and not np.any(inst_bbox ==0):
                data_sample["pred_track_instances"]["labels"].append(class_id)
                data_sample["pred_track_instances"]["instances_id"].append(inst_id)
                data_sample["pred_track_instances"]["bboxes"].append(inst_bbox)
                data_sample["pred_track_instances"]["scores"].append(1)
                inst_kpts = np.nan_to_num(np.concatenate(inst_kpts), nan=0.0).astype(int)
                data_sample["pred_track_instances"]["keypoints"].append(inst_kpts.tolist())
                data_sample["pred_track_instances"]["keypoint_scores"].append(np.ones(inst_kpts.shape[0]//2).tolist())

        for o in self.outputs:
            o(data_sample)

    def save(self):
        for o in self.outputs:
            o.save()

### Load the SLEAP config

You need to make sure you have a valid SLEAP checkpoint trained on MICE. For more informations, please visit SLEAP's documentation: https://sleap.ai/tutorials/initial-training.html

In [16]:
cfg_path = "../../configs/sleap/evaluation.py"
cfg = Config.fromfile(cfg_path)

## Latency Comparison Standards

Since SLEAP do not inherently adjust the image resolution during tracking. The only way to ensure a fair comparison is to resize the recording before hand, as the MICE dataset contains images of 640x640 pixels of resolution.

For this reason, we need to format the video so its resolution matches those used for the PrecisionTrack and DLC evaluations.


In [17]:
ori_video = sleap.load_video(cfg.video_path)
ori_width = ori_video.backend.width
ori_height = ori_video.backend.height

if ori_height != cfg.img_size[0] or ori_width != cfg.img_size[1]:
    print(
        f"Rescaling video from {ori_height}x{ori_width} to {cfg.img_size[0]}x{cfg.img_size[1]}"
    )
    split = osp.splitext(cfg.video_path)
    rescaled_video_path = f"{split[0]}_{cfg.img_size[0]}x{cfg.img_size[1]}{split[1]}"
    video_path = downsize_video(
        cfg.video_path,
        rescaled_video_path,
        height=cfg.img_size[0],
        width=cfg.img_size[1],
    )
    video = sleap.load_video(rescaled_video_path)
    scale = (ori_height / cfg.img_size[0], ori_width / cfg.img_size[1])
else:
    video = ori_video
    scale = (1, 1)
video

Rescaling video from 1536x1536 to 640x640


Video(backend=MediaVideo(filename='../../assets/20mice_640x640.avi', grayscale=False, bgr=True, dataset='', input_format='channels_last'))

## Speed Benchmarking

SLEAP splits its inference process into two parts:
1. Running the neural network model to obtain the poses on all the frames in the video.
2. Tracking instances over the video by assigning poses over all the frames iteratively.

Here, we provide the latency for both steps and the cumulative latency. This last one will be the latency you will actually experience in practice.



In [19]:
tracker = sleap.nn.tracking.Tracker.make_tracker_by_name(
    tracker="simplemaxtracks",
    track_window=5,
    similarity="iou",
    match="hungarian",
    min_new_track_points=1,
    min_match_points=1,
    target_instance_count=cfg.max_tracks,
    pre_cull_to_target=True,
    pre_cull_iou_threshold=0.8,
    post_connect_single_breaks=True,
    clean_instance_count=0,
    clean_iou_threshold=None,
    max_tracking=True,
    max_tracks=cfg.max_tracks,
    candidate_maker="SimpleMaxTracksCandidateMaker",
)
predictor = load_model(cfg.load_from, batch_size=30)


predictions, predictions_delay = predictor.predict(video)

mean_pred_delay = display_latency(predictions_delay, "SLEAP: Prediction step")
results = SLEAPOutput(save_path_bboxes=cfg.save_path_bboxes, save_path_kpts=cfg.save_path_kpts, precision=32)

lfs = []
times = []
for i, lf in tqdm(enumerate(predictions)):
    t0 = perf_counter()
    lf.instances = tracker.track(lf.instances, img=lf.image, t=i, img_hw=lf.image.shape[:2])
    lfs.append(lf)
    times.append(perf_counter() - t0)

times = np.array(times) / 2  # Simulates two threads perfectly running concurrently
mean_track_delay = np.mean(times)

t1 = perf_counter()
tracker.final_pass(lfs)
mean_track_delay += (perf_counter()-t1)

for i, lf in tqdm(enumerate(lfs)):
    results(lf, *scale)
results.save()

print(f"The total mean delay is {mean_pred_delay + mean_track_delay:.5f} seconds which gives us a total latency of {1/(mean_pred_delay+mean_track_delay):.2f} FPS")

  layer_config = serialize_layer_fn(layer)


Output()


| SLEAP: Prediction step   |   Inference Time (s) |
|--------------------------|----------------------|
| Mean                     |                0.085 |
| Std                      |                0.442 |
| Min                      |                0.014 |
| Median                   |                0.018 |
| Max                      |                3.016 |


1471it [00:41, 35.72it/s]
1471it [00:06, 210.74it/s]


The total mean delay is 0.10159 seconds which gives us a total latency of 9.84 FPS


## Quantitative Evaluation

Here are the metrics regarding the quality of the tracking relative to the ground truth. Check the paper for more details.


In [21]:
results = pd.read_csv(cfg.save_path_bboxes)
results.drop(columns=["class_id"], inplace=True)
results = results.values
gt = pd.read_csv(cfg.gt_path).values
evaluator = MOTEvaluation(classes=["mouse"])
unique_frames = np.unique(gt[:, 0])

for frame in unique_frames:
    frame_gt = gt[gt[:, 0] == frame]
    frame_gt = {"mouse": frame_gt[:, 2:-1]}
    frame_results = results[results[:, 0] == frame]
    frame_results = {"mouse": frame_results[:, 1:-1]}
    evaluator.update(frame_results, frame_gt)

ev = evaluator.evaluate()
display_mot_results(ev)

INFO:root:partials: 0.054 seconds.
INFO:root:mergeOverall: 0.055 seconds.

| Metric                  |     Score |
|-------------------------|-----------|
| MOTA on mouse           |     0.725 |
| IDF1 on mouse           |     0.717 |
| IDP on mouse            |     0.722 |
| IDR on mouse            |     0.711 |
| Precision on mouse      |     0.868 |
| Recall on mouse         |     0.856 |
| IDFP on mouse           |  8008.000 |
| IDFN on mouse           |  8436.000 |
| IDTP on mouse           | 20780.000 |
| Num Switches on mouse   |    23.000 |
| Num Detections on mouse | 25000.000 |


## Qualitative evaluation of the tracking results

Youre PrecisionTrack formatted outputs are available at cfg.save_dir_mot. You can leverage PrecisionTrack's visualization tool to proceed.