<a href="https://colab.research.google.com/github/Kushagra3219/CNN-Models/blob/main/Marine_Survelillance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Marine Surveillance Pipeline using YOLOv8 (EO + IR fusion, tracking, size estimation)

- EO stream: visible camera / video file
- IR stream: thermal camera / video file
- Detection: YOLOv8
- Fusion: late fusion between EO & IR detections
- Tracking: simple IOU-based tracker with linear motion model
- Size estimation: monocular geometry (needs camera calibration + tuning)

NOTE:
- This is a TEMPLATE / REFERENCE implementation.
- You MUST plug in your real camera parameters, paths, and possibly replace the tracker with DeepSORT/ByteTrack for production.
"""

import cv2
import numpy as np
from ultralytics import YOLO
import time
import math
from collections import deque

# =========================
# CONFIGURATION
# =========================

# Paths to EO and IR video (or use camera indices like 0, 1)
EO_SOURCE = "eo_video.mp4"   # change to 0 for webcam
IR_SOURCE = "ir_video.mp4"   # or another source

# YOLO model path (use a pretrained or your fine-tuned marine model)
YOLO_MODEL_PATH = "yolov8n.pt"  # change to your marine-trained weights

# Detection settings
CONF_THRES = 0.3
IOU_THRES_NMS = 0.5

# Tracker settings
MAX_TRACK_AGE = 20   # frames to keep "lost" tracks
IOU_MATCH_THRESH = 0.3

# Camera geometry (example values: YOU MUST CALIBRATE THESE)
CAMERA_HEIGHT_M = 15.0        # height of camera above sea level (meters)
FOCAL_LENGTH_PIXELS = 1200.0  # approx focal length in pixels from calibration
HORIZON_ROW = 200             # y-pixel coordinate of horizon line in EO image
PRINCIPAL_POINT = (640, 360)  # (cx, cy) for 1280x720 camera

FRAME_WIDTH = 1280
FRAME_HEIGHT = 720

# Optional: labels (if your YOLO model is custom)
CLASS_NAMES = None  # or e.g. ["cargo", "tanker", "fishing", "small_boat", ...]

# Output recording
WRITE_OUTPUT = False
OUTPUT_PATH = "output_eo_annotated.mp4"


# =========================
# UTILS
# =========================

def iou(box1, box2):
    """
    box: [x1, y1, x2, y2]
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    if x2 <= x1 or y2 <= y1:
        return 0.0

    inter_area = (x2 - x1) * (y2 - y1)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    return inter_area / float(box1_area + box2_area - inter_area + 1e-6)


