# üêò Automated Wildlife Tracking, Re-Identification & Analysis Pipeline

This notebook implements an end-to-end pipeline for tracking wildlife (specifically trained on Elephants) in drone footage.

**The Pipeline consists of four main steps:**
1.  **Initial Tracking:** Uses **YOLOv11** with **BotSort** to detect and track animals frame-by-frame.
2.  **Re-Identification (Post-Processing):** Analyzes "broken" tracks using **SIFT features** and **SSIM** to stitch trajectories of the same individual.
3.  **Video Visualization:** Generates a final video overlaying the corrected IDs.
4.  **Data Analysis:** Generates movement statistics, heatmaps, trajectory plots, and interaction (overlap/clustering) data.

In [None]:
# 1. Install Dependencies
!pip install ultralytics opencv-python pandas scikit-image tqdm seaborn --quiet

# 2. Imports
import cv2
import numpy as np
import pandas as pd
import os
import math
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from copy import deepcopy
from ultralytics import YOLO
from skimage.metrics import structural_similarity as ssim
from tqdm.notebook import tqdm
from google.colab import drive

# 3. Mount Drive
drive.mount('/content/drive')

# --- CONFIGURATION: UPDATE THIS PATH ---
WORKING_DIR = ''
INPUT_VIDEO = 'DJI_0395.MP4'       # Your video file name
MODEL_WEIGHTS = 'best_xl.pt'       # Your trained YOLO model
TRACKER_YAML = 'botsortV4.yaml'    # Your tracker config (optional)
CONF_THRESHOLD = 0.50              # Confidence threshold
OUTPUT_FOLDER = 'pipeline_results' # Where to save results

# 4. Setup Environment
if os.path.exists(WORKING_DIR):
    os.chdir(WORKING_DIR)
    print(f"‚úÖ Working directory set to: {os.getcwd()}")
else:
    print(f"‚ùå Error: Path not found: {WORKING_DIR}")

# Create output folders
os.makedirs(os.path.join(OUTPUT_FOLDER, 'plots'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_FOLDER, 'stats'), exist_ok=True)
print("‚úÖ Environment Ready.")

## Step 1: Initial Tracking (YOLO + BotSort)
We use YOLOv11 to detect animals and BotSort to assign initial IDs. This data is saved as `tracking_raw.csv`.

In [None]:
def run_initial_tracking(video_path, model_path, tracker_config, output_csv):
    print(f"üöÄ Starting tracking on {video_path}...")
    model = YOLO(model_path)
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    rows = []
    tracker_arg = tracker_config if os.path.exists(tracker_config) else "botsort.yaml"

    pbar = tqdm(total=total_frames, desc="Tracking")
    frame_num = 0

    while cap.isOpened():
        success, frame = cap.read()
        if not success: break

        results = model.track(frame, persist=True, tracker=tracker_arg, conf=CONF_THRESHOLD, verbose=False)

        if results[0].boxes.id is not None:
            boxes = results[0].boxes.xyxy.cpu().numpy()
            ids = results[0].boxes.id.cpu().numpy()
            confs = results[0].boxes.conf.cpu().numpy()

            for i, obj_id in enumerate(ids):
                rows.append({
                    'frame': frame_num,
                    'id': int(obj_id),
                    'xmin': float(boxes[i][0]),
                    'ymin': float(boxes[i][1]),
                    'xmax': float(boxes[i][2]),
                    'ymax': float(boxes[i][3]),
                    'confidence': float(confs[i])
                })
        frame_num += 1
        pbar.update(1)
    cap.release()
    pbar.close()

    df = pd.DataFrame(rows)
    df.to_csv(output_csv, index=False)
    print(f"‚úÖ Raw data saved to {output_csv}")
    return df

# Run Step 1
raw_csv_path = os.path.join(OUTPUT_FOLDER, 'tracking_raw.csv')
if os.path.exists(INPUT_VIDEO):
    df_raw = run_initial_tracking(INPUT_VIDEO, MODEL_WEIGHTS, TRACKER_YAML, raw_csv_path)
else:
    print(f"‚ùå Video not found: {INPUT_VIDEO}")

## Step 2: Re-Identification (Re-ID)
This step fixes "broken" tracks. It extracts frames where one ID ends and another begins, aligns them using **Homography**, and compares them using **SSIM**. If they match visually and physically (speed limit), the IDs are merged.

In [None]:
# --- Re-ID Helper Functions ---
def extract_frame_img(video_path, frame_number):
    cap = cv2.VideoCapture(video_path)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    success, frame = cap.read()
    cap.release()
    return frame if success else None

