# Particle Tracking with Template Matching

This notebook demonstrates a two-stage pipeline:

1. **Simulation** -- Generate a synthetic video of coloured particles performing random walks.
2. **Tracking** -- Use OpenCV template matching with rolling template updates to track each particle across frames.

The heavy lifting lives in the `src/` package; this notebook orchestrates the workflow and visualises the results.

In [None]:
import sys, os

# Ensure the project root is on the path so we can import `src`
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import cv2
import numpy as np
import matplotlib.pyplot as plt

from src.simulation import simulate_particles, draw_frame
from src.tracker import (
    make_vanilla_circle_template,
    extract_particle_only_template,
    local_circle_template_match,
    mask_circle,
    near_edge,
    Track,
    run_tracker,
    DEFAULT_TEMPLATE_SIZE,
    DEFAULT_MASK_RADIUS,
    DEFAULT_SEARCH_RADIUS,
    DEFAULT_NEW_MIN_SCORE,
    DEFAULT_TRACK_MIN_SCORE,
    DEFAULT_MATCH_METHOD,
    DEFAULT_MAX_NEW_PER_FRAME,
)
from src.visualization import draw_tracks, overlay_bounding_boxes

## 1. Simulate Particles

Generate a synthetic video of particles spawning from the edges of a 600x600 canvas and performing directional random walks. The output is saved as `simulation_detection.mp4`.

In [None]:
SIMULATION_VIDEO = os.path.join(PROJECT_ROOT, "simulation_detection.mp4")

total_data = simulate_particles(
    output_path=SIMULATION_VIDEO,
    frame_width=600,
    frame_height=600,
    max_particles=80,
    fps=15.0,
    show_preview=False,
)

print(f"Simulation complete: {len(total_data)} frames generated.")

### Preview a simulation frame

In [None]:
sample_idx = len(total_data) // 2
sample_frame = total_data[sample_idx]["frame"]

plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(sample_frame, cv2.COLOR_BGR2RGB))
plt.title(f"Simulation frame #{sample_idx}")
plt.axis("off")
plt.show()

## 2. Create a Circle Template

A plain white circle on a black background serves as the **vanilla template** for initial detection of new particles.

In [None]:
# Create and display the vanilla circle template
size = 21
radius = 10
canvas = np.zeros((size, size, 3), dtype=np.uint8)
cv2.circle(canvas, (size // 2, size // 2), radius, (255, 255, 255), -1)

plt.figure(figsize=(2, 2))
plt.imshow(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
plt.title("Circle template")
plt.axis("off")
plt.show()

# Save to images/
template_path = os.path.join(PROJECT_ROOT, "images", "circle_template.png")
cv2.imwrite(template_path, canvas)
print(f"Template saved to {template_path}")

## 3. Run the Tracker

The tracker reads `simulation_detection.mp4` frame by frame. For each frame it:

1. **Updates existing tracks** using local template matching within a search radius.
2. **Masks tracked regions** to avoid double-detection.
3. **Detects new particles** using the vanilla template on the masked frame.
4. **Refreshes rolling templates** from the current frame for the next iteration.

In [None]:
TRACKING_VIDEO = os.path.join(PROJECT_ROOT, "simulation_tracking.mp4")

final_tracks = run_tracker(
    video_path=SIMULATION_VIDEO,
    output_path=TRACKING_VIDEO,
    template_size=30,
    mask_radius=15,
    search_radius=16,
    new_min_score=0.35,
    track_min_score=0.7,
    max_new_per_frame=1,
    fps=15.0,
    show_preview=False,
)

print(f"Tracking complete. {len(final_tracks)} tracks active at end of video.")
print(f"Output saved to {TRACKING_VIDEO}")

### Preview a tracked frame

Read back a frame from the tracking output to see the bounding boxes and IDs overlaid.

In [None]:
cap = cv2.VideoCapture(TRACKING_VIDEO)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

# Seek to a frame in the middle of the video
target = total_frames // 2
cap.set(cv2.CAP_PROP_POS_FRAMES, target)
ok, frame = cap.read()
cap.release()

if ok:
    plt.figure(figsize=(6, 6))
    plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    plt.title(f"Tracked frame #{target}")
    plt.axis("off")
    plt.show()
else:
    print("Could not read frame from tracking video.")