def draw_box_with_label(img, box, label, color=(0, 255, 0), thickness=2):
    x1, y1, x2, y2 = [int(v) for v in box]
    cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
    if label:
        cv2.rectangle(img, (x1, y1 - 20), (x1 + 200, y1), color, -1)
        cv2.putText(img, label, (x1 + 2, y1 - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)


# =========================
# SIMPLE IOU-BASED TRACKER
# =========================

class Track:
    def __init__(self, track_id, box, cls, conf, frame_idx):
        self.track_id = track_id
        self.box = box  # [x1, y1, x2, y2]
        self.cls = cls
        self.conf = conf
        self.last_seen = frame_idx
        self.age = 0  # how many frames alive
        self.history = deque(maxlen=30)  # store centers for trajectory

    def update(self, box, conf, frame_idx):
        self.box = box
        self.conf = conf
        self.last_seen = frame_idx
        self.age += 1
        cx = 0.5 * (box[0] + box[2])
        cy = 0.5 * (box[1] + box[3])
        self.history.append((cx, cy))


class SimpleTracker:
    def __init__(self, iou_match_thresh=0.3, max_age=20):
        self.iou_match_thresh = iou_match_thresh
        self.max_age = max_age
        self.tracks = []
        self.next_id = 1

    def update(self, detections, frame_idx):
        """
        detections: list of dicts {box, cls, conf}
        returns list of Track objects (active)
        """
        # Step 1: match existing tracks to new detections using IoU
        unmatched_tracks = set(range(len(self.tracks)))
        unmatched_dets = set(range(len(detections)))
        matches = []

        for ti, track in enumerate(self.tracks):
            best_iou = 0.0
            best_di = None
            for di, det in enumerate(detections):
                if di not in unmatched_dets:
                    continue
                iou_val = iou(track.box, det["box"])
                if iou_val > best_iou:
                    best_iou = iou_val
                    best_di = di
            if best_di is not None and best_iou >= self.iou_match_thresh:
                matches.append((ti, best_di))
                unmatched_tracks.discard(ti)
                unmatched_dets.discard(best_di)

        # Step 2: update matched tracks
        for ti, di in matches:
            det = detections[di]
            self.tracks[ti].update(det["box"], det["conf"], frame_idx)
            self.tracks[ti].cls = det["cls"]

        # Step 3: create new tracks for unmatched detections
        for di in unmatched_dets:
            det = detections[di]
            new_track = Track(
                self.next_id, det["box"], det["cls"], det["conf"], frame_idx
            )
            new_track.update(det["box"], det["conf"], frame_idx)
            self.tracks.append(new_track)
            self.next_id += 1

        # Step 4: remove old tracks
        active_tracks = []
        for t in self.tracks:
            if frame_idx - t.last_seen <= self.max_age:
                active_tracks.append(t)
        self.tracks = active_tracks

        return self.tracks


# =========================
# SIZE & DISTANCE ESTIMATION
# =========================

def estimate_distance_from_horizon(y_bottom, camera_height=CAMERA_HEIGHT_M,
                                   focal_pixels=FOCAL_LENGTH_PIXELS,
                                   horizon_row=HORIZON_ROW):
    """
    Very rough distance estimation using sea horizon geometry:
    - Assumes flat sea surface.
    - y_bottom: bottom of vessel bounding box (in pixels, from top).

    Derived from simple pinhole model:
    Z = (h * f) / (v_h - v)

    Needs calibration and may be unstable close to horizon.
    """
    v = y_bottom
    vh = horizon_row
    if v <= vh:  # above horizon? invalid
        return None

    Z = (camera_height * focal_pixels) / float(v - vh + 1e-6)
    if Z < 0:
        return None
    return Z  # meters (approx)


def estimate_length_from_box(box, distance_m,
                             focal_pixels=FOCAL_LENGTH_PIXELS,
                             principal_point=PRINCIPAL_POINT):
    """
    Estimate vessel physical length from bounding box width in pixels:
    L_real = (box_width_pixels * Z) / f

    Very approximate; you should improve using segmentation mask,
    and better calibration.
    """
    if distance_m is None:
        return None

    x1, y1, x2, y2 = box
    width_px = max(1.0, (x2 - x1))
    L_real = (width_px * distance_m) / float(focal_pixels)
    return L_real


# =========================
# YOLOv8 DETECTION WRAPPER
# =========================

class YOLODetector:
    def __init__(self, model_path, conf_thres=0.3, iou_thres=0.5):
        self.model = YOLO(model_path)
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres

    def detect(self, frame_bgr):
        """
        Run YOLOv8 on a single BGR frame.
        Returns list of detection dicts: {box, conf, cls}
        """
        # ultralytics expects RGB or BGR; it can handle BGR directly
        results = self.model.predict(
            source=frame_bgr,
            conf=self.conf_thres,
            iou=self.iou_thres,
            verbose=False
        )
        dets = []
        if len(results) == 0:
            return dets

        r = results[0]
        if r.boxes is None or len(r.boxes) == 0:
            return dets

        boxes = r.boxes.xyxy.cpu().numpy()
        scores = r.boxes.conf.cpu().numpy()
        classes = r.boxes.cls.cpu().numpy().astype(int)

        for box, score, cls_id in zip(boxes, scores, classes):
            dets.append({
                "box": box.tolist(),
                "conf": float(score),
                "cls": int(cls_id)
            })
        return dets


# =========================
# FUSION (EO + IR LATE FUSION)
# =========================

def fuse_detections_late(EO_dets, IR_dets, iou_thresh=0.5):
    """
    Simple late fusion:
    - If EO and IR boxes overlap (IoU > iou_thresh), average their scores.
    - Otherwise, keep them as independent detections.

    In practice, you might weight EO more in day and IR more at night,
    and also enforce class consistency.
    """
    fused = []
    used_ir = set()

    # Match EO with IR
    for eo in EO_dets:
        best_iou = 0.0
        best_ir_idx = None
        for i, ir in enumerate(IR_dets):
            if i in used_ir:
                continue
            val = iou(eo["box"], ir["box"])
            if val > best_iou:
                best_iou = val
                best_ir_idx = i
        if best_ir_idx is not None and best_iou >= iou_thresh:
            ir = IR_dets[best_ir_idx]
            used_ir.add(best_ir_idx)
            # simple score fusion: mean of confidences
            fused_conf = 0.5 * (eo["conf"] + ir["conf"])
            fused_cls = eo["cls"]  # or choose by higher conf
            fused_box = eo["box"]  # or average coordinates
            fused.append({
                "box": fused_box,
                "conf": fused_conf,
                "cls": fused_cls
            })
        else:
            fused.append(eo)

    # Add IR-only detections (not matched)
    for i, ir in enumerate(IR_dets):
        if i not in used_ir:
            fused.append(ir)

    return fused


# =========================
# MAIN LOOP
# =========================

def main():
    # Video capture
    eo_cap = cv2.VideoCapture(EO_SOURCE)
    ir_cap = cv2.VideoCapture(IR_SOURCE)

    if not eo_cap.isOpened():
        print("Error: Cannot open EO source")
        return
    if not ir_cap.isOpened():
        print("Warning: Cannot open IR source (will run EO only)")

    # Optional set resolution
    eo_cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
    eo_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)

    # YOLO detector
    detector_eo = YOLODetector(YOLO_MODEL_PATH, CONF_THRES, IOU_THRES_NMS)
    detector_ir = YOLODetector(YOLO_MODEL_PATH, CONF_THRES, IOU_THRES_NMS)

    # Tracker
    tracker = SimpleTracker(IOU_MATCH_THRESH, MAX_TRACK_AGE)

    # Output writer
    writer = None
    if WRITE_OUTPUT:
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        writer = cv2.VideoWriter(OUTPUT_PATH, fourcc, 20.0,
                                 (FRAME_WIDTH, FRAME_HEIGHT))

    frame_idx = 0
    fps_time = time.time()

    while True:
        ret_eo, frame_eo = eo_cap.read()
        ret_ir, frame_ir = ir_cap.read() if ir_cap.isOpened() else (False, None)

        if not ret_eo:
            print("EO stream ended.")
            break

        frame_idx += 1

        # Resize if needed
        frame_eo = cv2.resize(frame_eo, (FRAME_WIDTH, FRAME_HEIGHT))
        if ret_ir and frame_ir is not None:
            frame_ir = cv2.resize(frame_ir, (FRAME_WIDTH, FRAME_HEIGHT))

        # Detection on EO and IR
        eo_dets = detector_eo.detect(frame_eo)
        ir_dets = detector_ir.detect(frame_ir) if ret_ir and frame_ir is not None else []

        # Fuse detections
        fused_dets = fuse_detections_late(eo_dets, ir_dets, iou_thresh=0.5)

        # Tracking
        tracks = tracker.update(fused_dets, frame_idx)

        # Draw detections & tracking info + size estimation
        for t in tracks:
            box = t.box
            cls_id = t.cls
            conf = t.conf

            # Estimate distance from bottom of box
            x1, y1, x2, y2 = [int(v) for v in box]
            y_bottom = y2
            distance_m = estimate_distance_from_horizon(
                y_bottom,
                camera_height=CAMERA_HEIGHT_M,
                focal_pixels=FOCAL_LENGTH_PIXELS,
                horizon_row=HORIZON_ROW
            )
            length_m = estimate_length_from_box(
                box,
                distance_m,
                focal_pixels=FOCAL_LENGTH_PIXELS,
                principal_point=PRINCIPAL_POINT
            )

            # Build label
            if CLASS_NAMES is not None and 0 <= cls_id < len(CLASS_NAMES):
                cls_name = CLASS_NAMES[cls_id]
            else:
                cls_name = f"cls{cls_id}"

            dist_str = f"{distance_m:.1f}m" if distance_m is not None else "?"
            len_str = f"{length_m:.1f}m" if length_m is not None else "?"

            label = f"ID {t.track_id} | {cls_name} | {conf:.2f} | D={dist_str} | L={len_str}"
            draw_box_with_label(frame_eo, box, label, color=(0, 255, 0))

            # Draw trajectory
            if len(t.history) >= 2:
                for i in range(1, len(t.history)):
                    cv2.line(
                        frame_eo,
                        (int(t.history[i-1][0]), int(t.history[i-1][1])),
                        (int(t.history[i][0]), int(t.history[i][1])),
                        (255, 0, 0),
                        2
                    )

        # FPS display
        now = time.time()
        fps = 1.0 / (now - fps_time)
        fps_time = now
        cv2.putText(frame_eo, f"FPS: {fps:.1f}", (10, 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)

        cv2.imshow("Marine Surveillance - EO (Fused & Tracked)", frame_eo)
        if writer is not None:
            writer.write(frame_eo)

        key = cv2.waitKey(1) & 0xFF
        if key == 27 or key == ord('q'):
            break

    eo_cap.release()
    if ir_cap.isOpened():
        ir_cap.release()
    if writer is not None:
        writer.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