def get_similarity_and_homography(img1, img2):
    sift = cv2.SIFT_create()
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    if des1 is None or des2 is None: return 0.0, None

    flann = cv2.FlannBasedMatcher(dict(algorithm=1, trees=5), dict(checks=50))
    matches = flann.knnMatch(des1, des2, k=2)
    good = [m for m, n in matches if m.distance < 0.7 * n.distance]

    if len(good) < 4: return 0.0, None
    pts1 = np.float32([kp1[m.queryIdx].pt for m in good])
    pts2 = np.float32([kp2[m.trainIdx].pt for m in good])

    H, _ = cv2.findHomography(pts1, pts2, cv2.RANSAC)
    if H is None: return 0.0, None

    try:
        h, w = img2.shape[:2]
        warped = cv2.warpPerspective(img1, H, (w, h))
        score, _ = ssim(cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY), cv2.cvtColor(warped, cv2.COLOR_BGR2GRAY), full=True)
        return score, H
    except: return 0.0, None

def transform_point(point, H):
    p = np.dot(H, np.array([point[0], point[1], 1.0]).reshape((3, 1)))
    return (p[:2] / p[2]).flatten() if p[2] != 0 else p[:2].flatten()

def calculate_distance(p1, p2):
    return np.linalg.norm(np.array(p1) - np.array(p2))

def get_max_speed(df, obj_id):
    t = df[df['id'] == obj_id].sort_values('frame')
    if len(t) < 2: return 1000.0
    speeds = [calculate_distance((t.iloc[i-1].xmin, t.iloc[i-1].ymin), (t.iloc[i].xmin, t.iloc[i].ymin)) * 30 for i in range(1, len(t))]
    return max(speeds) if speeds else 0

# --- Re-ID Execution ---
def process_reid(df_raw, video_path, output_csv):
    print("üß† Processing Re-ID...")
    # Filter low confidence
    df_clean = df_raw[df_raw['id'].isin(df_raw.groupby('id')['confidence'].mean()[lambda x: x >= 0.5].index)].copy()

    # Get candidates
    lifespans = df_clean.groupby('id').agg(start=('frame', 'min'), end=('frame', 'max')).reset_index()
    candidates = []
    for a in lifespans.itertuples():
        for b in lifespans.itertuples():
            if a.id != b.id and b.start > a.end:
                candidates.append((a.id, b.id, a.end, b.start))

    # Cache frames
    print(f"Checking {len(candidates)} pairs...")
    frames = {f: extract_frame_img(video_path, f) for f in tqdm(set([c[2] for c in candidates] + [c[3] for c in candidates]))}

    merges = []
    for id_a, id_b, f_end, f_start in tqdm(candidates):
        im_a, im_b = frames.get(f_end), frames.get(f_start)
        if im_a is None or im_b is None: continue

        score, H = get_similarity_and_homography(im_a, im_b)
        if score > 0.50 and H is not None:
            # Physics check
            row_a, row_b = df_clean[(df_clean.id==id_a)&(df_clean.frame==f_end)].iloc[0], df_clean[(df_clean.id==id_b)&(df_clean.frame==f_start)].iloc[0]
            center_a = [(row_a.xmin+row_a.xmax)/2, (row_a.ymin+row_a.ymax)/2]
            center_b = [(row_b.xmin+row_b.xmax)/2, (row_b.ymin+row_b.ymax)/2]
            dist = calculate_distance(transform_point(center_a, H), center_b)

            if dist < (get_max_speed(df_clean, id_a) * (f_start - f_end)/30 * 1.5):
                merges.append((id_a, id_b, score))

    merges.sort(key=lambda x: x[2], reverse=True)
    df_final = df_clean.copy()
    processed = set()
    count = 0
    for old, new, _ in merges:
        if new not in processed:
            df_final.loc[df_final.id == new, 'id'] = old
            processed.add(new)
            count += 1

    print(f"üîó Merged {count} IDs.")
    df_final.to_csv(output_csv, index=False)
    return df_final

# Run Step 2
reid_csv_path = os.path.join(OUTPUT_FOLDER, 'tracking_processed.csv')
df_reid = process_reid(df_raw, INPUT_VIDEO, reid_csv_path)

## Step 3: Video Visualization
Generates `Final_Output_ReID.mp4` with the corrected IDs overlayed on the video.

