# Thermal Counting Training

Rough training pipeline for gathering and labeling grouped bounding boxes.

> Note: If there is a tkinter error, run ```sudo apt install python3-tk -y```

In [9]:
import importlib  # Refreshing imports
import cv2
import supervision as sv
import numpy as np
from ultralytics import YOLO
from utils.thermal_frame_to_temp import result_to_temp_frame
import utils.group_bounding_boxes as gbb
import validate_bounding_box as vbb
import tkinter as tk
from tkinter.filedialog import askopenfilename, askdirectory
# Better exception handling and helpers
import traceback
import pprint
import datetime
import os
import time

# Refresh imports to see live changes vs cached ones
importlib.reload(gbb)
importlib.reload(vbb)

<module 'validate_bounding_box' from '/mnt/c/Users/JJ/Desktop/Repos/Chick-Counting/thermal/validate_bounding_box.py'>

## Helpers for saving the results (cropped bounding box results)

In [10]:
SAVE_DIR = "grouped_bounding_box_crops"  # Directory to save the crops

def save_crop(frame: np.ndarray, box: tuple[int,int,int,int], frame_count: int, group_id: int) -> None:
    """Safely crop and save each combined bounding box to SAVE_DIR with a unique name."""
    
    os.makedirs(SAVE_DIR, exist_ok=True)
    h, w = frame.shape[:2]
    x1, y1, x2, y2 = box
    
    # Clip to frame
    x1 = max(0, min(x1, w-1)); x2 = max(0, min(x2, w-1))
    y1 = max(0, min(y1, h-1)); y2 = max(0, min(y2, h-1))
    
    # Validate the coordinates
    if x2 <= x1 or y2 <= y1:
        return
    
    # Perform the crop and save
    crop = frame[y1:y2, x1:x2].copy()
    ts = int(time.time() * 1000)
    out_path = os.path.join(SAVE_DIR, f"f{frame_count}_g{group_id}_{ts}.jpg")
    cv2.imwrite(out_path, crop)

## Run the YOLO model and gather the results (currently unoptimized, POC)

In [15]:
FRAME_COUNT_EARLY_STOP = 1000  # For testing, limit to first N frames

