### Tracking videos using YOLO+SORT (change the input directory path)

In [None]:
import os
import cv2
import numpy as np
import csv
import glob
import pandas as pd
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment

# SORT tracker (restricted to max 2 IDs, since we only track two marmosets)
class SORT:
    def __init__(self):
        self.trackers = []
        self.used_ids = set()

    # Each tracked object is represented by a KalmanBoxTracker
    class KalmanBoxTracker:
        def __init__(self, bbox, assigned_id):
            # Initialize Kalman filter for a bounding box
            self.kf = KalmanFilter(dim_x=7, dim_z=4)
            # State transition matrix (motion model)
            self.kf.F = np.array([
                [1,0,0,0,1,0,0],
                [0,1,0,0,0,1,0],
                [0,0,1,0,0,0,1],
                [0,0,0,1,0,0,0],
                [0,0,0,0,1,0,0],
                [0,0,0,0,0,1,0],
                [0,0,0,0,0,0,1]])
            # Measurement matrix (maps state to bbox)
            self.kf.H = np.array([
                [1,0,0,0,0,0,0],
                [0,1,0,0,0,0,0],
                [0,0,1,0,0,0,0],
                [0,0,0,1,0,0,0]])
            # Uncertainty settings
            self.kf.R[2:,2:] *= 10.
            self.kf.P[4:,4:] *= 1000.
            self.kf.P *= 10.
            self.kf.Q[-1,-1] *= 0.01
            self.kf.Q[4:,4:] *= 0.01
            # Convert bbox to center, area, aspect ratio
            cx, cy = (bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2
            s = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])   # area
            r = (bbox[2]-bbox[0]) / (bbox[3]-bbox[1]+1e-6)  # aspect ratio
            self.kf.x[:4] = np.array([cx, cy, s, r]).reshape((4,1))
            self.id = assigned_id
            self.time_since_update = 0
            self.hits = 0
            self.hit_streak = 0

        def update(self, bbox):
            # Update tracker with new detection
            cx, cy = (bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2
            s = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
            r = (bbox[2]-bbox[0]) / (bbox[3]-bbox[1]+1e-6)
            self.kf.update(np.array([cx, cy, s, r]).reshape((4,1)))
            self.time_since_update = 0
            self.hits += 1
            self.hit_streak += 1

        def predict(self):
            # Predict next position
            self.kf.predict()
            self.time_since_update += 1
            if self.time_since_update > 0:
                self.hit_streak = 0
            return self.get_state()

        def get_state(self):
            # Convert predicted state back to bounding box [x1,y1,x2,y2]
            cx, cy, s, r = self.kf.x[:4].reshape(-1)
            w = np.sqrt(s*r)
            h = s / (w + 1e-6)
            return [cx-w/2, cy-h/2, cx+w/2, cy+h/2]

    def assign_id(self):
        # Assign unique IDs 1 or 2
        for i in [1, 2]:
            if i not in self.used_ids:
                return i
        return None

    def update(self, detections):
        # Update tracker with new detections
        trks = []
        to_del = []
        # Predict positions of existing trackers
        for t, trk in enumerate(self.trackers):
            pos = trk.predict()
            if np.any(np.isnan(pos)) or trk.time_since_update > 30:
                to_del.append(t)  # remove stale trackers
            else:
                trks.append(pos)
        # Delete old trackers
        for t in reversed(to_del):
            self.used_ids.discard(self.trackers[t].id)
            self.trackers.pop(t)

        trks = np.array(trks)
        dets = detections[:, :4] if len(detections) > 0 else np.empty((0,4))
        # Match detections with trackers
        matched, unmatched_dets, unmatched_trks = self.associate_detections_to_trackers(dets, trks)

        # Update matched trackers
        for m in matched:
            self.trackers[m[1]].update(detections[m[0], :4])

        # Increment age of unmatched trackers
        for t in unmatched_trks:
            self.trackers[t].time_since_update += 1

        # Create new trackers for unmatched detections (only up to 2 IDs)
        for i in unmatched_dets:
            if len(self.trackers) >= 2:
                continue
            new_id = self.assign_id()
            if new_id is not None:
                trk = SORT.KalmanBoxTracker(detections[i, :4], new_id)
                self.trackers.append(trk)
                self.used_ids.add(new_id)

        # Prepare return values (only active trackers)
        ret = []
        for trk in self.trackers:
            if trk.time_since_update < 1 and (trk.hits >= 3 or len(self.trackers) <= 2):
                d = trk.get_state()
                ret.append([*d, trk.id])
        return ret

    def iou(self, bb_test, bb_gt):
        # Intersection-over-Union (IoU) for matching bboxes
        xx1 = np.maximum(bb_test[0], bb_gt[0])
        yy1 = np.maximum(bb_test[1], bb_gt[1])
        xx2 = np.minimum(bb_test[2], bb_gt[2])
        yy2 = np.minimum(bb_test[3], bb_gt[3])
        w = np.maximum(0., xx2 - xx1)
        h = np.maximum(0., yy2 - yy1)
        wh = w*h
        return wh / ((bb_test[2]-bb_test[0])*(bb_test[3]-bb_test[1]) + 
                     (bb_gt[2]-bb_gt[0])*(bb_gt[3]-bb_gt[1]) - wh + 1e-6)

    def associate_detections_to_trackers(self, detections, trackers):
        # Match detections to existing trackers using IoU + Hungarian algorithm
        if len(trackers) == 0:
            return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0),dtype=int)

        iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
        for d, det in enumerate(detections):
            for t, trk in enumerate(trackers):
                iou_matrix[d,t] = self.iou(det, trk)

        matched_indices = linear_sum_assignment(-iou_matrix)
        matched_indices = np.array(matched_indices).T
        unmatched_detections = [d for d in range(len(detections)) if d not in matched_indices[:,0]]
        unmatched_trackers = [t for t in range(len(trackers)) if t not in matched_indices[:,1]]

        matches = []
        for m in matched_indices:
            if iou_matrix[m[0], m[1]] < 0.3:  # reject weak matches
                unmatched_detections.append(m[0])
                unmatched_trackers.append(m[1])
            else:
                matches.append(m.reshape(2))
        return np.array(matches), np.array(unmatched_detections), np.array(unmatched_trackers)

