
# 🐾 Wildlife Speed Tracking

Welcome! This notebook helps you **detect, classify, and estimate 2D speeds** of animals in videos using **PyTorchWildlife**.  

> **What you’ll get**: Annotated videos saved to an output folder, plus a `speed.csv` summarizing speeds per tracked object.


## ✅ Requirements

- Python 3.9+ recommended
- GPU optional (CUDA speeds things up, but CPU works too)
- Videos placed in a folder (default: `./demo_data/speed_tracking_videos`)

### 📦 Install dependencies

> If you already have the packages, you can **skip** this cell. If needed, uncomment to install (recommended to run one-by-one if you hit errors)  

In [None]:

# %pip install --upgrade pip
# %pip install supervision
# %pip install ipywidgets tqdm pandas matplotlib
# %pip install PytorchWildlife
# %pip install torch torchvision torchaudio
# %pip install display


## 📥 Imports
This cell imports everything we need and validates the environment.


In [1]:
import os
import sys
from pathlib import Path
from typing import Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
import torch
import supervision as sv

from PytorchWildlife.models import detection as pw_detection
from PytorchWildlife.models import classification as pw_classification
from PytorchWildlife import utils as pw_utils

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

Torch version: 2.5.1+cu124
CUDA available: True



## 🎛️ Configure Inputs

- **Height (optional):** Use the toggle below to set `animal_height_m` in meters to convert speeds from **px/s → m/s** (conversion: `m/s = (px/s * animal_height_m) / image_width_px`).  
  Leave it unset to keep speeds in **px/s**.

In [2]:
import ipywidgets as widgets
from IPython.display import display

animal_height_m = None

mode = widgets.ToggleButtons(
    options=[("None (px/s)", "none"), ("Use height (m)", "custom")],
    value="none",
    description="Scale by:",
)

height_input = widgets.BoundedFloatText(
    value=1.00, min=0.01, max=100.0, step=0.01,
    description="Height (m):",
    layout=widgets.Layout(width="220px"),
)

status = widgets.HTML()

def _update(*_):
    global animal_height_m
    if mode.value == "custom":
        height_input.layout.display = "flex"
        animal_height_m = float(height_input.value) if (height_input.value and height_input.value > 0) else None
        status.value = (
            f"<span style='color:#0a0;'>Using height → {animal_height_m:.3f} m</span>"
            if animal_height_m else
            "<span style='color:#a00;'>Enter a positive height in meters.</span>"
        )
    else:
        height_input.layout.display = "none"
        animal_height_m = None
        status.value = "<span style='color:#555;'>No height selected → speeds will be in px/s.</span>"

mode.observe(_update, names="value")
height_input.observe(_update, names="value")

_update()
display(widgets.VBox([mode, height_input, status]))