In [None]:
def create_annotated_video(video_path, tracking_df, output_path):
    print(f"üé• Generating video: {output_path}")
    cap = cv2.VideoCapture(video_path)
    w, h = int(cap.get(3)), int(cap.get(4))
    fps, total = int(cap.get(5)), int(cap.get(7))

    writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
    purple = (128, 0, 128)

    pbar = tqdm(total=total, desc="Rendering")
    frame_num = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break

        for _, row in tracking_df[tracking_df.frame == frame_num].iterrows():
            x1, y1, x2, y2 = int(row.xmin), int(row.ymin), int(row.xmax), int(row.ymax)
            cv2.rectangle(frame, (x1, y1), (x2, y2), purple, 3)
            cv2.putText(frame, f"ID {int(row.id)} ({row.confidence:.2f})", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, purple, 2)

        writer.write(frame)
        frame_num += 1
        pbar.update(1)

    cap.release()
    writer.release()
    pbar.close()
    print("‚ú® Video generation complete!")

final_vid_path = os.path.join(OUTPUT_FOLDER, 'Final_Output_ReID.mp4')
create_annotated_video(INPUT_VIDEO, df_reid, final_vid_path)

## Step 4: Analysis & Statistics
Calculates distances, overlaps, clusters, and density heatmaps.
Results are saved in `pipeline_results/plots` and `pipeline_results/stats`.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import os
import pandas as pd
import numpy as np

# --- STATISTICS FUNCTIONS ---

def save_id_statistics(df, output_dir):
    stats = df.groupby('id')['frame'].agg(['count'])
    stats['seconds'] = stats['count'] / 30

    avg_frames = stats['count'].mean()
    avg_seconds = stats['seconds'].mean()

    with open(os.path.join(output_dir, "id_statistics.txt"), "w") as f:
        f.write(f"Average number of frames each ID appears in: {avg_frames:.2f}\n")
        f.write(f"Average number of seconds each ID appears in: {avg_seconds:.2f}\n\n")
        f.write("Frame counts and seconds for each ID:\n")
        f.write("ID | Frame Counts | Seconds\n")
        f.write("-" * 30 + "\n")
        for idx, row in stats.iterrows():
            f.write(f"{float(idx)} | {row['count']} | {row['seconds']:.2f}\n")

def save_overlap_statistics(df, output_dir):
    total_frames = df['frame'].nunique()
    unique_overlaps = defaultdict(set)

    for frame, group in df.groupby('frame'):
        if len(group) < 2: continue
        boxes = group[['id', 'xmin', 'ymin', 'xmax', 'ymax']].values
        ids = boxes[:, 0]
        x1, y1, x2, y2 = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]

        for i in range(len(boxes)):
            for j in range(i + 1, len(boxes)):
                if (x1[i] < x2[j] and x2[i] > x1[j] and
                    y1[i] < y2[j] and y2[i] > y1[j]):
                    unique_overlaps[ids[i]].add(frame)
                    unique_overlaps[ids[j]].add(frame)

    res = []
    for uid in df['id'].unique():
        pct = (len(unique_overlaps[uid]) / total_frames) * 100
        res.append({'id': float(uid), 'overlap_percentage': pct})

    pd.DataFrame(res).sort_values('overlap_percentage', ascending=False)\
      .to_csv(os.path.join(output_dir, "overlap.csv"), sep='\t', index=False)

def find_connected_components(adj):
    visited = set()
    clusters = []
    for node in adj:
        if node not in visited:
            component = []
            stack = [node]
            while stack:
                n = stack.pop()
                if n not in visited:
                    visited.add(n)
                    component.append(n)
                    stack.extend(adj[n] - visited)
            clusters.append(sorted(component))
    return clusters

def save_cluster_ranges(df, output_dir):
    df['diagonal'] = np.sqrt((df.xmax - df.xmin)**2 + (df.ymax - df.ymin)**2)
    history = []

    for frame in sorted(df.frame.unique()):
        curr = df[df.frame == frame]
        if curr.empty: continue

        thresh = curr['diagonal'].mean() * 1.5
        adj = defaultdict(set)
        for uid in curr.id.values: adj[uid] = set()

        coords = curr[['x', 'y']].values
        ids = curr.id.values

        for i in range(len(ids)):
            for j in range(i + 1, len(ids)):
                if np.linalg.norm(coords[i] - coords[j]) < thresh:
                    adj[ids[i]].add(ids[j])
                    adj[ids[j]].add(ids[i])

        clusters = find_connected_components(adj)
        formatted = []
        for c in clusters:
            s = "{" + " ,".join([f"{float(x)}" for x in c]) + " }"
            formatted.append(s)
        formatted.sort()
        history.append((frame, str(formatted)))

    if not history: return
    grouped = defaultdict(list)
    last_state = history[0][1]
    start = history[0][0]
    end = start

    for f, state in history[1:]:
        if state == last_state and f == end + 1:
            end = f
        else:
            grouped[last_state].append(f"{start}-{end}")
            last_state = state
            start = f
            end = f
    grouped[last_state].append(f"{start}-{end}")

    rows = [{'Clusters': k, 'Frame Ranges': ", ".join(v)} for k, v in grouped.items()]
    pd.DataFrame(rows).to_csv(os.path.join(output_dir, "cluster_frame_ranges.csv"), index=False)