def get_line_from_video_frame(frame):
    frame_height, frame_width = frame.shape[:2]

    # Draw a horizontal line across the middle of the frame
    line_start = (frame_width, frame_height // 2)
    line_end = (0, frame_height // 2)
    return [line_start, line_end]

def chick_counting(video_path, line_points):

    # Grab a sample frame so we know video size
    generator = sv.get_video_frames_generator(video_path)
    frame = next(generator)

    # Set up video writer with same FPS/size as input
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()

    # Init tracker and helpers
    byte_tracker = sv.ByteTrack()
    trace_annotator = sv.TraceAnnotator(thickness=4, trace_length=50)

    # Create the counting line
    line_zone = sv.LineZone(start=sv.Point(*line_points[0]), end=sv.Point(*line_points[1]))

    # Load custom YOLO model (trained on chicks only)
    model = YOLO("models/new_iron.pt")

    frame_count = 0
    total_count = 0
    all_counted_ids = set()  # keep track of already-counted trackers

    try:
        generator = sv.get_video_frames_generator(video_path)

        for frame in generator:
            frame_count += 1
            if frame_count > FRAME_COUNT_EARLY_STOP:
                break
            
            print(f"Processing frame {frame_count}")

            # Run YOLO on frame
            results = model(frame)[0]

            # Convert results to supervision Detections
            detections = sv.Detections.from_ultralytics(results)

            # Update tracker with detections
            detections = byte_tracker.update_with_detections(detections)
            print("Tracker IDs this frame:", detections.tracker_id)

            # See if any trackers crossed the line
            crossed_in_flags, crossed_out_flags = line_zone.trigger(detections)

            ''' Additional training logic for grouping bounding boxes and saving crops '''
            # Find groups that contain any box overlapping the detection that crossed "in"
            xyxy_np = detections.xyxy.astype(float)
            groups = gbb.group_bounding_boxes(xyxy_np)  # Default, low threshold for now
            
            # Finding which groups correspond to crossed "in" boxes
            crossed_i = {i for i, crossed in enumerate(crossed_in_flags) if crossed}
            groups_to_save = []
            for gid, g in enumerate(groups):
                if any(idx in crossed_i for idx in g):
                    groups_to_save.append((gid, g))
                    
            # Merge and save the grouped boxes for each group
            for gid, g in groups_to_save:
                # Skip empty groups
                if not g:
                    continue
                # Get merged box and save crop
                merged_box = gbb.merge_group_bounding_box(xyxy_np, g)
                save_crop(frame, merged_box, frame_count, gid)
            
            # Only count new IDs that cross "in"
            for i, crossed in enumerate(crossed_in_flags):
                if crossed:
                    tracker_id = detections.tracker_id[i]
                    if tracker_id is not None and tracker_id not in all_counted_ids:
                        total_count += 1
                        all_counted_ids.add(tracker_id)
                        print(f"New Chick crossed the line! ID {tracker_id}, Total count: {total_count}")

            # Sensitivity for declaring a box as "nested"
            # e.g. 0.9 means inner must have at least 90% of its area inside outer
            NESTED_THRESHOLD = 0.9  

            contained_indices = set()
            boxes = detections.xyxy

            for i, outer in enumerate(boxes):
                x1o, y1o, x2o, y2o = outer
                outer_area = max(0, (x2o - x1o)) * max(0, (y2o - y1o))

                for j, inner in enumerate(boxes):
                    if i == j:
                        continue
                    x1i, y1i, x2i, y2i = inner
                    inner_area = max(0, (x2i - x1i)) * max(0, (y2i - y1i))

                    # Intersection box
                    inter_x1 = max(x1o, x1i)
                    inter_y1 = max(y1o, y1i)
                    inter_x2 = min(x2o, x2i)
                    inter_y2 = min(y2o, y2i)

                    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

                    # Ratio of inner covered by outer
                    if inner_area > 0 and (inter_area / inner_area) >= NESTED_THRESHOLD:
                        contained_indices.add(j)


            # Assign labels + colors depending on nesting
            labels = []
            colors = []
            for i, tracker_id in enumerate(detections.tracker_id):
                if i in contained_indices:
                    labels.append(f"#{tracker_id} nested")
                    colors.append(sv.Color.RED)
                else:
                    labels.append(f"#{tracker_id} chick")
                    colors.append(sv.Color.GREEN)

            

    except Exception as e:
        # Detailed exception logging
        print("=== Exception while processing video frames ===")
        print("Time:", datetime.datetime.now().isoformat())
        print("Exception type:", type(e).__name__)
        print("Exception message:", str(e))
        print("Full traceback:")
        print(traceback.format_exc())

    finally:
        print(f"Processing complete. Processed {frame_count} frames.")
        print(f"Final total count: {total_count}")

if __name__ == "__main__":
    tk.Tk().withdraw()

    # Pick input video + output folder with file dialogs
    SOURCE_VIDEO_PATH = askopenfilename()
    print("User chose:", SOURCE_VIDEO_PATH)
    
    # Grab a frame to define the line
    cap = cv2.VideoCapture(SOURCE_VIDEO_PATH)
    ret, frame = cap.read()
    if not ret:
        print("Failed to read the video")
        exit()
    cap.release()
    
    line_points = get_line_from_video_frame(frame)
    
    chick_counting(SOURCE_VIDEO_PATH, line_points)
    
    print(f"Completed attempted processing of {FRAME_COUNT_EARLY_STOP} frames.")

User chose: /mnt/c/Users/JJ/Desktop/Repos/Chick-Counting/data/Brennen's-Thermal-Video/3-Part Mid Belt(Iron) 02.mp4
Processing frame 1

0: 640x480 21 Chicks, 12.7ms
Speed: 1.8ms preprocess, 12.7ms inference, 22.8ms postprocess per image at shape (1, 3, 640, 480)
Tracker IDs this frame: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20]
Processing frame 2

0: 640x480 23 Chicks, 10.8ms
Speed: 1.4ms preprocess, 10.8ms inference, 26.3ms postprocess per image at shape (1, 3, 640, 480)
Tracker IDs this frame: [ 1  3  6 10  9  7  4  2 15 14  8 13 19 12 20]
Processing frame 3

0: 640x480 20 Chicks, 10.8ms
Speed: 1.4ms preprocess, 10.8ms inference, 33.4ms postprocess per image at shape (1, 3, 640, 480)
Tracker IDs this frame: [ 1  3 10  6  9  4 15  7 12 26 19 11]
Processing frame 4

0: 640x480 19 Chicks, 10.8ms
Speed: 1.5ms preprocess, 10.8ms inference, 20.6ms postprocess per image at shape (1, 3, 640, 480)
Tracker IDs this frame: [10  1  3  4  7 13 15  9 12 27  6 28]
Processing frame