VBox(children=(ToggleButtons(description='Scale by:', options=(('None (px/s)', 'none'), ('Use height (m)', 'cu…

- **Video content:** Use the toggle below to specify what your videos contain. This choice controls how tracks are selected:

    - *One individual* → keep only the single longest track (best when a single animal is present).
    - *Group of animals* → keep all tracks (best when multiple animals may appear).

In [3]:

assume_single_individual = True

mode = widgets.ToggleButtons(
    options=[("One individual", True), ("Group of animals", False)],
    value=True,
    description="Video content:",
)

status = widgets.HTML()

def _update(change=None):
    global assume_single_individual
    assume_single_individual = mode.value
    if assume_single_individual:
        status.value = "<span style='color:#0a0;'>Assuming ONE individual → will keep only the longest track.</span>"
    else:
        status.value = "<span style='color:#555;'>Assuming a GROUP → will keep ALL tracks.</span>"

mode.observe(_update, names="value")
_update()
display(widgets.VBox([mode, status]))


VBox(children=(ToggleButtons(description='Video content:', options=(('One individual', True), ('Group of anima…


- **Folders**: where videos live and where outputs should go.

In [4]:

SOURCE_FOLDER_PATH = os.path.join(".", "demo_data", "speed_tracking_videos")
OUTPUT_FOLDER = os.path.join(".", "speed_tracking_output")

Path(OUTPUT_FOLDER).mkdir(parents=True, exist_ok=True)

print(f"SOURCE_FOLDER_PATH = {SOURCE_FOLDER_PATH}")
print(f"OUTPUT_FOLDER      = {OUTPUT_FOLDER}")


SOURCE_FOLDER_PATH = ./demo_data/speed_tracking_videos
OUTPUT_FOLDER      = ./speed_tracking_output



## 🖥️ Device
Selects GPU (CUDA) if available, otherwise CPU.


In [5]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    try:
        dev_name = torch.cuda.get_device_name(0)
    except Exception:
        dev_name = "CUDA device"
    print(f"Using GPU: {dev_name}")
else:
    print("Using CPU (this may be slower).")


Using GPU: Tesla V100-SXM2-32GB



## 🧠 Load Models
- **Detector**: MegaDetector V6 (YOLOv9-c backbone)
- **Classifier**: AI4G Amazon Rainforest (v2)

> First run may download weights. If downloads fail, check internet/firewall settings.


In [6]:

# You can switch versions here if needed.
DETECTION_VERSION = "MDV6-yolov9-c"
CLASSIFICATION_VERSION = "v2"

try:
    detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version=DETECTION_VERSION)
    classification_model = pw_classification.AI4GAmazonRainforest(device=DEVICE, version=CLASSIFICATION_VERSION)
    print("✅ Models loaded")
except Exception as e:
    raise RuntimeError(
        "Failed to load models. Verify your PyTorchWildlife install and network access for weights."
    ) from e


Ultralytics 8.3.55 🚀 Python-3.10.16 torch-2.5.1+cu124 CUDA:0 (Tesla V100-SXM2-32GB, 32501MiB)
YOLOv9c summary (fused): 384 layers, 25,321,561 parameters, 0 gradients, 102.3 GFLOPs
✅ Models loaded



## 🖍️ Annotators
Configure bounding boxes and labels for the output videos.


In [7]:

box_annotator = sv.BoxAnnotator(thickness=4)
lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)
print("Annotators ready.")


Annotators ready.



## 🔁 Detection & Classification Callback

This function:
1. Detects animals in the frame  
2. Classifies each detection (cropped region)  
3. Draws boxes & labels for visualization


In [8]:

from typing import Dict

def callback(frame: np.ndarray, index: int) -> Tuple[np.ndarray, sv.Detections, List[Tuple[str, float]]]:
    """
    Args:
        frame: Current video frame (H,W,3)
        index: Frame index or identifier (passed to detector for metadata)
    Returns:
        annotated_frame: Frame with boxes+labels
        detections: Supervision Detections object
        clf_labels: List of (prediction, confidence) per detection in the same order
    """
    results_det: Dict = detection_model.single_image_detection(frame, img_path=index)

    clf_labels: List[Tuple[str, float]] = []
    for xyxy in results_det["detections"].xyxy:
        cropped_image = sv.crop_image(image=frame, xyxy=xyxy)
        results_clf = classification_model.single_image_classification(cropped_image)
        clf_labels.append((results_clf["prediction"], results_clf["confidence"]))

    annotated_frame = lab_annotator.annotate(
        scene=box_annotator.annotate(scene=frame, detections=results_det["detections"]),
        detections=results_det["detections"],
        labels=results_det["labels"]
    )

    return annotated_frame, results_det["detections"], clf_labels

print("Callback ready.")


Callback ready.



## 📊 Prepare Speed Table
We’ll create a DataFrame that stores **t1/x1/y1 → t2/x2/y2** and a computed speed column:
- If a **species** is chosen, speeds convert to **m/s**
- Otherwise, speeds remain in **px/s**


In [9]:

def init_speed_df(species_name: str):
    if animal_height_m:
        cols = ["Video", "Image Width (px)", "t1 (s)", "x1 (px)", "y1 (px)", "t2 (s)", "x2 (px)", "y2 (px)", "speed (m/s)"]
        print(f"Using height ~ {animal_height_m} m for conversion.")
        return pd.DataFrame(columns=cols), animal_height_m, True
    else:
        cols = ["Video", "Image Width (px)", "t1 (s)", "x1 (px)", "y1 (px)", "t2 (s)", "x2 (px)", "y2 (px)", "speed (px/s)"]
        print("No animal height specified. Speed will be in pixels/second.")
        return pd.DataFrame(columns=cols), None, False

df, animal_height_m, using_meters = init_speed_df(animal_height_m)


No animal height specified. Speed will be in pixels/second.



## ▶️ Run Tracking on Your Videos

- Place videos in `SOURCE_FOLDER_PATH` (e.g., `.mp4, .avi, .mov`).
- Annotated videos will be written into `OUTPUT_FOLDER`.
- A running **speed table** is built and saved as `speed.csv` at the end.


In [None]:
import logging
logging.getLogger("ultralytics").setLevel(logging.CRITICAL)

if not os.path.exists(SOURCE_FOLDER_PATH):
    raise FileNotFoundError(f"Source video folder not found at {SOURCE_FOLDER_PATH}. Please create it and add videos.")

tracks = 0
video_files = [f for f in os.listdir(SOURCE_FOLDER_PATH) if f.lower().endswith((".mp4", ".avi", ".mov"))]

if not video_files:
    print("No video files found. Add videos to the source folder and re-run this cell.")
else:
    iterator = video_files
    if tqdm is not None:
        iterator = tqdm(video_files, desc="Processing videos")

    for video_name in iterator:
        SOURCE_VIDEO_PATH = os.path.join(SOURCE_FOLDER_PATH, video_name)
        TARGET_VIDEO_PATH = os.path.join(OUTPUT_FOLDER, f"{os.path.splitext(video_name)[0]}_tracked.mp4")
        print(f"\nProcessing: {video_name}")

        try:
            image_width_px, track_summaries = pw_utils.speed_in_video(
                source_path=SOURCE_VIDEO_PATH,
                target_path=TARGET_VIDEO_PATH,
                callback=callback,
                target_fps=10,
                codec="mp4v",
                longest=assume_single_individual,
                min_points=6,
                min_duration_s=0.5,
                min_displacement_px=20,
                suppress_subtracks=True,
                subtrack_radius_px=50,
            )

            # Each 'track' has two points (t1,x1,y1) and (t2,x2,y2) and a speed in px/s
            for i, key in enumerate(track_summaries):
                t1, x1, y1 = track_summaries[key]['points'][0]
                t2, x2, y2 = track_summaries[key]['points'][1]
                speed_px_s = track_summaries[key]['speed']

                if using_meters and animal_height_m:
                    # Convert px/s to m/s using width-scale (height_m / image_width_px)
                    speed_val = (speed_px_s * animal_height_m) / image_width_px
                else:
                    speed_val = speed_px_s

                df.loc[tracks] = [video_name, image_width_px, t1, x1, y1, t2, x2, y2, speed_val]
                tracks += 1

        except Exception as e:
            print(f"⚠️ Error processing {video_name}: {e}")
            continue

print("\nDone.")


Processing videos:   0%|          | 0/5 [00:00<?, ?it/s]


Processing: 04280240.MOV


Processing videos:  20%|██        | 1/5 [00:42<02:51, 42.90s/it]


Processing: 04230105.MOV


Processing videos:  40%|████      | 2/5 [01:18<01:55, 38.43s/it]


Processing: 03090004.MP4


Processing videos:  60%|██████    | 3/5 [01:33<00:55, 27.70s/it]


Processing: 03210076.MP4


Processing videos:  80%|████████  | 4/5 [01:39<00:19, 19.28s/it]


Processing: 05100016.MOV


Processing videos: 100%|██████████| 5/5 [02:27<00:00, 29.56s/it]


Done.






## 💾 Save Results
This writes a `speed.csv` into your output folder.


In [None]:

csv_path = os.path.join(OUTPUT_FOLDER, "speed.csv")
if len(df) > 0:
    df.to_csv(csv_path, index=False, float_format="%.3f")
    print(f"Saved: {csv_path}")
    display(df)
else:
    print("Speed table is empty—nothing to save yet.")