# --- Main video processing ---
def process_video(video_path, model, output_dir):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Load video
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Prepare output video and CSV paths
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    out_vid_path = os.path.join(output_dir, f"{video_name}_tracked.mp4")
    csv_path = os.path.join(output_dir, f"{video_name}_tracking.csv")

    out = cv2.VideoWriter(out_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    csv_file = open(csv_path, 'w', newline='')
    writer = csv.writer(csv_file)
    writer.writerow(['frame', 'id', 'x1', 'y1', 'x2', 'y2'])

    tracker = SORT()
    frame_id = 0

    # Process frames
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Run YOLO detection
        results = model(frame)[0]
        dets = [[x1, y1, x2, y2, conf] for x1, y1, x2, y2, conf, cls in results.boxes.data.tolist()
                if conf > 0.4 and int(cls) == 0]  # class 0 = marmoset
        dets = np.array(dets).reshape(-1, 5) if dets else np.empty((0,5))

        # Update tracker with detections
        tracked_objects = tracker.update(dets)

        # Draw boxes + IDs, write CSV
        for *bbox, tid in tracked_objects:
            x1, y1, x2, y2 = map(int, bbox)
            color = (0, 255, 0) if tid == 1 else (255, 0, 0)
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(frame, f'ID {tid}', (x1, y1-10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
            writer.writerow([frame_id, tid, x1, y1, x2, y2])

        out.write(frame)
        frame_id += 1

    # Cleanup
    cap.release()
    out.release()
    csv_file.close()

def main():
    # Input/output paths
    video_dir = r"D:\DeepSORT_ML2025\video_inputs"
    output_base_dir = r"D:\DeepSORT_ML2025\video_outputs"
    model_path = r"D:/DeepSORT_ML2025/yolo_project/marmoset_yolo/weights/best.pt"

    # Load YOLO model
    model = YOLO(model_path)
    os.makedirs(output_base_dir, exist_ok=True)

    # Process all videos in directory
    videos = glob.glob(os.path.join(video_dir, "*.mp4"))
    for vid in videos:
        print(f"Processing {vid} ...")
        video_name = os.path.splitext(os.path.basename(vid))[0]
        output_dir = os.path.join(output_base_dir, video_name)
        process_video(vid, model, output_dir)
        print(f"Finished {video_name}")

if __name__ == "__main__":
    main()


## The next few code blocks cater to the pre and post processing needs of a specific camera setup
### When the arena is captured by 4 different cameras (Check the readme.md file associated with this repository for more details)



In [None]:
import os
import cv2
import numpy as np
import csv
import math
import glob
import pandas as pd
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment

# Inverse transform for right half boxes
def inverse_transform_right_half_box(x1, y1, x2, y2):
    inverse_transform_matrix = np.array([
        [ 1.05403131,  0.12941877, -51.0827529],
        [-0.14624321,  1.19105538, 136.228942],
        [ 0.0,         0.0,          1.0     ]
    ])

    def apply_inverse(x, y):
        pt = np.array([x, y, 1.0])
        result = inverse_transform_matrix @ pt
        return result[0], result[1]

    x1n, y1n = apply_inverse(x1, y1)
    x2n, y2n = apply_inverse(x2, y2)
    return [x1n, y1n, x2n, y2n]

# SORT tracker class
class SORT:
    def __init__(self):
        self.trackers = []
        self.used_ids = set()

    class KalmanBoxTracker:
        def __init__(self, bbox, assigned_id):
            self.kf = KalmanFilter(dim_x=7, dim_z=4)
            self.kf.F = np.array([
                [1,0,0,0,1,0,0],
                [0,1,0,0,0,1,0],
                [0,0,1,0,0,0,1],
                [0,0,0,1,0,0,0],
                [0,0,0,0,1,0,0],
                [0,0,0,0,0,1,0],
                [0,0,0,0,0,0,1]])
            self.kf.H = np.array([
                [1,0,0,0,0,0,0],
                [0,1,0,0,0,0,0],
                [0,0,1,0,0,0,0],
                [0,0,0,1,0,0,0]])
            self.kf.R[2:,2:] *= 10.
            self.kf.P[4:,4:] *= 1000.
            self.kf.P *= 10.
            self.kf.Q[-1,-1] *= 0.01
            self.kf.Q[4:,4:] *= 0.01
            cx, cy = (bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2
            s = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
            r = (bbox[2]-bbox[0]) / (bbox[3]-bbox[1]+1e-6)
            self.kf.x[:4] = np.array([cx, cy, s, r]).reshape((4,1))
            self.id = assigned_id
            self.time_since_update = 0
            self.hits = 0
            self.hit_streak = 0

        def update(self, bbox):
            cx, cy = (bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2
            s = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
            r = (bbox[2]-bbox[0]) / (bbox[3]-bbox[1]+1e-6)
            self.kf.update(np.array([cx, cy, s, r]).reshape((4,1)))
            self.time_since_update = 0
            self.hits += 1
            self.hit_streak += 1

        def predict(self):
            self.kf.predict()
            self.time_since_update += 1
            if self.time_since_update > 0:
                self.hit_streak = 0
            return self.get_state()

        def get_state(self):
            cx, cy, s, r = self.kf.x[:4].reshape(-1)
            w = np.sqrt(s*r)
            h = s / (w + 1e-6)
            return [cx-w/2, cy-h/2, cx+w/2, cy+h/2]

    def assign_id(self):
        for i in [1, 2]:
            if i not in self.used_ids:
                return i
        return None

    def update(self, detections):
        # Predict new locations of existing trackers
        trks = []
        to_del = []
        for t, trk in enumerate(self.trackers):
            pos = trk.predict()
            if np.any(np.isnan(pos)) or trk.time_since_update > 30:
                to_del.append(t)
            else:
                trks.append(pos)

        for t in reversed(to_del):
            self.used_ids.discard(self.trackers[t].id)
            self.trackers.pop(t)

        trks = np.array(trks)
        dets = detections[:, :4] if len(detections) > 0 else np.empty((0,4))

        matched, unmatched_dets, unmatched_trks = self.associate_detections_to_trackers(dets, trks)

        # Update matched trackers with assigned detections
        for m in matched:
            self.trackers[m[1]].update(detections[m[0], :4])

        # Increment time_since_update for unmatched trackers
        for t in unmatched_trks:
            self.trackers[t].time_since_update += 1

        # Create new trackers for unmatched detections (max 2 trackers)
        for i in unmatched_dets:
            if len(self.trackers) >= 2:
                continue
            new_id = self.assign_id()
            if new_id is not None:
                trk = SORT.KalmanBoxTracker(detections[i, :4], new_id)
                self.trackers.append(trk)
                self.used_ids.add(new_id)

        ret = []
        for trk in self.trackers:
            if trk.time_since_update < 1 and (trk.hits >= 3 or len(self.trackers) <= 2):
                d = trk.get_state()
                ret.append([*d, trk.id])
        return ret

    def iou(self, bb_test, bb_gt):
        xx1 = np.maximum(bb_test[0], bb_gt[0])
        yy1 = np.maximum(bb_test[1], bb_gt[1])
        xx2 = np.minimum(bb_test[2], bb_gt[2])
        yy2 = np.minimum(bb_test[3], bb_gt[3])
        w = np.maximum(0., xx2 - xx1)
        h = np.maximum(0., yy2 - yy1)
        wh = w*h
        o = wh / ((bb_test[2]-bb_test[0])*(bb_test[3]-bb_test[1]) + (bb_gt[2]-bb_gt[0])*(bb_gt[3]-bb_gt[1]) - wh + 1e-6)
        return o

    def associate_detections_to_trackers(self, detections, trackers):
        if len(trackers) == 0:
            return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0),dtype=int)

        iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
        for d, det in enumerate(detections):
            for t, trk in enumerate(trackers):
                iou_matrix[d,t] = self.iou(det, trk)

        matched_indices = linear_sum_assignment(-iou_matrix)
        matched_indices = np.array(matched_indices).T

        unmatched_detections = [d for d in range(len(detections)) if d not in matched_indices[:,0]]
        unmatched_trackers = [t for t in range(len(trackers)) if t not in matched_indices[:,1]]

        matches = []
        for m in matched_indices:
            if iou_matrix[m[0], m[1]] < 0.3:
                unmatched_detections.append(m[0])
                unmatched_trackers.append(m[1])
            else:
                matches.append(m.reshape(2))

        return np.array(matches), np.array(unmatched_detections), np.array(unmatched_trackers)

# Log helper for ID 1 left/right position by column (x1 or x2)
def log_id1_position(df, colname, txt_path, width):
    log_lines = []
    for fid, group in df.groupby("frame"):
        pos = {}
        for tid in [1, 2]:
            sub = group[(group['id'] == tid) & (~group[colname].isna())]
            if sub.empty:
                continue
            val = sub[colname].iloc[0]
            if sub['side'].iloc[0] == 'right':
                val += width // 2
            pos[tid] = val
        if 1 in pos and 2 in pos:
            where = 'left' if pos[1] < pos[2] else 'right'
            log_lines.append(f"Frame {fid}: ID 1 is on the {where}")
    with open(txt_path, 'w') as f:
        f.write('\n'.join(log_lines))

# Main video processing function
def process_video(video_path, model, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    center_x = width // 2

    video_name = os.path.splitext(os.path.basename(video_path))[0]

    # Paths for outputs
    out_vid_path = os.path.join(output_dir, f"{video_name}_tracked.mp4")
    combined_csv_path = os.path.join(output_dir, f"{video_name}_tracking.csv")
    left_csv_path = os.path.join(output_dir, f"{video_name}_cam10.csv")
    right_csv_path = os.path.join(output_dir, f"{video_name}_cam11.csv")
    txt_x1_path = os.path.join(output_dir, f"{video_name}_id1_x1_position.txt")
    txt_x2_path = os.path.join(output_dir, f"{video_name}_id1_x2_position.txt")

    # Video writer
    out = cv2.VideoWriter(out_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    # CSV writers
    combined_csv_file = open(combined_csv_path, 'w', newline='')
    combined_writer = csv.writer(combined_csv_file)
    combined_writer.writerow(['frame', 'side', 'id', 'x1', 'y1', 'x2', 'y2'])

    left_csv_file = open(left_csv_path, 'w', newline='')
    left_writer = csv.writer(left_csv_file)
    left_writer.writerow(['frame', 'id', 'x1', 'y1', 'x2', 'y2'])

    right_csv_file = open(right_csv_path, 'w', newline='')
    right_writer = csv.writer(right_csv_file)
    right_writer.writerow(['frame', 'id', 'x1', 'y1', 'x2', 'y2'])

    tracker = SORT()
    frame_id = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model(frame)[0]
        dets = [[x1, y1, x2, y2, conf] for x1, y1, x2, y2, conf, cls in results.boxes.data.tolist()
                if conf > 0.4 and int(cls) == 0]
        dets = np.array(dets).reshape(-1, 5) if dets else np.empty((0,5))

        tracked_objects = tracker.update(dets)

        # Store left and right side boxes per ID for this frame
        sides = {'left': {1: None, 2: None}, 'right': {1: None, 2: None}}

        # YOUR FULL MODIFIED SCRIPT (with drawing logic changed)
# This version keeps inverse-transformed right boxes in CSVs,
# but uses YOLO-detected raw boxes for drawing in the video.

# [all your imports, tracker code, etc. remains unchanged above this line]

# Replace the drawing logic INSIDE `process_video()` with this:
        for *bbox, tid in tracked_objects:
            x1, y1, x2, y2 = bbox
            x1_i, y1_i, x2_i, y2_i = map(int, bbox)

            if x1_i < center_x < x2_i:
                # Left side clipped to center
                left_box = (x1_i, y1_i, center_x, y2_i)
                right_box_rel = (0, y1_i, x2_i - center_x, y2_i)
                right_box = inverse_transform_right_half_box(*right_box_rel)

                sides['left'][tid] = left_box
                sides['right'][tid] = right_box

                # Choose colors per ID
                color = (0, 255, 0) if tid == 1 else (255, 0, 0)  # Green for ID1, Blue for ID2

                # Draw raw YOLO boxes (split)
                cv2.rectangle(frame, (x1_i, y1_i), (center_x, y2_i), color, 2)
                cv2.putText(frame, f'ID {tid}', (x1_i, y1_i-10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
                cv2.rectangle(frame, (center_x, y1_i), (x2_i, y2_i), color, 2)
                cv2.putText(frame, f'ID {tid}', (center_x, y1_i-10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)

            elif x2_i <= center_x:
                # Fully on left
                sides['left'][tid] = (x1_i, y1_i, x2_i, y2_i)
                color = (0, 255, 0) if tid == 1 else (255, 0, 0)
                cv2.rectangle(frame, (x1_i, y1_i), (x2_i, y2_i), color, 2)
                cv2.putText(frame, f'ID {tid}', (x1_i, y1_i-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)

            else:
                # Fully on right: apply inverse for CSV, but draw original
                right_box_rel = (x1_i - center_x, y1_i, x2_i - center_x, y2_i)
                right_box = inverse_transform_right_half_box(*right_box_rel)
                sides['right'][tid] = right_box

                color = (0, 255, 0) if tid == 1 else (255, 0, 0)
                cv2.rectangle(frame, (x1_i, y1_i), (x2_i, y2_i), color, 2)
                cv2.putText(frame, f'ID {tid}', (x1_i, y1_i-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)


        # Write bounding box info to CSVs
        for side in ['left', 'right']:
            for tid in [1, 2]:
                box = sides[side][tid]
                if box is None:
                    combined_writer.writerow([frame_id, side, tid, np.nan, np.nan, np.nan, np.nan])
                    if side == 'left':
                        left_writer.writerow([frame_id, tid, np.nan, np.nan, np.nan, np.nan])
                    else:
                        right_writer.writerow([frame_id, tid, np.nan, np.nan, np.nan, np.nan])
                else:
                    x1, y1, x2, y2 = box
                    combined_writer.writerow([frame_id, side, tid, x1, y1, x2, y2])
                    if side == 'left':
                        left_writer.writerow([frame_id, tid, x1, y1, x2, y2])
                    else:
                        right_writer.writerow([frame_id, tid, x1, y1, x2, y2])

        out.write(frame)
        frame_id += 1

    cap.release()
    out.release()
    combined_csv_file.close()
    left_csv_file.close()
    right_csv_file.close()

    # Load combined CSV to generate TXT logs for x1 and x2
    df = pd.read_csv(combined_csv_path)
    log_id1_position(df, 'x1', txt_x1_path, width)
    log_id1_position(df, 'x2', txt_x2_path, width)

# Batch processing for all videos in folder
def main():
    video_dir = r"D:\DeepSORT_ML2025\video_stitch\stitched_outputs"
    output_base_dir = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked"
    model_path = r"D:/DeepSORT_ML2025/yolo_project/marmoset_yolo_20250619_172321/weights/best.pt"

    model = YOLO(model_path)

    os.makedirs(output_base_dir, exist_ok=True)
    videos = glob.glob(os.path.join(video_dir, "*.mp4"))

    for vid in videos:
        print(f"Processing {vid} ...")
        video_name = os.path.splitext(os.path.basename(vid))[0]
        output_dir = os.path.join(output_base_dir, video_name)
        process_video(vid, model, output_dir)
        print(f"Finished {video_name}")

if __name__ == "__main__":
    main()


### make empty txt in all subfolders

In [None]:
import os

# === Path to your target root folder ===
target_root = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked"

# === Go through all subfolders ===
for subfolder in os.listdir(target_root):
    subfolder_path = os.path.join(target_root, subfolder)

    # Ensure it's a directory
    if os.path.isdir(subfolder_path):
        close_file = os.path.join(subfolder_path, 'close_frames.txt')

        # Create the file if it doesn't exist
        if not os.path.exists(close_file):
            with open(close_file, 'w') as f:
                f.write('')  # Optional: add a header like 'Close Frames:\n'
            print(f'Created: {close_file}')
        else:
            print(f'Already exists: {close_file}')


### Switching IDs from Manual Guide

In [None]:
import os
import pandas as pd
import re

# Convert close_frames.txt to a list of (start_frame, end_frame)
def parse_swap_file(txt_path, fps=24):
    swaps = []
    with open(txt_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if '-' in line:
                start_sec, end_sec = map(float, line.split('-'))
                swaps.append((int(start_sec * fps), int(end_sec * fps)))
            else:
                start_frame = int(float(line) * fps)
                swaps.append((start_frame, float('inf')))
    return swaps

# Check if a given frame is in any swap range
def should_swap(frame, swap_ranges):
    return any(start <= frame <= end for start, end in swap_ranges)

# Swap ID 1 and 2 in the tracking CSV
def correct_csv(csv_path, swap_ranges):
    df = pd.read_csv(csv_path)
    for i, row in df.iterrows():
        frame = row['frame']
        if should_swap(frame, swap_ranges):
            old_id = df.at[i, 'id']
            if pd.notna(old_id):
                df.at[i, 'id'] = 3 - old_id  # Swap 1 ↔ 2
    return df

# Swap "ID 1 is on the left/right" lines in log files
def correct_log(log_path, swap_ranges):
    corrected_lines = []
    with open(log_path, 'r') as f:
        for line in f:
            match = re.search(r'Frame (\d+): ID 1 is on the (\w+)', line)
            if match:
                frame = int(match.group(1))
                side = match.group(2)
                if should_swap(frame, swap_ranges):
                    side = 'right' if side == 'left' else 'left'
                corrected_lines.append(f"Frame {frame}: ID 1 is on the {side}")
            else:
                corrected_lines.append(line.strip())
    return corrected_lines

# Process one folder
def process_folder(folder, fps=24):
    print(f"Processing: {folder}")
    close_path = os.path.join(folder, "close_frames.txt")
    if not os.path.exists(close_path):
        print("  Skipping (no close_frames.txt found)")
        return

    base = None
    for file in os.listdir(folder):
        if file.endswith("_tracking.csv"):
            base = file.replace("_tracking.csv", "")
            break
    if not base:
        print("  Skipping (no tracking CSV found)")
        return

    swap_ranges = parse_swap_file(close_path, fps)

    # Fix CSV
    csv_path = os.path.join(folder, f"{base}_tracking.csv")
    df = correct_csv(csv_path, swap_ranges)
    df.to_csv(os.path.join(folder, f"{base}_tracking_corrected.csv"), index=False)

    # Fix log TXT files
    for axis in ["x1", "x2"]:
        txt_path = os.path.join(folder, f"{base}_id1_{axis}_position.txt")
        if os.path.exists(txt_path):
            corrected_lines = correct_log(txt_path, swap_ranges)
            corrected_path = os.path.join(folder, f"{base}_id1_{axis}_position_corrected.txt")
            with open(corrected_path, 'w') as f:
                f.write("\n".join(corrected_lines))

# Process all video folders
def main():
    root_dir = r"D:\tracked_output"  # update if needed
    for subfolder in os.listdir(root_dir):
        full_path = os.path.join(root_dir, subfolder)
        if os.path.isdir(full_path):
            process_folder(full_path)

if __name__ == "__main__":
    main()


### Creating files after correction

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import re

def load_switch_intervals(txt_path, total_frames, fps=24):
    intervals = []
    with open(txt_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if '-' in line:
                start_sec, end_sec = line.split('-')
                start_frame = int(float(start_sec) * fps)
                end_frame = int(float(end_sec) * fps)
            else:
                start_frame = int(float(line) * fps)
                end_frame = total_frames - 1
            start_frame = max(0, start_frame)
            end_frame = min(total_frames - 1, end_frame)
            intervals.append((start_frame, end_frame))
    return intervals

def get_fps_and_framecount(video_path):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return fps, total

def swap_ids_in_df(df, intervals):
    df = df.copy()
    for start, end in intervals:
        mask = (df['frame'] >= start) & (df['frame'] <= end)
        id1_rows = df[mask & (df['id'] == 1)].index
        id2_rows = df[mask & (df['id'] == 2)].index

        # Temporarily mark id=1 rows as -1 to avoid collision
        df.loc[id1_rows, 'id'] = -1
        df.loc[id2_rows, 'id'] = 1
        df.loc[id1_rows, 'id'] = 2

        # Swap bbox coords between these rows per frame
        frames_to_swap = df.loc[mask, 'frame'].unique()
        for f in frames_to_swap:
            idx1 = df[(df['frame'] == f) & (df['id'] == 2)].index
            idx2 = df[(df['frame'] == f) & (df['id'] == 1)].index
            if len(idx1) == 1 and len(idx2) == 1:
                cols = ['x1', 'y1', 'x2', 'y2']
                tmp = df.loc[idx1[0], cols].copy()
                df.loc[idx1[0], cols] = df.loc[idx2[0], cols]
                df.loc[idx2[0], cols] = tmp
    return df

def swap_sides_in_txt(lines, intervals, fps=24):
    new_lines = []
    for line in lines:
        # Check if line starts with "Frame N:"
        m = re.match(r"Frame\s+(\d+):\s+ID\s+1\s+is\s+on\s+the\s+(left|right)", line)
        if m:
            frame_num = int(m.group(1))
            side = m.group(2)
            # Check if this frame is in any switch interval
            if any(start <= frame_num <= end for start, end in intervals):
                flipped_side = 'left' if side == 'right' else 'right'
                line = re.sub(r"(left|right)", flipped_side, line)
        new_lines.append(line)
    return new_lines

def process_folder(sub_path):
    print(f"\nProcessing folder: {sub_path}")
    video_file = [f for f in os.listdir(sub_path) if f.endswith('.mp4')]
    if not video_file:
        print("  No video file found, skipping.")
        return
    video_file = video_file[0]
    video_path = os.path.join(sub_path, video_file)
    fps, total_frames = get_fps_and_framecount(video_path)

    close_frames_path = os.path.join(sub_path, 'close_frames.txt')
    if not os.path.exists(close_frames_path):
        print("  No close_frames.txt found, skipping.")
        return

    intervals = load_switch_intervals(close_frames_path, total_frames, fps)
    if not intervals:
        print("  No intervals to switch found.")
        return

    print(f"  Switch intervals (frames): {intervals}")

    base_name = os.path.basename(sub_path)

    # Process CSV files
    csv_files = [f for f in os.listdir(sub_path) if f.endswith('.csv') and not f.endswith('_corrected.csv')]
    for csv_file in csv_files:
        csv_path = os.path.join(sub_path, csv_file)
        df = pd.read_csv(csv_path)
        required_cols = {'frame','id','x1','y1','x2','y2'}
        if not required_cols.issubset(df.columns):
            print(f"  Skipping {csv_file}, missing required columns.")
            continue
        print(f"  Processing CSV: {csv_file}")
        df_corrected = swap_ids_in_df(df, intervals)
        corrected_path = os.path.join(sub_path, csv_file.replace('.csv', '_corrected.csv'))
        df_corrected.to_csv(corrected_path, index=False)
        print(f"    Saved corrected CSV: {corrected_path}")

    # Process TXT files
    txt_files = [f for f in os.listdir(sub_path) if f.endswith('.txt') and not f.endswith('_corrected.txt') and 'close_frames' not in f]
    for txt_file in txt_files:
        txt_path = os.path.join(sub_path, txt_file)
        with open(txt_path, 'r') as f:
            lines = f.readlines()
        print(f"  Processing TXT: {txt_file}")
        lines_corrected = swap_sides_in_txt(lines, intervals, fps)
        corrected_txt_path = os.path.join(sub_path, txt_file.replace('.txt', '_corrected.txt'))
        with open(corrected_txt_path, 'w') as f:
            f.writelines(lines_corrected)
        print(f"    Saved corrected TXT: {corrected_txt_path}")

def main():
    root_dir = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked"  # Change this to your root folder
    for folder in sorted(os.listdir(root_dir)):
        sub_path = os.path.join(root_dir, folder)
        if os.path.isdir(sub_path):
            process_folder(sub_path)

if __name__ == "__main__":
    main()


### Tracking cam18 with guide

In [None]:
import os
import cv2
import numpy as np
import csv
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment


def load_guide_txt(txt_path):
    guide = {}
    with open(txt_path, 'r') as f:
        for line in f:
            if "Frame" in line and "ID 1 is on the" in line:
                parts = line.strip().split(":")
                frame_num = int(parts[0].split()[-1])
                side = parts[1].strip().split()[-1].lower()
                guide[frame_num] = side
    return guide


class KalmanBoxTracker:
    def __init__(self, bbox, assigned_id):
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        self.kf.F = np.eye(7)
        self.kf.F[:4, 4:] = np.eye(4)[:4, :3]
        self.kf.H = np.eye(4, 7)
        self.kf.P[4:, 4:] *= 1000.
        self.kf.P *= 10.
        self.kf.Q *= 0.001

        cx = (bbox[0] + bbox[2]) / 2
        cy = (bbox[1] + bbox[3]) / 2
        s = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
        r = (bbox[2] - bbox[0]) / (bbox[3] - bbox[1] + 1e-6)
        self.kf.x[:4] = np.array([cx, cy, s, r]).reshape((4, 1))

        self.id = assigned_id
        self.hits = 1
        self.time_since_update = 0
        self.last_yolo_box = bbox

    def predict(self):
        self.kf.predict()
        self.time_since_update += 1
        return self.get_state()

    def update(self, bbox):
        self.time_since_update = 0
        self.hits += 1
        cx = (bbox[0] + bbox[2]) / 2
        cy = (bbox[1] + bbox[3]) / 2
        s = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
        r = (bbox[2] - bbox[0]) / (bbox[3] - bbox[1] + 1e-6)
        self.kf.update(np.array([cx, cy, s, r]).reshape((4, 1)))
        self.last_yolo_box = bbox

    def get_state(self):
        cx, cy, s, r = self.kf.x[:4].reshape(-1)
        if s * r <= 0:
            return [0, 0, 0, 0]
        w = np.sqrt(abs(s * r))
        h = abs(s) / (w + 1e-6)
        return [cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2]


class SORT:
    def __init__(self, guide):
        self.trackers = []
        self.frame_id = 0
        self.guide = guide

    def update(self, detections, width):
        self.frame_id += 1
        updated_tracks = []

        detections = sorted(detections, key=lambda d: d[4], reverse=True)[:2]
        det_boxes = [d[:4] for d in detections]

        guide_side = self.guide.get(self.frame_id - 1)

        if guide_side and det_boxes:
            self.trackers = []
            if len(det_boxes) == 1:
                assigned_id = 1 if guide_side == 'left' else 2
                self.trackers = [KalmanBoxTracker(det_boxes[0], assigned_id)]
                updated_tracks.append([*det_boxes[0], assigned_id])
            elif len(det_boxes) == 2:
                sorted_dets = sorted(det_boxes, key=lambda b: b[0])
                ids = [1, 2] if guide_side == 'left' else [2, 1]
                for i in range(2):
                    trk = KalmanBoxTracker(sorted_dets[i], ids[i])
                    self.trackers.append(trk)
                    updated_tracks.append([*sorted_dets[i], ids[i]])

        return updated_tracks


def track_video(video_path, guide_path, output_dir, model):
    os.makedirs(output_dir, exist_ok=True)

    guide = load_guide_txt(guide_path)
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    base = os.path.splitext(os.path.basename(video_path))[0]
    out_path = os.path.join(output_dir, f"{base}_tracked.mp4")
    csv_path = os.path.join(output_dir, f"{base}_tracking.csv")

    out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    tracker = SORT(guide)

    csv_file = open(csv_path, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['frame', 'id', 'x1', 'y1', 'x2', 'y2'])

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model(frame, verbose=False)[0]
        dets = [[*xyxy, conf] for *xyxy, conf, cls in results.boxes.data.tolist()
                if conf > 0.7 and int(cls) == 0]

        tracks = tracker.update(dets, width)

        for x1, y1, x2, y2, tid in tracks:
            x1i, y1i, x2i, y2i = map(int, [x1, y1, x2, y2])
            color = (255, 0, 0) if tid == 1 else (0, 255, 0)
            cv2.rectangle(frame, (x1i, y1i), (x2i, y2i), color, 2)
            cv2.putText(frame, f'ID {tid}', (x1i, y1i - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
            csv_writer.writerow([frame_idx, tid, x1, y1, x2, y2])

        out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()
    csv_file.close()
    print(f"✓ Finished: {video_path}")


def main():
    video_dir = r"D:\DeepSORT_ML2025\video_stitch\cam18" 
    guide_root_dir = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked" 
    output_dir = r"D:\DeepSORT_ML2025\video_stitch\cam18_tracked" 

    model_path = r"D:\DeepSORT_ML2025\yolo_project\marmoset_yolo_20250708_105031\weights\best.pt" 
    model = YOLO(model_path)

    for fname in os.listdir(video_dir):
        if not fname.endswith(".mp4"):
            continue
        if "cam15" not in fname and "cam18" not in fname:
            continue

        full_path = os.path.join(video_dir, fname)
        base_name = fname.replace("_cam15_", "").replace("_cam18_", "").replace("__", "_").replace("--", "-")
        base_name = os.path.splitext(base_name)[0]

        guide_path = os.path.join(guide_root_dir, base_name, f"{base_name}_id1_x1_position_corrected.txt")
        if not os.path.exists(guide_path):
            print(f"⚠️ Guide not found: {guide_path}")
            continue

                # Remove common suffix patterns like _from_18_pre or _from_18_post
        clean_base_name = base_name.split("_from_")[0]
        video_output_dir = os.path.join(output_dir, clean_base_name)

        track_video(full_path, guide_path, video_output_dir, model)

if __name__ == "__main__":
    main()


### Tracking cam15 with guide

In [None]:
import os
import cv2
import numpy as np
import csv
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment

# Load guide TXT 
def load_guide_txt(txt_path):
    guide = {}
    with open(txt_path, 'r') as f:
        for line in f:
            if "Frame" in line and "ID 1 is on the" in line:
                parts = line.strip().split(":")
                frame_num = int(parts[0].split()[-1])
                side = parts[1].strip().split()[-1].lower()
                guide[frame_num] = side
    return guide

# Kalman Tracker 
class KalmanBoxTracker:
    def __init__(self, bbox, assigned_id):
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        self.kf.F = np.eye(7)
        self.kf.F[:4, 4:] = np.eye(4)[:4, :3]
        self.kf.H = np.eye(4, 7)
        self.kf.P[4:, 4:] *= 1000.
        self.kf.P *= 10.
        self.kf.Q *= 0.001

        cx = (bbox[0] + bbox[2]) / 2
        cy = (bbox[1] + bbox[3]) / 2
        s = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
        r = (bbox[2] - bbox[0]) / (bbox[3] - bbox[1] + 1e-6)
        self.kf.x[:4] = np.array([cx, cy, s, r]).reshape((4, 1))

        self.id = assigned_id
        self.hits = 1
        self.time_since_update = 0
        self.last_yolo_box = bbox

    def predict(self):
        self.kf.predict()
        self.time_since_update += 1
        return self.get_state()

    def update(self, bbox):
        self.time_since_update = 0
        self.hits += 1
        cx = (bbox[0] + bbox[2]) / 2
        cy = (bbox[1] + bbox[3]) / 2
        s = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
        r = (bbox[2] - bbox[0]) / (bbox[3] - bbox[1] + 1e-6)
        self.kf.update(np.array([cx, cy, s, r]).reshape((4, 1)))
        self.last_yolo_box = bbox

    def get_state(self):
        cx, cy, s, r = self.kf.x[:4].reshape(-1)
        if s * r <= 0:
            return [0, 0, 0, 0]
        w = np.sqrt(abs(s * r))
        h = abs(s) / (w + 1e-6)
        return [cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2]

#SORT Tracker 
class SORT:
    def __init__(self, guide):
        self.trackers = []
        self.frame_id = 0
        self.guide = guide

    def update(self, detections):
        self.frame_id += 1
        updated_tracks = []

        detections = sorted(detections, key=lambda d: d[4], reverse=True)[:2]
        det_boxes = [d[:4] for d in detections]

        guide_side = self.guide.get(self.frame_id - 1)

        if guide_side and det_boxes:
            self.trackers = []
            if len(det_boxes) == 1:
                assigned_id = 1 if guide_side == 'right' else 2
                self.trackers = [KalmanBoxTracker(det_boxes[0], assigned_id)]
                updated_tracks.append([*det_boxes[0], assigned_id])
            elif len(det_boxes) == 2:
                sorted_dets = sorted(det_boxes, key=lambda b: b[2], reverse=True)  # Sort by x2
                ids = [1, 2] if guide_side == 'right' else [2, 1]
                for i in range(2):
                    trk = KalmanBoxTracker(sorted_dets[i], ids[i])
                    self.trackers.append(trk)
                    updated_tracks.append([*sorted_dets[i], ids[i]])

        return updated_tracks

# Run tracking on one video
def track_video(video_path, guide_path, output_dir, model):
    os.makedirs(output_dir, exist_ok=True)

    guide = load_guide_txt(guide_path)
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    base = os.path.splitext(os.path.basename(video_path))[0]
    out_path = os.path.join(output_dir, f"{base}_tracked.mp4")
    csv_path = os.path.join(output_dir, f"{base}_tracking.csv")

    out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    tracker = SORT(guide)

    csv_file = open(csv_path, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['frame', 'id', 'x1', 'y1', 'x2', 'y2'])

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model(frame, verbose=False)[0]
        dets = [[*xyxy, conf] for *xyxy, conf, cls in results.boxes.data.tolist()
                if conf > 0.7 and int(cls) == 0]

        tracks = tracker.update(dets)

        for x1, y1, x2, y2, tid in tracks:
            x1i, y1i, x2i, y2i = map(int, [x1, y1, x2, y2])
            color = (255, 0, 0) if tid == 1 else (0, 255, 0)
            cv2.rectangle(frame, (x1i, y1i), (x2i, y2i), color, 2)
            cv2.putText(frame, f'ID {tid}', (x1i, y1i - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
            csv_writer.writerow([frame_idx, tid, x1, y1, x2, y2])

        out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()
    csv_file.close()
    print(f"✓ Finished: {video_path}")

#Batch runner 
def main():
    video_dir = r"D:\DeepSORT_ML2025\video_stitch\cam15"  # <-- replace with your video path
    guide_root_dir = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked"  # <-- base folder with guide subfolders
    output_dir = r"D:\DeepSORT_ML2025\video_stitch\cam15_tracked" 
    model_path = r"D:\DeepSORT_ML2025\yolo_project\marmoset_yolo_20250708_105031\weights\best.pt"  # <-- replace

    model = YOLO(model_path)

    for fname in os.listdir(video_dir):
        if not fname.endswith(".mp4"):
            continue
        if "cam15" not in fname and "cam18" not in fname:
            continue

        full_path = os.path.join(video_dir, fname)
        base_name = fname.replace("_cam15_", "").replace("_cam18_", "").replace("__", "_").replace("--", "-")
        base_name = os.path.splitext(base_name)[0]

        guide_path = os.path.join(guide_root_dir, base_name, f"{base_name}_id1_x2_position_corrected.txt")
        if not os.path.exists(guide_path):
            print(f"⚠️ Guide not found: {guide_path}")
            continue

                        # Remove common suffix patterns like _from_18_pre or _from_18_post
        clean_base_name = base_name.split("_from_")[0]
        video_output_dir = os.path.join(output_dir, clean_base_name)

        track_video(full_path, guide_path, video_output_dir, model)

if __name__ == "__main__":
    main()


### Some transforms due to stitching

In [None]:
import os
import pandas as pd
import numpy as np


root_dir = r"D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked"  # CHANGE this to your root folder
H, W = 1080, 1920  # Video height and width


inverse_transform_matrix = np.array([
    [ 1.05403131,  0.12941877, -51.0827529],
    [-0.14624321,  1.19105538, 136.228942],
    [ 0.0,         0.0,          1.0     ]
])


def get_correct_transform_matrix(shape):
    H, W = shape
    src_pts = np.float32([[W//2, 0], [W-1, 0], [W//2, H-1]])
    dst_pts = np.float32([[W//2 + 30, 5], [W + 45, 0], [W//2 + 45, H]])
    matrix = cv2.getAffineTransform(src_pts, dst_pts)
    return matrix

correct_transform_matrix = get_correct_transform_matrix((H, W))


def apply_combined_transform(x1, y1, x2, y2):
    def apply(M, x, y):
        pt = np.array([x, y, 1.0])
        out = M @ pt
        return out[0], out[1]

    # Undo previous inverse transform
    x1u, y1u = apply(np.linalg.inv(inverse_transform_matrix), x1, y1)
    x2u, y2u = apply(np.linalg.inv(inverse_transform_matrix), x2, y2)

    # Apply correct transform
    x1c, y1c = apply(correct_transform_matrix, x1u, y1u)
    x2c, y2c = apply(correct_transform_matrix, x2u, y2u)

    return [x1c, y1c, x2c, y2c]


def correct_csv(filepath):
    try:
        df = pd.read_csv(filepath)
        if not {'x1', 'y1', 'x2', 'y2'}.issubset(df.columns):
            print(f"Skipping {filepath} — missing box columns.")
            return
        corrected = []
        for _, row in df.iterrows():
            if pd.isna(row['x1']) or pd.isna(row['y1']) or pd.isna(row['x2']) or pd.isna(row['y2']):
                corrected.append([np.nan]*4)
            else:
                corrected.append(apply_combined_transform(row['x1'], row['y1'], row['x2'], row['y2']))
        corrected = np.array(corrected)
        df[['x1', 'y1', 'x2', 'y2']] = corrected
        df.to_csv(filepath, index=False)
        print(f"✔ Corrected: {filepath}")
    except Exception as e:
        print(f"❌ Error in {filepath}: {e}")


for dirpath, _, filenames in os.walk(root_dir):
    for filename in filenames:
        if filename.endswith("cam11.csv"):
            csv_path = os.path.join(dirpath, filename)
            correct_csv(csv_path)


✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind1post\JM_ind1post_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind1pre\JM_ind1pre_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind2post\JM_ind2post_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind2pre\JM_ind2pre_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind4post\JM_ind4post_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_ind4pre\JM_ind4pre_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_pro1post\JM_pro1post_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_pro1pre\JM_pro1pre_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_pro2post\JM_pro2post_cam11.csv
✔ Corrected: D:\DeepSORT_ML2025\video_stitch\cam10_cam11_tracked\JM_pro2pre\JM_pro2pre_cam11.csv
✔ Corrected: D:\Deep

### Masking Videos for DLC with padding

In [None]:
import os
import cv2
import numpy as np
import csv

def draw_masked_boxes(video_path, csv_path, output_path_prefix, wiggle=7):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Cannot open video: {video_path}")
        return
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    bboxes_by_frame_id = {}
    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)
        if all(x in header for x in ['x1', 'y1', 'x2', 'y2']):
            idx_frame = header.index('frame')
            idx_id = header.index('id')
            idx_x1 = header.index('x1')
            idx_y1 = header.index('y1')
            idx_x2 = header.index('x2')
            idx_y2 = header.index('y2')
            for row in reader:
                if row[idx_x1] == '' or row[idx_y1] == '' or row[idx_x2] == '' or row[idx_y2] == '':
                    continue
                frame_i = int(row[idx_frame])
                id_i = int(row[idx_id])
                x1 = int(float(row[idx_x1])) - wiggle
                y1 = int(float(row[idx_y1])) - wiggle
                x2 = int(float(row[idx_x2])) + wiggle
                y2 = int(float(row[idx_y2])) + wiggle
                bboxes_by_frame_id.setdefault(frame_i, {}).setdefault(id_i, []).append([
                    max(0, x1), max(0, y1), min(width - 1, x2), min(height - 1, y2)
                ])
        else:
            idx_frame = header.index('frame')
            idx_id = header.index('id')
            idx_cx = header.index('cx')
            idx_cy = header.index('cy')
            bbox_w = 50 + 2 * wiggle
            bbox_h = 50 + 2 * wiggle
            for row in reader:
                if row[idx_cx] == '' or row[idx_cy] == '':
                    continue
                frame_i = int(row[idx_frame])
                id_i = int(row[idx_id])
                cx = float(row[idx_cx])
                cy = float(row[idx_cy])
                x1 = int(cx - bbox_w // 2)
                y1 = int(cy - bbox_h // 2)
                x2 = int(cx + bbox_w // 2)
                y2 = int(cy + bbox_h // 2)
                bboxes_by_frame_id.setdefault(frame_i, {}).setdefault(id_i, []).append([
                    max(0, x1), max(0, y1), min(width - 1, x2), min(height - 1, y2)
                ])

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out1 = cv2.VideoWriter(f'{output_path_prefix}_ID1.mp4', fourcc, fps, (width, height))
    out2 = cv2.VideoWriter(f'{output_path_prefix}_ID2.mp4', fourcc, fps, (width, height))

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        white_bg = np.ones_like(frame, dtype=np.uint8) * 255
        for bbox in bboxes_by_frame_id.get(frame_idx, {}).get(1, []):
            x1, y1, x2, y2 = bbox
            white_bg[y1:y2, x1:x2] = frame[y1:y2, x1:x2]
        out1.write(white_bg)

        white_bg = np.ones_like(frame, dtype=np.uint8) * 255
        for bbox in bboxes_by_frame_id.get(frame_idx, {}).get(2, []):
            x1, y1, x2, y2 = bbox
            white_bg[y1:y2, x1:x2] = frame[y1:y2, x1:x2]
        out2.write(white_bg)

        frame_idx += 1

    cap.release()
    out1.release()
    out2.release()
    print(f"Finished masked videos for {os.path.basename(video_path)}")

def video_name_to_csv_folder_cam15_18(video_name, cam_num):
    if cam_num == 15:
        base = video_name.replace('_cam15_', '_')
    elif cam_num == 18:
        base = video_name.replace('_cam18_', '_')
    else:
        base = video_name

    if base.endswith('_pre'):
        suffix = 'pre'
        base = base[:-4] + 'pre'
    elif base.endswith('_post'):
        suffix = 'post'
        base = base[:-5] + 'post'
    else:
        suffix = ''

    return f"{base}_from_{cam_num}_{suffix}"

def find_csv_folder_for_cam15_18(csv_root, video_name_wo_ext, cam_num):
    target_folder = video_name_to_csv_folder_cam15_18(video_name_wo_ext, cam_num)
    full_path = os.path.join(csv_root, target_folder)
    return full_path if os.path.exists(full_path) else None

def process_cam15_or_18(cam_dir, csv_root, output_root, cam_num):
    for video_filename in os.listdir(cam_dir):
        if not video_filename.endswith('.mp4'):
            continue
        name_wo_ext = os.path.splitext(video_filename)[0]
        video_path = os.path.join(cam_dir, video_filename)

        csv_folder = find_csv_folder_for_cam15_18(csv_root, name_wo_ext, cam_num)
        if csv_folder is None:
            print(f"No CSV folder found for {name_wo_ext} in {csv_root} for cam{cam_num}")
            continue

        csv_path = os.path.join(csv_folder, f"{name_wo_ext}_tracking.csv")
        out_folder = os.path.join(output_root, f"cam{cam_num}", os.path.basename(csv_folder))
        os.makedirs(out_folder, exist_ok=True)

        if os.path.exists(video_path) and os.path.exists(csv_path):
            print(f"Processing cam{cam_num} video: {video_path}")
            draw_masked_boxes(video_path, csv_path, os.path.join(out_folder, os.path.basename(csv_folder)))
        else:
            print(f"Missing cam{cam_num} video or CSV for {os.path.basename(csv_folder)}")
            print(f"  Expected video: {video_path}")
            print(f"  Expected CSV:   {csv_path}")

def main():
    root_videos = r"D:\DeepSORT_ML2025\video_stitch"
    cam15_csv_root = r"D:\cam15_tracked_with_guide"
    cam18_csv_root = r"D:\cam18_tracked_with_guide"
    output_root = r"D:\MaskedVideos_PerID_padded"

    print("Processing cam15...")
    process_cam15_or_18(os.path.join(root_videos, 'cam15'), cam15_csv_root, output_root, cam_num=15)

    print("Processing cam18...")
    process_cam15_or_18(os.path.join(root_videos, 'cam18'), cam18_csv_root, output_root, cam_num=18)

if __name__ == "__main__":
    main()


### Cubic Spline on centres of BBoxes

In [None]:
import os
import re
import pandas as pd
import numpy as np
from scipy.interpolate import CubicSpline

def interpolate_centers(csv_path, output_path):
    df = pd.read_csv(csv_path)

    required_cols = ['frame', 'id', 'x1', 'y1', 'x2', 'y2']
    for c in required_cols:
        if c not in df.columns:
            print(f"Skipping {csv_path} - missing column {c}")
            return

    df['cx'] = (df['x1'] + df['x2']) / 2
    df['cy'] = (df['y1'] + df['y2']) / 2

    unique_ids = df['id'].unique()

    if len(unique_ids) == 1:
        print(f"Only one ID in {csv_path}, skipping interpolation.")
        result_df = df[['frame', 'id', 'cx', 'cy']].copy()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        result_df.to_csv(output_path, index=False)
        print(f"Saved center coordinates (no interpolation) to {output_path}")
        return

    interpolated_rows = []
    for obj_id, group in df.groupby('id'):
        group = group.sort_values('frame')

        frames = group['frame'].values
        cx_vals = group['cx'].values
        cy_vals = group['cy'].values

        valid_mask = np.isfinite(cx_vals) & np.isfinite(cy_vals) & np.isfinite(frames)
        frames_valid = frames[valid_mask]
        cx_valid = cx_vals[valid_mask]
        cy_valid = cy_vals[valid_mask]

        if len(frames_valid) < 4:
            interpolated_rows.append(group[['frame', 'id', 'cx', 'cy']])
            continue

        full_frames = np.arange(frames_valid.min(), frames_valid.max() + 1)

        cs_x = CubicSpline(frames_valid, cx_valid, bc_type='natural')
        cs_y = CubicSpline(frames_valid, cy_valid, bc_type='natural')

        interp_cx = cs_x(full_frames)
        interp_cy = cs_y(full_frames)

        interp_df = pd.DataFrame({
            'frame': full_frames,
            'id': obj_id,
            'cx': interp_cx,
            'cy': interp_cy
        })

        interpolated_rows.append(interp_df)

    result_df = pd.concat(interpolated_rows, ignore_index=True)
    result_df = result_df.sort_values(['frame', 'id'])

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    result_df.to_csv(output_path, index=False)
    print(f"Saved interpolated centers to {output_path}")

def process_corrected_folder(corrected_folder_path):
    print(f"Processing folder: {corrected_folder_path}")

    # Determine parent directory and corrected_interpolated folder path
    parent_dir = os.path.dirname(corrected_folder_path)
    corrected_interpolated_folder = os.path.join(parent_dir, 'corrected_interpolated')

    # Create corrected_interpolated folder if it doesn't exist
    os.makedirs(corrected_interpolated_folder, exist_ok=True)

    # Process all CSV files inside corrected folder
    for fname in os.listdir(corrected_folder_path):
        if fname.lower().endswith('.csv'):
            input_csv = os.path.join(corrected_folder_path, fname)
            base_name = os.path.splitext(fname)[0]
            output_fname = f"{base_name}_centre_interpol.csv"
            output_csv = os.path.join(corrected_interpolated_folder, output_fname)

            interpolate_centers(input_csv, output_csv)

def walk_and_process(root_dir):
    # Walk recursively and look for 'corrected' folders
    for dirpath, dirnames, filenames in os.walk(root_dir):
        # Copy dirnames because we'll modify it in-place
        for dirname in dirnames:
            if dirname.lower() == 'corrected':
                corrected_path = os.path.join(dirpath, dirname)
                process_corrected_folder(corrected_path)

def main():
    root_dir = r"D:\DeepSORT_ML2025\video_stitch\sorted_csv"
    if not os.path.isdir(root_dir):
        print(f"Error: The path '{root_dir}' is not a valid directory.")
        return
    walk_and_process(root_dir)
    print("Processing complete.")

if __name__ == '__main__':
    main()


### Triangulation + Distance calculation

In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
from filterpy.kalman import KalmanFilter
import plotly.graph_objects as go
import plotly.io as pio
from scipy.interpolate import CubicSpline
from scipy.signal import savgol_filter
import scipy.stats

# Intrinsic Matrices 
K_cam10 = np.array([[2711.6, 19.5268, -68.165],
                    [0, 2668.7, 158.146],
                    [0, 0, 1]])

K_cam11 = np.array([[2609.9, -18.9653, 799.908],
                    [0, 2471.6, 420.996],
                    [0, 0, 1]])

K_cam18 = np.array([[3056.7, 83.7254, -718.479],
                    [0, 2412.1, 307.2416],
                    [0, 0, 1]])

K_cam15 = np.array([[2809.8, -27.3469, 1163.2],
                    [0, 2448.5, 364.079],
                    [0, 0, 1]])

#Extrinsics
R_10_18 = np.array([
    [0.0839, -0.0321, -0.5421],
    [0.0013, 0.9984, -0.0571],
    [0.5431, 0.0472, 0.8384]
])
T_10_18 = np.array([[1288.5], [-15.879], [148.983]])

R_11_15 = np.array([
    [0.8115, -0.2139, 0.5438],
    [0.1884, 0.9767, 0.1030],
    [-0.5531, 0.0189, 0.8329]
])
T_11_15 = np.array([[-1.143], [-119.637], [288.81]])

MAX_DISTANCE_FROM_ORIGIN = 5000
MAX_DEPTH = 10000

def kalman_smooth(df):
    smoothed = []
    for (cam_id, obj_id), group in df.groupby(['camera', 'id']):
        kf = KalmanFilter(dim_x=4, dim_z=2)
        kf.F = np.array([[1, 0, 1, 0],
                         [0, 1, 0, 1],
                         [0, 0, 1, 0],
                         [0, 0, 0, 1]])
        kf.H = np.array([[1, 0, 0, 0],
                         [0, 1, 0, 0]])
        kf.P *= 10.
        kf.R *= 1.
        kf.Q *= 0.01

        group = group.sort_values('frame')
        first = True
        for _, row in group.iterrows():
            if first:
                kf.x[:2] = np.array([[row['cx']], [row['cy']]])
                first = False
            kf.predict()
            kf.update([row['cx'], row['cy']])
            smoothed.append({**row.to_dict(), 'cx': kf.x[0, 0], 'cy': kf.x[1, 0]})
    return pd.DataFrame(smoothed)

def triangulate_with_reasons(df1, df2, K1, K2, R, T):
    df1 = df1.drop(columns=['camera'], errors='ignore')
    df2 = df2.drop(columns=['camera'], errors='ignore')
    df1 = df1.rename(columns={'cx': 'cx1', 'cy': 'cy1'})
    df2 = df2.rename(columns={'cx': 'cx2', 'cy': 'cy2'})
    merged = pd.merge(df1, df2, on=['frame', 'id'])
    P1 = K1 @ np.hstack((np.eye(3), np.zeros((3, 1))))
    P2 = K2 @ np.hstack((R, T))

    points_data = []
    for frame, group in merged.groupby("frame"):
        ids_in_frame = group["id"].unique()
        for obj_id in ids_in_frame:
            row = group[group["id"] == obj_id].iloc[0]
            pt1 = np.array([[row["cx1"]], [row["cy1"]]])
            pt2 = np.array([[row["cx2"]], [row["cy2"]]])
            point_4d = cv2.triangulatePoints(P1, P2, pt1, pt2)
            point_3d = (point_4d / point_4d[3])[:3].flatten()

            swapped = False
            reason = "Valid"
            if not np.isfinite(point_3d).all():
                reason = "NaN or infinite coordinate"
                if len(ids_in_frame) == 2:
                    other_id = [i for i in ids_in_frame if i != obj_id][0]
                    other_row = group[group["id"] == other_id].iloc[0]
                    point_4d_swap = cv2.triangulatePoints(
                        P1, P2,
                        np.array([[other_row["cx1"]], [other_row["cy1"]]]),
                        np.array([[row["cx2"]], [row["cy2"]]])
                    )
                    point_3d_swap = (point_4d_swap / point_4d_swap[3])[:3].flatten()
                    if np.isfinite(point_3d_swap).all() and np.all(np.abs(point_3d_swap) < MAX_DISTANCE_FROM_ORIGIN) and 0 <= point_3d_swap[2] <= MAX_DEPTH:
                        swapped = True
                        point_3d = point_3d_swap
                        reason = "ID swap applied due to invalid coordinate"
                    else:
                        point_3d = np.array([np.nan, np.nan, np.nan])
            elif np.any(np.abs(point_3d) > MAX_DISTANCE_FROM_ORIGIN):
                reason = "Coordinate out of physical bounds"
                point_3d = np.array([np.nan, np.nan, np.nan])
            elif point_3d[2] < 0 or point_3d[2] > MAX_DEPTH:
                reason = "Negative or too large depth"
                point_3d = np.array([np.nan, np.nan, np.nan])

            points_data.append({
                "frame": frame,
                "id": obj_id,
                "X": point_3d[0],
                "Y": point_3d[1],
                "Z": point_3d[2],
                "swapped": swapped,
                "reason": reason
            })

    return pd.DataFrame(points_data)

import os
import re
from difflib import get_close_matches

def extract_base_cam10_11(fname):
    base = os.path.basename(fname).lower()
    # Extract string before _cam10_ or _cam11_
    m = re.match(r'(.*)_cam1[01]_.*', base)
    if m:
        return m.group(1)
    return None

def extract_base_cam15_18(fname):
    base = os.path.basename(fname).lower()
    # Extract string before _cam15_ or _cam18_
    m = re.match(r'(.*)_cam1[58]_.*', base)
    if m:
        return m.group(1)
    return None

def find_matched_sets(dir_cam10, dir_cam11, dir_cam15, dir_cam18):
    files10 = os.listdir(dir_cam10)
    files11 = os.listdir(dir_cam11)
    files15 = os.listdir(dir_cam15)
    files18 = os.listdir(dir_cam18)

    bases_10 = {extract_base_cam10_11(f): f for f in files10 if extract_base_cam10_11(f)}
    bases_11 = {extract_base_cam10_11(f): f for f in files11 if extract_base_cam10_11(f)}
    bases_15 = {extract_base_cam15_18(f): f for f in files15 if extract_base_cam15_18(f)}
    bases_18 = {extract_base_cam15_18(f): f for f in files18 if extract_base_cam15_18(f)}

    matched = []

    # Match cam10 & cam11 bases exactly
    base_10_11 = set(bases_10.keys()) & set(bases_11.keys())

    # For each base in cam10_11, find closest match in cam15 and cam18 bases using fuzzy matching
    for base in base_10_11:
        match_15 = get_close_matches(base, bases_15.keys(), n=1, cutoff=0.7)
        match_18 = get_close_matches(base, bases_18.keys(), n=1, cutoff=0.7)

        if match_15 and match_18:
            matched.append((
                os.path.join(dir_cam10, bases_10[base]),
                os.path.join(dir_cam11, bases_11[base]),
                os.path.join(dir_cam15, bases_15[match_15[0]]),
                os.path.join(dir_cam18, bases_18[match_18[0]]),
                base
            ))

    return matched

def main():
    base_dir = r"D:\DeepSORT_ML2025\video_stitch\sorted_csv"
    combinations = [
        "ind_pre",
        "ind_post",
        "rew_pre",
        "rew_post",
        "pro_pre",
        "pro_post"
    ]

    output_dir_root = r"D:\DeepSORT_ML2025\video_stitch\3d_transform"

    for comb in combinations:
        print(f"\nProcessing condition: {comb}")

        dir_cam10 = os.path.join(base_dir, comb, "cam10", "corrected_interpolated")
        dir_cam11 = os.path.join(base_dir, comb, "cam11", "corrected_interpolated")
        dir_cam15 = os.path.join(base_dir, comb, "cam15", "corrected_interpolated")
        dir_cam18 = os.path.join(base_dir, comb, "cam18", "corrected_interpolated")

        output_dir = os.path.join(output_dir_root, comb)
        os.makedirs(output_dir, exist_ok=True)

        matched_sets = find_matched_sets(dir_cam10, dir_cam11, dir_cam15, dir_cam18)
        print(f"Found {len(matched_sets)} matched sets for {comb}.")

        all_valid_dfs = []

        for file10, file11, file15, file18, base_name in matched_sets:
            print(f"Processing set: {base_name}")

            df_cam10 = pd.read_csv(file10)
            df_cam11 = pd.read_csv(file11)
            df_cam15 = pd.read_csv(file15)
            df_cam18 = pd.read_csv(file18)

            df_cam10['camera'] = 10
            df_cam11['camera'] = 11
            df_cam15['camera'] = 15
            df_cam18['camera'] = 18

            tri_left = triangulate_with_reasons(df_cam10, df_cam18, K_cam10, K_cam18, R_10_18, T_10_18)
            if tri_left.empty or 'X' not in tri_left.columns:
                print(f"Warning: triangulation left returned empty or no 'X' column for {base_name}")
                continue 

            tri_right = triangulate_with_reasons(df_cam11, df_cam15, K_cam11, K_cam15, R_11_15, T_11_15)
            if tri_right.empty or 'X' not in tri_right.columns:
                print(f"Warning: triangulation right returned empty or no 'X' column for {base_name}")
                continue 

            # Now safe to do this
            tri_right["X"] += T_10_18[0][0]


            tri_left.to_csv(os.path.join(output_dir, f"triangulated_left_{base_name}.csv"), index=False)
            tri_right.to_csv(os.path.join(output_dir, f"triangulated_right_{base_name}.csv"), index=False)

            all_points = pd.concat([tri_left, tri_right], axis=0)
            pivoted = all_points.pivot_table(
                index='frame', columns='id', values=['X', 'Y', 'Z', 'reason', 'swapped'], aggfunc='first')
            pivoted.columns = ['{}_{}'.format(col[0], int(col[1])) for col in pivoted.columns]
            pivoted = pivoted.reset_index()

            frames = pivoted['frame'].values

            distance_records = []
            for _, row in pivoted.iterrows():
                frame = row['frame']
                valid_1 = all([pd.notna(row.get('X_1')), pd.notna(row.get('Y_1')), pd.notna(row.get('Z_1')),
                               row.get('reason_1', '') in ['Valid', 'ID swap applied due to invalid coordinate']])
                valid_2 = all([pd.notna(row.get('X_2')), pd.notna(row.get('Y_2')), pd.notna(row.get('Z_2')),
                               row.get('reason_2', '') in ['Valid', 'ID swap applied due to invalid coordinate']])

                if not valid_1 or not valid_2:
                    reasons = []
                    if not valid_1:
                        reasons.append(f"ID 1 invalid or missing ({row.get('reason_1', 'NA')})")
                    if not valid_2:
                        reasons.append(f"ID 2 invalid or missing ({row.get('reason_2', 'NA')})")
                    distance_records.append([frame, np.nan, "; ".join(reasons)])
                    continue

                p1 = np.array([row['X_1'], row['Y_1'], row['Z_1']])
                p2 = np.array([row['X_2'], row['Y_2'], row['Z_2']])
                dist = np.linalg.norm(p1 - p2)
                reason = "Distance zero (suspicious)" if dist == 0 else "Valid"
                distance_records.append([frame, dist, reason])

            dist_df = pd.DataFrame(distance_records, columns=['frame', 'distance_1_2', 'distance_reason'])
            dist_df.to_csv(os.path.join(output_dir, f'interindividual_distances_with_reasons_{base_name}.csv'), index=False)

            valid = dist_df[dist_df["distance_reason"] == "Valid"][["frame", "distance_1_2"]].copy()
            valid["trial"] = base_name
            all_valid_dfs.append(valid)

        # Plot summary for this condition
        if all_valid_dfs:
            combined = pd.concat(all_valid_dfs)

            grouped = combined.groupby('frame')['distance_1_2']
            frames_list = []
            means = []
            lowers = []
            uppers = []

            for frame, values in grouped:
                vals = values.dropna().values
                if len(vals) > 1:
                    mean = np.mean(vals)
                    sem = scipy.stats.sem(vals)
                    ci_range = sem * scipy.stats.t.ppf((1 + 0.95) / 2., len(vals)-1)
                    lower = mean - ci_range
                    upper = mean + ci_range
                elif len(vals) == 1:
                    mean = vals[0]
                    lower = mean
                    upper = mean
                else:
                    mean = np.nan
                    lower = np.nan
                    upper = np.nan

                frames_list.append(frame)
                means.append(mean)
                lowers.append(lower)
                uppers.append(upper)

            window = 15 if len(means) >= 15 else (len(means) // 2) * 2 + 1
            smoothed_mean = savgol_filter(means, window_length=window, polyorder=2, mode='interp')
            smoothed_lower = savgol_filter(lowers, window_length=window, polyorder=2, mode='interp')
            smoothed_upper = savgol_filter(uppers, window_length=window, polyorder=2, mode='interp')

            fig = go.Figure()

            fig.add_trace(go.Scatter(
                x=frames_list + frames_list[::-1],
                y=list(smoothed_upper) + list(smoothed_lower[::-1]),
                fill='toself',
                fillcolor='rgba(0,0,0,0.2)',
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=True,
                name="95% Confidence Interval"
            ))

            fig.add_trace(go.Scatter(
                x=frames_list,
                y=smoothed_mean,
                mode='lines',
                line=dict(color='black', width=3),
                name='Smoothed Mean Distance',
                hovertemplate="Frame %{x}<br>Distance %{y:.2f}<extra></extra>"
            ))

            fig.update_layout(
                title=f"Inter-individual Distance Across Trials (Smoothed Mean ± 95% CI) - {comb}",
                xaxis_title="Frame",
                yaxis_title="Distance between ID 1 and ID 2",
                template="simple_white",
                height=500,
                width=950,
                showlegend=True,
                yaxis=dict(range=[-500, 3000])  # fixed y-axis 0 to 1000
            )

            pio.show(fig)
        else:
            print(f"No valid data to plot for {comb}.")

if __name__ == "__main__":
    main()

