# Welcome to the PrecisionTrack Training and Testing Notebook.

![alt text](https://raw.githubusercontent.com/VincentCoulombe/precision_track/main/assets/logo.png)

### In this notebook, you will:
- Train a PrecisionTracker on the MICE dataset.  
- Test your newly trained PrecisionTracker.
- Visualize pose-tracking results on unseen footage.

**Note:**  
To deploy, track and visualize your PrecisionTrack checkpoints, please refer to our [tutorials](https://github.com/VincentCoulombe/precision_track/tree/main) and [configuration documentation](https://github.com/VincentCoulombe/precision_track/tree/main/configs).


### Before you begin
Ensure your Colab runtime is connected to a GPU:

1. Click on **Runtime** in the menu bar.  
2. Select **Change runtime type**.  
3. Set the interpreter to **Python 3**.  
4. Under **Hardware accelerator**, select **GPU**. 

In [None]:
import subprocess

# First, determine if the machine is cuda accelerated (to determine which version of the virtual environment we are goind to build).

try:
    CUDA_ACCELERATED = subprocess.run("nvidia-smi").returncode == 0
except FileNotFoundError:
    CUDA_ACCELERATED = False
    print("Please follow the instructions in the cell above to CUDA-accelerate your instance.")

In [None]:
%%capture
# Second, build PrecisionTrack's virtual environment.
!git clone https://github.com/VincentCoulombe/precision_track.git
%cd /content/precision_track/

if CUDA_ACCELERATED:
    !pip install torch==2.7.1 torchvision==0.22.1  --index-url https://download.pytorch.org/whl/cu128
    !pip install -e .[cuda]
else:
    !pip install torch==2.7.1 torchvision==0.22.1  --index-url https://download.pytorch.org/whl/cpu
    !pip install -e .[cpu]
    
!pip install gdown
# NOTE: This took some time when I tried it (about 12 minutes). This is because Pytorch and PyCuda are both heavy, wheeled, depedencies.  

In [None]:
# This cell contains a custom helper function for visualizing video on the browser.
from IPython.display import HTML
from base64 import b64encode
import subprocess


def to_h264(video_path: str) -> tuple:
    """Convert the video codec to a browser-friendly one."""
    subprocess.run(
        [
            "ffmpeg",
            "-y",
            "-i",
            video_path,
            "-c:v",
            "libx264",
            "-profile:v",
            "baseline",
            "-level",
            "3.0",
            "-pix_fmt",
            "yuv420p",
            "-movflags",
            "+faststart",
            "./visualization_h264.mp4",
        ],
        check=True,
    )

    mp4 = open("./visualization_h264.mp4", "rb").read()
    return "data:video/mp4;base64," + b64encode(mp4).decode()

### Dataset preparation

We will **train and test on the MICE dataset**.  
Before proceeding, we need to **download the dataset** and the relevant **Transfer learning checkpoint**.

#### NOTE:
If you would like to create and label your own dataset instead, please refer to our  
[training workflow guide](https://github.com/VincentCoulombe/precision_track/tree/main).


In [None]:
!mkdir /content/precision_track/checkpoints/
%cd /content/precision_track/checkpoints/
!gdown --id 1OEsczExkQ38pyyrEmH9To9ZTmGPRyaBj # The Animal Pose checkpoint (for transfer learning)

!mkdir /content/datasets/MICE/
%cd /content/datasets/MICE/
!gdown --id 1swPZZxqF6L0wbEDY_EE-esG9GZNt42mU # The MICE dataset
!unzip /content/datasets/MICE/pose-estimation.zip

### OPTIONNAL
PrecisionTrack offers the ability to log your training runs using the [Weights & Biases (W&B)](https://wandb.ai/) MLOps tool.  

By enabling this functionnality will allow PrecisionTrack to:

- **Track experiments**: Log training hyperparameters, metrics, and outputs.  
- **Visualize results**: Generate interactive plots and dashboards. Usefull for comparing training run.  
- **Collaborate**: Share runs, reports, and insights with teammates.  
- **Manage models**: Version control for datasets, models, and pipelines.

To enable this functionnality, please refer to our [Weights & Biases guide](https://github.com/VincentCoulombe/precision_track/tree/main/configs/wandb). You will then be able to set `WANDB_ENABLED = True` in the cell below.

In [None]:
# First, change the settings to load the downloaded Animal Pose checkpoint.
# NOTE: This step is done Programmatically here, but I encourage the users to do it manually instead (its way more intuitive). Please refer to our settings and workflow guides for more details.

from pathlib import Path
import os
from mmengine import Config

SETTINGS_PATH = "/content/precision_track/configs/settings/mice.py"
settings = Config.fromfile(SETTINGS_PATH)

# Accessing the actual variables
actual_training_ckpt_path = settings["training_checkpoint"]

# Updating the variables
new_training_ckpt_path = os.path.join("/content/precision_track/checkpoints/", "model_ap.pth")

# OPTIONAL: If you followed our Weight & Biases guide, have a functionnal wandb.py file and want to log your training run results, change the following to 'True'.
WANDB_ENABLED = False

# OPTIONAL: I changed to batch size to not overflow the provided GPU's vRAM.
NEW_BATCH_SIZE = 24

file = Path(os.path.abspath(SETTINGS_PATH))
text = file.read_text()
text = text.replace(actual_training_ckpt_path, new_training_ckpt_path)
text = text.replace("wandb_logging = False", f"wandb_logging = {WANDB_ENABLED}")
text = text.replace("batch_size = 38", f"batch_size={NEW_BATCH_SIZE}")
file.write_text(text)

In [None]:
# First, Train the network on the MICE dataset.
# NOTE: this step will take a few hours. The process will finish after 300 epochs.
# NOTE: If your instance is not CUDA-accelerated, do not bother running this step, it would take too long to complete.
# NOTE: Our training engine is very verbose. This level of detail is intentional and helpful for debugging, though it may appear overwhelming at first. Do no worry, unless an exception is explicitly raised, the process is functioning as expected.
!python train.py

In [None]:
# Second, change the settings to load the newly trained checkpoint.

SETTINGS_PATH = "/content/precision_track/configs/settings/mice.py"
settings = Config.fromfile(SETTINGS_PATH)

# Accessing the actual variables
actual_testing_ckpt_path = settings["testing_checkpoint"]
actual_training_work_dir = settings["training_work_dir"]

# Updating the variables
new_testing_ckpt_path = os.path.join(actual_training_work_dir, "epoch_300.pth")

file = Path(os.path.abspath(SETTINGS_PATH))
text = file.read_text()
text = text.replace(actual_testing_ckpt_path, new_testing_ckpt_path)
file.write_text(text)

# Third, test the system's detection and pose-estimation capabilities.
# NOTE: Here, you should obtain results similar to those reported in Figure 2e.
# NOTE: Since the training process is stochastic, you will most likely not have the exact same results as those reported in the acticle. Although, you should obtain comparable results.
!python test.py ../configs/tasks/testing_detection.py

In [None]:
# Fourth, test the system's tracking capabilities.
# NOTE: Here, you should obtain results similar to those reported in Figure 5d
# NOTE: Since the training process is stochastic, you will most likely not have the exact same results as those reported in the acticle. Although, you should obtain comparable results.
!python test.py ../configs/tasks/testing_tracking.py

### Tracking using your newly trained PrecisionTracker

In [None]:
# Ensure the tracking configuration points towards our current model (maybe redundant, but necessary to prevent futur changes to break this cell)
file = Path(os.path.abspath("/content/precision_track/configs/tasks/tracking.py"))
text = file.read_text()
text = text.replace("../models/yolox-pose.py", "../models/rtmdet-pose.py")
text = text.replace("with_action_recognition = True", "with_action_recognition = False")
file.write_text(text)

#Fifth, track using our current checkpoint.
#NOTE: The logger will print the confirmation that you are tracking using the newly trained checkpoint (it will print the path to the currently used checkpoint).
!python track.py /content/datasets/MICE/pose-estimation/benchmark/data/20mice.avi

In [None]:
# Sixth, generate and display the visualization.
!python ./visualize.py /content/datasets/MICE/pose-estimation/benchmark/data/20mice.avi /content/datasets/MICE/pose-estimation/benchmark/data/20mice_visualization.mp4

# Display.
data_url = to_h264("/content/datasets/MICE/pose-estimation/benchmark/data/20mice_visualization.mp4")

HTML(
    f"""
<video width=800 controls>
    <source src="{data_url}" type="video/mp4">
</video>
"""
)