# --- PLOTTING FUNCTIONS ---

def save_avg_distance_plot(df, output_dir):
    daily = df.groupby('second')['distance'].mean()
    plt.figure(figsize=(10, 6))
    plt.plot(daily.index, daily.values, label='Avg Distance')
    plt.xlabel('Seconds'); plt.ylabel('Px / Frame')
    plt.title('Average Movement Distance Over Time')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(output_dir, "avg_distance_plot.png"))
    plt.close()

def save_individual_distance_plots(df, output_dir):
    """Plots distance over time for EACH ID separately."""
    for uid in df.id.unique():
        plt.figure(figsize=(10, 6))
        sub = df[df.id == uid]
        # Group by second to smooth out the plot slightly
        grp = sub.groupby('second')['distance'].sum()
        plt.plot(grp.index, grp.values)
        plt.title(f'Distance Moved Per Second - ID {uid}')
        plt.xlabel('Second'); plt.ylabel('Total Distance (px)')
        plt.savefig(os.path.join(output_dir, f"distance_plot_id_{uid}.png"))
        plt.close()

def save_trajectories_plot(df, output_dir):
    """Plots the X,Y path of every animal."""
    plt.figure(figsize=(12, 8))
    for uid in df.id.unique():
        sub = df[df.id == uid]
        plt.plot(sub.x, sub.y, label=f'ID {uid}', alpha=0.7)
        # Optional: Add start/end points
        plt.scatter(sub.x.iloc[0], sub.y.iloc[0], marker='o', s=30)
        plt.scatter(sub.x.iloc[-1], sub.y.iloc[-1], marker='x', s=30)

    plt.title('Animal Trajectories')
    plt.xlabel('X Position'); plt.ylabel('Y Position')
    plt.legend()
    plt.grid(True, alpha=0.3)
    # Invert Y axis because image coordinates start from top-left
    plt.gca().invert_yaxis()
    plt.savefig(os.path.join(output_dir, "animal_trajectories.png"))
    plt.close()

def save_heatmap(df, output_dir):
    """Generates a density heatmap of animal positions."""
    plt.figure(figsize=(12, 8))
    try:
        sns.kdeplot(x=df.x, y=df.y, cmap='viridis', fill=True, thresh=0.05, alpha=0.7)
        plt.gca().invert_yaxis() # Match image coordinates
        plt.title('Herd Density Heatmap')
        plt.savefig(os.path.join(output_dir, "density_heatmap.png"))
    except Exception as e:
        print(f"‚ö†Ô∏è Could not generate heatmap (not enough variance?): {e}")
    plt.close()

# --- MAIN CONTROLLER ---

def run_analysis_final(df_in, output_folder):
    print("üìä Starting Final Analysis...")
    stats_dir = os.path.join(output_folder, 'stats')
    plots_dir = os.path.join(output_folder, 'plots')
    os.makedirs(stats_dir, exist_ok=True)
    os.makedirs(plots_dir, exist_ok=True)

    # 1. Preprocessing
    df = df_in.copy()
    df['x'] = (df.xmin + df.xmax) / 2
    df['y'] = (df.ymin + df.ymax) / 2
    df['second'] = df.frame // 30

    # 2. Distance Calculation
    df = df.sort_values(['id', 'frame'])
    prevs = df.groupby('id')[['x', 'y']].shift(1)
    df['distance'] = np.sqrt((df.x - prevs.x)**2 + (df.y - prevs.y)**2).fillna(0)

    # 3. Generate Stats Files
    save_id_statistics(df, stats_dir)
    save_overlap_statistics(df, stats_dir)
    save_cluster_ranges(df, stats_dir)
    df.groupby('id')['distance'].sum().reset_index().rename(columns={'id': 'id'})\
      .to_csv(os.path.join(stats_dir, "sum_distance_per_id.csv"), index=False)

    # 4. Generate All Plots
    print("   üé® Generating plots...")
    save_avg_distance_plot(df, plots_dir)
    save_individual_distance_plots(df, plots_dir)
    save_trajectories_plot(df, plots_dir)
    save_heatmap(df, plots_dir)

    print(f"‚úÖ Analysis Complete. Files saved to: {output_folder}")

# Run Step 4
if 'df_reid' in locals():
    run_analysis_final(df_reid, OUTPUT_FOLDER)
else:
    print("‚ö†Ô∏è Please run Step 2 first to generate 'df_reid'.")