# Robust Human Detection in UAV Imagery
## HIT-UAV Infrared Thermal Dataset - Baseline vs Augmented Comparison

**Experiment Design:**
- **Model A**: Trained on clean/normal data only
- **Model B**: Trained with SAR augmentations (snow, smoke/fire, thermal artifacts)
- **Evaluation**: Compare both on clean and perturbed test sets

**Dataset**: HIT-UAV from Kaggle (thermal infrared UAV imagery)

---

## Cell 1: Environment Setup

In [None]:
# =============================================================================
# CELL 1: ENVIRONMENT SETUP
# =============================================================================

import subprocess
import sys
import os

def install_packages():
    """Install required packages."""
    packages = [
        'torch', 'torchvision', 'albumentations>=1.3.0', 'pycocotools',
        'opencv-python-headless', 'matplotlib', 'numpy', 'Pillow',
        'tqdm', 'scipy', 'kaggle'
    ]
    for pkg in packages:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
    print("Packages installed")

install_packages()

# Mount Google Drive (Colab) or use local cache
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_ROOT = '/content/drive/MyDrive/uav_detection'
    IN_COLAB = True
except ImportError:
    DRIVE_ROOT = './uav_detection_cache'
    IN_COLAB = False

os.makedirs(DRIVE_ROOT, exist_ok=True)
os.makedirs(f"{DRIVE_ROOT}/data", exist_ok=True)
os.makedirs(f"{DRIVE_ROOT}/checkpoints", exist_ok=True)
os.makedirs(f"{DRIVE_ROOT}/outputs", exist_ok=True)

print(f"Cache directory: {DRIVE_ROOT}")

## Cell 2: Imports and Configuration

In [None]:
# =============================================================================
# CELL 2: IMPORTS AND CONFIGURATION
# =============================================================================

import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
import shutil
import zipfile
import copy
from pathlib import Path
from tqdm.auto import tqdm
import random
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Configuration
class Config:
    # Paths
    DATA_ROOT = f"{DRIVE_ROOT}/data/hit_uav"
    CURATED_ROOT = f"{DRIVE_ROOT}/data/curated"
    CHECKPOINT_DIR = f"{DRIVE_ROOT}/checkpoints"
    OUTPUT_DIR = f"{DRIVE_ROOT}/outputs"
    
    # Image settings
    IMG_SIZE = 512
    
    # Training settings
    BATCH_SIZE = 4
    NUM_EPOCHS = 6
    LR = 0.005
    LR_STEP_SIZE = 3
    LR_GAMMA = 0.1
    WEIGHT_DECAY = 0.0005
    
    # Detection settings
    NUM_CLASSES = 2  # background + person
    IOU_THRESHOLD = 0.5
    CONF_THRESHOLD = 0.5
    
    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SEED = 42

# Set seeds
torch.manual_seed(Config.SEED)
np.random.seed(Config.SEED)
random.seed(Config.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(Config.SEED)

print(f"Device: {Config.DEVICE}")
print(f"Image size: {Config.IMG_SIZE}, Epochs: {Config.NUM_EPOCHS}, Batch: {Config.BATCH_SIZE}")

## Cell 3: Download HIT-UAV Dataset from Kaggle

In [None]:
# =============================================================================
# CELL 3: DOWNLOAD HIT-UAV DATASET FROM KAGGLE
# =============================================================================

def download_hituav_kaggle(data_root: str) -> bool:
    """
    Download HIT-UAV dataset from Kaggle.
    Requires Kaggle API credentials (~/.kaggle/kaggle.json or KAGGLE_USERNAME/KAGGLE_KEY env vars)
    """
    data_root = Path(data_root)
    data_root.mkdir(parents=True, exist_ok=True)
    
    # Check if already downloaded
    images_dir = data_root / "normal" / "images"
    if images_dir.exists() and len(list(images_dir.glob("*.jpg"))) > 100:
        print(f"Dataset already exists at {data_root}")
        return True
    
    zip_path = data_root / "hituav.zip"
    
    # Method 1: Try curl download
    print("Downloading HIT-UAV dataset from Kaggle...")
    try:
        import subprocess
        result = subprocess.run([
            'curl', '-L', '-o', str(zip_path),
            'https://www.kaggle.com/api/v1/datasets/download/pandrii000/hituav-a-highaltitude-infrared-thermal-dataset'
        ], capture_output=True, timeout=600)
        
        if zip_path.exists() and zip_path.stat().st_size > 1000000:  # >1MB
            print(f"Downloaded to {zip_path}")
        else:
            raise Exception("Download failed or file too small")
            
    except Exception as e:
        print(f"Curl download failed: {e}")
        
        # Method 2: Try kaggle API
        try:
            import kaggle
            kaggle.api.dataset_download_files(
                'pandrii000/hituav-a-highaltitude-infrared-thermal-dataset',
                path=str(data_root),
                unzip=False
            )
            # Find the downloaded zip
            for f in data_root.glob("*.zip"):
                zip_path = f
                break
        except Exception as e2:
            print(f"Kaggle API failed: {e2}")
            print("\nPlease download manually:")
            print("1. Go to: https://www.kaggle.com/datasets/pandrii000/hituav-a-highaltitude-infrared-thermal-dataset")
            print("2. Download and extract to:", data_root)
            return False
    
    # Extract
    if zip_path.exists():
        print("Extracting dataset...")
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                zf.extractall(data_root)
            print(f"Extracted to {data_root}")
            zip_path.unlink()  # Remove zip
            return True
        except Exception as e:
            print(f"Extraction failed: {e}")
            return False
    
    return False


def explore_dataset_structure(data_root: str):
    """Print the dataset structure to understand it."""
    data_root = Path(data_root)
    print(f"\nDataset structure at {data_root}:")
    
    for item in sorted(data_root.rglob("*")):
        if item.is_dir():
            files = list(item.glob("*"))
            depth = len(item.relative_to(data_root).parts)
            indent = "  " * depth
            print(f"{indent}{item.name}/ ({len(files)} items)")
            # Show sample files
            if len(files) > 0 and files[0].is_file():
                print(f"{indent}  Sample: {files[0].name}")


# Download dataset
download_success = download_hituav_kaggle(Config.DATA_ROOT)

if download_success:
    explore_dataset_structure(Config.DATA_ROOT)

## Cell 4: Convert HIT-UAV to COCO Format (Person Only)

In [None]:
# =============================================================================
# CELL 4: CONVERT HIT-UAV TO COCO FORMAT
# HIT-UAV uses YOLO format: class_id cx cy w h (normalized)
# We convert to COCO: {images: [], annotations: [], categories: []}
# =============================================================================

def convert_hituav_to_coco(data_root: str, output_path: str, target_size: int = 512) -> Tuple[Path, Path]:
    """
    Convert HIT-UAV YOLO format to COCO JSON format.
    Only keeps 'Person' class (class_id=0 in HIT-UAV).
    
    HIT-UAV classes: 0=Person, 1=Car, 2=Bicycle, 3=OtherVehicle, 4=DontCare
    """
    data_root = Path(data_root)
    output_path = Path(output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    images_out = output_path / "images"
    images_out.mkdir(exist_ok=True)
    
    # Find the images and labels directories
    # HIT-UAV structure: normal/images/, normal/labels/, or train/images, etc.
    possible_paths = [
        (data_root / "normal" / "images", data_root / "normal" / "labels"),
        (data_root / "train" / "images", data_root / "train" / "labels"),
        (data_root / "images", data_root / "labels"),
    ]
    
    images_dir = None
    labels_dir = None
    
    for img_dir, lbl_dir in possible_paths:
        if img_dir.exists():
            images_dir = img_dir
            labels_dir = lbl_dir
            print(f"Found images at: {images_dir}")
            print(f"Found labels at: {labels_dir}")
            break
    
    if images_dir is None:
        # List what we have
        print("Could not find standard structure. Contents:")
        for item in data_root.iterdir():
            print(f"  {item}")
        raise FileNotFoundError(f"Cannot find images directory in {data_root}")
    
    # COCO format
    coco = {
        "images": [],
        "annotations": [],
        "categories": [{"id": 1, "name": "person", "supercategory": "human"}]
    }
    
    PERSON_CLASS = 0  # HIT-UAV person class
    ann_id = 1
    img_id = 1
    
    # Get all image files
    image_files = sorted(list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")))
    print(f"\nFound {len(image_files)} images")
    
    images_with_persons = 0
    total_persons = 0
    
    for img_path in tqdm(image_files, desc="Converting"):
        # Read image to get dimensions
        img = cv2.imread(str(img_path))
        if img is None:
            continue
        
        orig_h, orig_w = img.shape[:2]
        
        # Look for corresponding label file
        label_path = labels_dir / f"{img_path.stem}.txt"
        
        person_annotations = []
        
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        cls_id = int(parts[0])
                        if cls_id == PERSON_CLASS:
                            # YOLO format: class cx cy w h (normalized 0-1)
                            cx, cy, bw, bh = map(float, parts[1:5])
                            
                            # Convert to pixel coordinates
                            x = (cx - bw / 2) * orig_w
                            y = (cy - bh / 2) * orig_h
                            w = bw * orig_w
                            h = bh * orig_h
                            
                            # Clip to image bounds
                            x = max(0, x)
                            y = max(0, y)
                            w = min(w, orig_w - x)
                            h = min(h, orig_h - y)
                            
                            if w > 5 and h > 5:  # Skip tiny boxes
                                person_annotations.append([x, y, w, h])
        
        # Only include images with person annotations
        if len(person_annotations) == 0:
            continue
        
        images_with_persons += 1
        
        # Resize image
        scale_x = target_size / orig_w
        scale_y = target_size / orig_h
        img_resized = cv2.resize(img, (target_size, target_size))
        
        # Save resized image
        new_filename = f"hituav_{img_id:05d}.jpg"
        cv2.imwrite(str(images_out / new_filename), img_resized)
        
        # Add image info
        coco["images"].append({
            "id": img_id,
            "file_name": new_filename,
            "width": target_size,
            "height": target_size,
            "original_file": img_path.name
        })
        
        # Add scaled annotations
        for (x, y, w, h) in person_annotations:
            # Scale to new size
            x_scaled = x * scale_x
            y_scaled = y * scale_y
            w_scaled = w * scale_x
            h_scaled = h * scale_y
            
            # Skip if too small after scaling
            if w_scaled < 8 or h_scaled < 8:
                continue
            
            coco["annotations"].append({
                "id": ann_id,
                "image_id": img_id,
                "category_id": 1,
                "bbox": [x_scaled, y_scaled, w_scaled, h_scaled],  # COCO format: x, y, w, h
                "area": w_scaled * h_scaled,
                "iscrowd": 0
            })
            ann_id += 1
            total_persons += 1
        
        img_id += 1
    
    # Save annotations
    ann_path = output_path / "annotations.json"
    with open(ann_path, 'w') as f:
        json.dump(coco, f)
    
    print(f"\nConversion complete:")
    print(f"  Images with persons: {images_with_persons}")
    print(f"  Total person annotations: {total_persons}")
    print(f"  Avg persons per image: {total_persons/max(images_with_persons,1):.1f}")
    print(f"  Saved to: {output_path}")
    
    return images_out, ann_path


# Convert dataset
print("Converting HIT-UAV to COCO format...")
IMAGES_DIR, ANNOTATIONS_PATH = convert_hituav_to_coco(
    Config.DATA_ROOT,
    Config.CURATED_ROOT,
    target_size=Config.IMG_SIZE
)

## Cell 5: SAR Augmentations

In [None]:
# =============================================================================
# CELL 5: SAR AUGMENTATIONS
# Snow, Smoke/Fire, Thermal artifacts for robustness
# =============================================================================

class SARaugmentations:
    """Realistic augmentations for SAR drone imagery."""
    
    @staticmethod
    def generate_perlin_noise(shape: Tuple[int, int], scale: float = 100.0) -> np.ndarray:
        """Generate Perlin-like noise using octaves of Gaussian noise."""
        h, w = shape
        noise = np.zeros((h, w), dtype=np.float32)
        
        for octave in range(4):
            freq = 2 ** octave
            amplitude = 1.0 / freq
            small_h = max(2, int(h / (scale / freq)))
            small_w = max(2, int(w / (scale / freq)))
            small_noise = np.random.randn(small_h, small_w).astype(np.float32)
            upscaled = cv2.resize(small_noise, (w, h), interpolation=cv2.INTER_CUBIC)
            noise += amplitude * upscaled
        
        noise = (noise - noise.min()) / (noise.max() - noise.min() + 1e-8)
        return noise
    
    @staticmethod
    def apply_snow(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
        """Apply realistic snow effect."""
        h, w = img.shape[:2]
        is_color = len(img.shape) == 3
        
        snow_noise = SARaugmentations.generate_perlin_noise((h, w), scale=50.0)
        fine_noise = np.random.rand(h, w).astype(np.float32)
        fine_noise = cv2.GaussianBlur(fine_noise, (5, 5), 0)
        
        snow_layer = 0.6 * snow_noise + 0.4 * fine_noise
        snow_layer = np.clip(snow_layer * intensity * 255, 0, 255).astype(np.uint8)
        
        if is_color:
            snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_GRAY2BGR)
        
        img_float = img.astype(np.float32)
        mean_val = np.mean(img_float)
        contrast_reduction = 0.3
        img_reduced = (1 - contrast_reduction) * img_float + contrast_reduction * mean_val
        
        alpha = intensity * 0.7
        result = (1 - alpha) * img_reduced + alpha * snow_layer.astype(np.float32)
        return np.clip(result, 0, 255).astype(np.uint8)
    
    @staticmethod
    def apply_smoke_fire(img: np.ndarray, smoke_intensity: float = 0.4, 
                         fire_intensity: float = 0.3) -> np.ndarray:
        """Apply smoke and fire effects."""
        h, w = img.shape[:2]
        if len(img.shape) != 3:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        
        result = img.astype(np.float32)
        
        # Smoke
        smoke_noise = SARaugmentations.generate_perlin_noise((h, w), scale=80.0)
        gradient = np.linspace(1.0, 0.3, h).reshape(-1, 1)
        gradient = np.tile(gradient, (1, w))
        smoke_mask = smoke_noise * gradient
        
        smoke_color = np.array([180, 180, 180], dtype=np.float32)
        smoke_layer = np.ones((h, w, 3), dtype=np.float32) * smoke_color
        smoke_layer = cv2.GaussianBlur(smoke_layer, (21, 21), 0)
        
        smoke_alpha = smoke_mask[..., np.newaxis] * smoke_intensity
        result = result * (1 - smoke_alpha) + smoke_layer * smoke_alpha
        
        # Fire
        if fire_intensity > 0:
            fx, fy = np.random.randint(w//4, 3*w//4), np.random.randint(h//2, h)
            y_coords, x_coords = np.ogrid[:h, :w]
            dist = np.sqrt((x_coords - fx)**2 + (y_coords - fy)**2)
            fire_radius = min(h, w) // 3
            fire_mask = np.clip(1 - dist / fire_radius, 0, 1) ** 2
            
            fire_color = np.array([30, 100, 255], dtype=np.float32)  # Orange BGR
            fire_layer = np.ones((h, w, 3), dtype=np.float32) * fire_color
            
            fire_alpha = fire_mask[..., np.newaxis] * fire_intensity
            result = result * (1 - fire_alpha) + fire_layer * fire_alpha
        
        return np.clip(result, 0, 255).astype(np.uint8)
    
    @staticmethod
    def apply_thermal_artifacts(img: np.ndarray, intensity_scale: float = 1.0,
                                sensor_noise: float = 0.05) -> np.ndarray:
        """Apply thermal camera artifacts."""
        h, w = img.shape[:2]
        
        if len(img.shape) == 3:
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        else:
            gray = img.copy()
        
        result = gray.astype(np.float32) * intensity_scale
        
        # Sensor noise
        noise = np.random.normal(0, sensor_noise * 255, (h, w)).astype(np.float32)
        if np.random.rand() < 0.3:
            line_noise = np.random.normal(0, sensor_noise * 50, (h, 1))
            noise += np.tile(line_noise, (1, w))
        
        result += noise
        result = np.clip(result, 0, 255).astype(np.uint8)
        return cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)
    
    @staticmethod
    def apply_random(img: np.ndarray) -> Tuple[np.ndarray, str]:
        """Apply random SAR augmentation."""
        aug_type = np.random.choice(['snow', 'fire', 'thermal', 'none'])
        
        if aug_type == 'snow':
            return SARaugmentations.apply_snow(img, np.random.uniform(0.3, 0.6)), 'snow'
        elif aug_type == 'fire':
            return SARaugmentations.apply_smoke_fire(img, np.random.uniform(0.2, 0.4),
                                                     np.random.uniform(0.2, 0.4)), 'fire'
        elif aug_type == 'thermal':
            return SARaugmentations.apply_thermal_artifacts(img, np.random.uniform(0.8, 1.2),
                                                            np.random.uniform(0.03, 0.08)), 'thermal'
        return img, 'none'


# Visualize augmentations
sample_imgs = list(IMAGES_DIR.glob("*.jpg"))[:1]
if sample_imgs:
    sample = cv2.imread(str(sample_imgs[0]))
    
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    axes[0,0].imshow(cv2.cvtColor(sample, cv2.COLOR_BGR2RGB))
    axes[0,0].set_title('Original')
    axes[0,0].axis('off')
    
    axes[0,1].imshow(cv2.cvtColor(SARaugmentations.apply_snow(sample, 0.5), cv2.COLOR_BGR2RGB))
    axes[0,1].set_title('Snow')
    axes[0,1].axis('off')
    
    axes[1,0].imshow(cv2.cvtColor(SARaugmentations.apply_smoke_fire(sample, 0.4, 0.4), cv2.COLOR_BGR2RGB))
    axes[1,0].set_title('Smoke/Fire')
    axes[1,0].axis('off')
    
    axes[1,1].imshow(cv2.cvtColor(SARaugmentations.apply_thermal_artifacts(sample, 1.1, 0.06), cv2.COLOR_BGR2RGB))
    axes[1,1].set_title('Thermal Artifacts')
    axes[1,1].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{Config.OUTPUT_DIR}/augmentations.png", dpi=150)
    plt.show()

## Cell 6: Dataset Class with Proper Box Handling

In [None]:
# =============================================================================
# CELL 6: DATASET CLASS
# Proper handling of box formats to avoid evaluation bugs
# =============================================================================

class UAVDetectionDataset(Dataset):
    """
    Dataset for UAV person detection with proper box format handling.
    
    IMPORTANT: Boxes in COCO are [x, y, width, height]
    Faster R-CNN expects [x1, y1, x2, y2] (pascal_voc format)
    """
    
    def __init__(self, images_dir, annotations_path, transforms=None,
                 apply_sar_aug=False, sar_aug_prob=0.5):
        self.images_dir = Path(images_dir)
        self.transforms = transforms
        self.apply_sar_aug = apply_sar_aug
        self.sar_aug_prob = sar_aug_prob
        
        with open(annotations_path, 'r') as f:
            coco = json.load(f)
        
        self.images = {img['id']: img for img in coco['images']}
        
        # Group annotations by image
        self.img_to_anns = defaultdict(list)
        for ann in coco['annotations']:
            self.img_to_anns[ann['image_id']].append(ann)
        
        # Only keep images WITH annotations
        self.img_ids = [img_id for img_id in self.images.keys()
                        if len(self.img_to_anns[img_id]) > 0]
        
        print(f"Dataset: {len(self.img_ids)} images with annotations")
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.images[img_id]
        
        # Load image
        img_path = self.images_dir / img_info['file_name']
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Apply SAR augmentation
        if self.apply_sar_aug and random.random() < self.sar_aug_prob:
            img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img_aug, _ = SARaugmentations.apply_random(img_bgr)
            img = cv2.cvtColor(img_aug, cv2.COLOR_BGR2RGB)
        
        # Get annotations and convert boxes
        anns = self.img_to_anns[img_id]
        
        boxes = []
        labels = []
        areas = []
        
        for ann in anns:
            # COCO format: [x, y, width, height]
            x, y, w, h = ann['bbox']
            
            # Convert to pascal_voc: [x1, y1, x2, y2]
            x1 = x
            y1 = y
            x2 = x + w
            y2 = y + h
            
            # Validate box
            if x2 > x1 and y2 > y1:
                boxes.append([x1, y1, x2, y2])
                labels.append(1)  # Person class = 1
                areas.append(ann['area'])
        
        # Convert image to tensor
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        
        # Apply additional transforms (normalization, etc.)
        if self.transforms:
            img_tensor = self.transforms(img_tensor)
        
        # Create target dict
        if len(boxes) > 0:
            boxes_tensor = torch.as_tensor(boxes, dtype=torch.float32)
            labels_tensor = torch.as_tensor(labels, dtype=torch.int64)
            areas_tensor = torch.as_tensor(areas, dtype=torch.float32)
        else:
            boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)
            labels_tensor = torch.zeros((0,), dtype=torch.int64)
            areas_tensor = torch.zeros((0,), dtype=torch.float32)
        
        target = {
            'boxes': boxes_tensor,
            'labels': labels_tensor,
            'image_id': torch.tensor([img_id]),
            'area': areas_tensor,
            'iscrowd': torch.zeros(len(boxes), dtype=torch.int64)
        }
        
        return img_tensor, target


def collate_fn(batch):
    """Custom collate for detection."""
    return tuple(zip(*batch))


# Create train/val/test split (70/15/15)
with open(ANNOTATIONS_PATH, 'r') as f:
    full_coco = json.load(f)

all_images = full_coco['images'].copy()
random.shuffle(all_images)

n = len(all_images)
train_end = int(0.7 * n)
val_end = int(0.85 * n)

train_images = all_images[:train_end]
val_images = all_images[train_end:val_end]
test_images = all_images[val_end:]

# Create ID sets
train_ids = set(img['id'] for img in train_images)
val_ids = set(img['id'] for img in val_images)
test_ids = set(img['id'] for img in test_images)

# Split annotations
train_anns = [a for a in full_coco['annotations'] if a['image_id'] in train_ids]
val_anns = [a for a in full_coco['annotations'] if a['image_id'] in val_ids]
test_anns = [a for a in full_coco['annotations'] if a['image_id'] in test_ids]

# Save splits
splits = {
    'train': (train_images, train_anns),
    'val': (val_images, val_anns),
    'test': (test_images, test_anns)
}

for split_name, (images, anns) in splits.items():
    split_coco = {
        'images': images,
        'annotations': anns,
        'categories': full_coco['categories']
    }
    with open(Path(Config.CURATED_ROOT) / f"{split_name}.json", 'w') as f:
        json.dump(split_coco, f)
    print(f"{split_name}: {len(images)} images, {len(anns)} annotations")

TRAIN_ANN = Path(Config.CURATED_ROOT) / "train.json"
VAL_ANN = Path(Config.CURATED_ROOT) / "val.json"
TEST_ANN = Path(Config.CURATED_ROOT) / "test.json"

## Cell 7: Corrected Evaluation Function

In [None]:
# =============================================================================
# CELL 7: CORRECTED EVALUATION FUNCTION
# Properly computes precision, recall, F1 with detailed debugging
# =============================================================================

def compute_iou(box1, box2):
    """
    Compute IoU between two boxes in [x1, y1, x2, y2] format.
    """
    # Intersection
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    inter_w = max(0, x2 - x1)
    inter_h = max(0, y2 - y1)
    inter_area = inter_w * inter_h
    
    # Union
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = area1 + area2 - inter_area
    
    if union_area <= 0:
        return 0.0
    
    return inter_area / union_area


@torch.no_grad()
def evaluate_model(model, data_loader, device, iou_thresh=0.5, conf_thresh=0.5, verbose=False):
    """
    Evaluate detection model with proper metric computation.
    
    Returns precision, recall, F1 at given IoU and confidence thresholds.
    """
    model.eval()
    
    total_tp = 0
    total_fp = 0
    total_fn = 0
    total_gt = 0
    total_pred = 0
    
    for batch_idx, (images, targets) in enumerate(tqdm(data_loader, desc="Evaluating", disable=not verbose)):
        images = [img.to(device) for img in images]
        
        # Get predictions
        outputs = model(images)
        
        for output, target in zip(outputs, targets):
            # Get ground truth boxes (already in [x1, y1, x2, y2] format from dataset)
            gt_boxes = target['boxes'].cpu().numpy()
            total_gt += len(gt_boxes)
            
            # Filter predictions by confidence AND class (person = 1)
            scores = output['scores'].cpu().numpy()
            pred_boxes = output['boxes'].cpu().numpy()
            pred_labels = output['labels'].cpu().numpy()
            
            # Only keep person predictions above threshold
            mask = (scores >= conf_thresh) & (pred_labels == 1)
            pred_boxes = pred_boxes[mask]
            pred_scores = scores[mask]
            total_pred += len(pred_boxes)
            
            if len(gt_boxes) == 0:
                # All predictions are false positives
                total_fp += len(pred_boxes)
                continue
            
            if len(pred_boxes) == 0:
                # All ground truths are missed
                total_fn += len(gt_boxes)
                continue
            
            # Match predictions to ground truth (greedy matching by score)
            # Sort predictions by score (descending)
            sorted_indices = np.argsort(-pred_scores)
            matched_gt = set()
            
            for pred_idx in sorted_indices:
                pred_box = pred_boxes[pred_idx]
                
                best_iou = 0
                best_gt_idx = -1
                
                for gt_idx, gt_box in enumerate(gt_boxes):
                    if gt_idx in matched_gt:
                        continue
                    
                    iou = compute_iou(pred_box, gt_box)
                    if iou > best_iou:
                        best_iou = iou
                        best_gt_idx = gt_idx
                
                if best_iou >= iou_thresh and best_gt_idx >= 0:
                    total_tp += 1
                    matched_gt.add(best_gt_idx)
                else:
                    total_fp += 1
            
            # Unmatched ground truths are false negatives
            total_fn += len(gt_boxes) - len(matched_gt)
    
    # Compute metrics
    precision = total_tp / max(total_tp + total_fp, 1)
    recall = total_tp / max(total_tp + total_fn, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-8)
    
    metrics = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'tp': total_tp,
        'fp': total_fp,
        'fn': total_fn,
        'total_gt': total_gt,
        'total_pred': total_pred
    }
    
    if verbose:
        print(f"\nEvaluation Results (IoU={iou_thresh}, Conf={conf_thresh}):")
        print(f"  GT boxes: {total_gt}, Predictions: {total_pred}")
        print(f"  TP: {total_tp}, FP: {total_fp}, FN: {total_fn}")
        print(f"  Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    
    return metrics


print("Evaluation function defined")

## Cell 8: Model Creation Function

In [None]:
# =============================================================================
# CELL 8: MODEL CREATION
# =============================================================================

def create_detection_model(num_classes=2, pretrained=True, freeze_backbone=True):
    """
    Create Faster R-CNN model for person detection.
    
    Args:
        num_classes: 2 (background + person)
        pretrained: Use COCO pretrained weights
        freeze_backbone: Freeze early layers to prevent overfitting
    """
    if pretrained:
        weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
        model = fasterrcnn_resnet50_fpn(weights=weights)
    else:
        model = fasterrcnn_resnet50_fpn(weights=None)
    
    # Replace the classifier head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    if freeze_backbone:
        # Freeze backbone except layer4 and FPN
        for name, param in model.named_parameters():
            if 'backbone' in name:
                if 'layer4' not in name and 'fpn' not in name:
                    param.requires_grad = False
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Model: {trainable:,} / {total:,} trainable params ({100*trainable/total:.1f}%)")
    
    return model


# ============================================
# YOLOV8 ALTERNATIVE:
# from ultralytics import YOLO
# model = YOLO('yolov8n.pt')
# model.train(data='data.yaml', epochs=6, imgsz=512)
# ============================================

print("Model creation function defined")

## Cell 9: Training Function

In [None]:
# =============================================================================
# CELL 9: TRAINING FUNCTION
# =============================================================================

def train_model(model, train_loader, val_loader, device, num_epochs, lr,
                checkpoint_prefix="model", lr_step=3, lr_gamma=0.1):
    """
    Train the detection model and track metrics.
    """
    model.to(device)
    
    # Optimizer - only train unfrozen params
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=Config.WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    
    history = {
        'train_loss': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': []
    }
    
    best_f1 = 0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        epoch_loss = 0
        num_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, targets in pbar:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            # Skip batches with no valid targets
            valid_targets = [t for t in targets if len(t['boxes']) > 0]
            if len(valid_targets) == 0:
                continue
            
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            losses.backward()
            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
            optimizer.step()
            
            epoch_loss += losses.item()
            num_batches += 1
            pbar.set_postfix({'loss': f"{losses.item():.4f}"})
        
        avg_loss = epoch_loss / max(num_batches, 1)
        history['train_loss'].append(avg_loss)
        
        # Validation
        metrics = evaluate_model(model, val_loader, device,
                                 iou_thresh=Config.IOU_THRESHOLD,
                                 conf_thresh=Config.CONF_THRESHOLD)
        
        history['val_precision'].append(metrics['precision'])
        history['val_recall'].append(metrics['recall'])
        history['val_f1'].append(metrics['f1'])
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, "
              f"P={metrics['precision']:.4f}, R={metrics['recall']:.4f}, F1={metrics['f1']:.4f} "
              f"(TP={metrics['tp']}, FP={metrics['fp']}, FN={metrics['fn']})")
        
        # Save best model
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'metrics': metrics,
                'history': history
            }, f"{Config.CHECKPOINT_DIR}/{checkpoint_prefix}_best.pth")
            print(f"  -> Saved best model (F1={best_f1:.4f})")
    
    return model, history, best_f1


print("Training function defined")

## Cell 10: Train Model A (Baseline - No SAR Augmentation)

In [None]:
# =============================================================================
# CELL 10: TRAIN MODEL A - BASELINE (NO AUGMENTATION)
# =============================================================================

print("="*60)
print("TRAINING MODEL A: BASELINE (No SAR Augmentation)")
print("="*60)

# Create datasets WITHOUT SAR augmentation
train_dataset_baseline = UAVDetectionDataset(
    IMAGES_DIR, TRAIN_ANN,
    apply_sar_aug=False
)

val_dataset = UAVDetectionDataset(
    IMAGES_DIR, VAL_ANN,
    apply_sar_aug=False
)

train_loader_baseline = DataLoader(
    train_dataset_baseline,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn
)

# Create and train model A
model_A = create_detection_model(
    num_classes=Config.NUM_CLASSES,
    pretrained=True,
    freeze_backbone=True
)

model_A, history_A, best_f1_A = train_model(
    model_A, train_loader_baseline, val_loader, Config.DEVICE,
    num_epochs=Config.NUM_EPOCHS,
    lr=Config.LR,
    checkpoint_prefix="model_A_baseline",
    lr_step=Config.LR_STEP_SIZE,
    lr_gamma=Config.LR_GAMMA
)

print(f"\nModel A Best F1: {best_f1_A:.4f}")

## Cell 11: Train Model B (With SAR Augmentation)

In [None]:
# =============================================================================
# CELL 11: TRAIN MODEL B - WITH SAR AUGMENTATION
# =============================================================================

print("="*60)
print("TRAINING MODEL B: WITH SAR AUGMENTATION")
print("="*60)

# Create dataset WITH SAR augmentation
train_dataset_augmented = UAVDetectionDataset(
    IMAGES_DIR, TRAIN_ANN,
    apply_sar_aug=True,
    sar_aug_prob=0.5  # 50% chance of augmentation
)

train_loader_augmented = DataLoader(
    train_dataset_augmented,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

# Create and train model B (fresh model)
model_B = create_detection_model(
    num_classes=Config.NUM_CLASSES,
    pretrained=True,
    freeze_backbone=True
)

model_B, history_B, best_f1_B = train_model(
    model_B, train_loader_augmented, val_loader, Config.DEVICE,
    num_epochs=Config.NUM_EPOCHS,
    lr=Config.LR,
    checkpoint_prefix="model_B_augmented",
    lr_step=Config.LR_STEP_SIZE,
    lr_gamma=Config.LR_GAMMA
)

print(f"\nModel B Best F1: {best_f1_B:.4f}")

## Cell 12: Create Perturbed Test Set

In [None]:
# =============================================================================
# CELL 12: CREATE PERTURBED TEST SET
# =============================================================================

print("Creating perturbed test set...")

perturbed_dir = Path(Config.CURATED_ROOT) / "perturbed_test"
perturbed_dir.mkdir(exist_ok=True)

with open(TEST_ANN, 'r') as f:
    test_coco = json.load(f)

for img_info in tqdm(test_coco['images'], desc="Perturbing test images"):
    img_path = IMAGES_DIR / img_info['file_name']
    img = cv2.imread(str(img_path))
    
    if img is None:
        continue
    
    # Apply random perturbation (snow or fire)
    aug_type = random.choice(['snow', 'fire'])
    if aug_type == 'snow':
        perturbed = SARaugmentations.apply_snow(img, random.uniform(0.4, 0.6))
    else:
        perturbed = SARaugmentations.apply_smoke_fire(img, random.uniform(0.3, 0.5),
                                                       random.uniform(0.3, 0.5))
    
    cv2.imwrite(str(perturbed_dir / img_info['file_name']), perturbed)

# Save annotations (same as test)
PERTURBED_TEST_ANN = perturbed_dir / "annotations.json"
with open(PERTURBED_TEST_ANN, 'w') as f:
    json.dump(test_coco, f)

print(f"Perturbed test set saved to {perturbed_dir}")

## Cell 13: Final Comparison - Both Models on Clean & Perturbed Test Sets

In [None]:
# =============================================================================
# CELL 13: FINAL COMPARISON
# Evaluate both models on clean and perturbed test sets
# =============================================================================

print("="*70)
print("FINAL COMPARISON: MODEL A (Baseline) vs MODEL B (Augmented)")
print("="*70)

# Load best checkpoints
ckpt_A = torch.load(f"{Config.CHECKPOINT_DIR}/model_A_baseline_best.pth", map_location=Config.DEVICE)
model_A.load_state_dict(ckpt_A['model_state_dict'])

ckpt_B = torch.load(f"{Config.CHECKPOINT_DIR}/model_B_augmented_best.pth", map_location=Config.DEVICE)
model_B.load_state_dict(ckpt_B['model_state_dict'])

# Create test datasets
test_dataset_clean = UAVDetectionDataset(IMAGES_DIR, TEST_ANN, apply_sar_aug=False)
test_dataset_perturbed = UAVDetectionDataset(perturbed_dir, PERTURBED_TEST_ANN, apply_sar_aug=False)

test_loader_clean = DataLoader(test_dataset_clean, batch_size=Config.BATCH_SIZE,
                               shuffle=False, num_workers=2, collate_fn=collate_fn)
test_loader_perturbed = DataLoader(test_dataset_perturbed, batch_size=Config.BATCH_SIZE,
                                   shuffle=False, num_workers=2, collate_fn=collate_fn)

# Evaluate Model A
print("\n--- Model A (Baseline) ---")
print("On CLEAN test set:")
metrics_A_clean = evaluate_model(model_A, test_loader_clean, Config.DEVICE,
                                  iou_thresh=Config.IOU_THRESHOLD,
                                  conf_thresh=Config.CONF_THRESHOLD, verbose=True)

print("\nOn PERTURBED test set:")
metrics_A_perturbed = evaluate_model(model_A, test_loader_perturbed, Config.DEVICE,
                                      iou_thresh=Config.IOU_THRESHOLD,
                                      conf_thresh=Config.CONF_THRESHOLD, verbose=True)

# Evaluate Model B
print("\n--- Model B (Augmented) ---")
print("On CLEAN test set:")
metrics_B_clean = evaluate_model(model_B, test_loader_clean, Config.DEVICE,
                                  iou_thresh=Config.IOU_THRESHOLD,
                                  conf_thresh=Config.CONF_THRESHOLD, verbose=True)

print("\nOn PERTURBED test set:")
metrics_B_perturbed = evaluate_model(model_B, test_loader_perturbed, Config.DEVICE,
                                      iou_thresh=Config.IOU_THRESHOLD,
                                      conf_thresh=Config.CONF_THRESHOLD, verbose=True)

# Summary Table
print("\n" + "="*70)
print("SUMMARY TABLE")
print("="*70)
print(f"{'Model':<20} {'Test Set':<15} {'Precision':<12} {'Recall':<12} {'F1':<12}")
print("-"*70)
print(f"{'A (Baseline)':<20} {'Clean':<15} {metrics_A_clean['precision']:.4f}       {metrics_A_clean['recall']:.4f}       {metrics_A_clean['f1']:.4f}")
print(f"{'A (Baseline)':<20} {'Perturbed':<15} {metrics_A_perturbed['precision']:.4f}       {metrics_A_perturbed['recall']:.4f}       {metrics_A_perturbed['f1']:.4f}")
print(f"{'B (Augmented)':<20} {'Clean':<15} {metrics_B_clean['precision']:.4f}       {metrics_B_clean['recall']:.4f}       {metrics_B_clean['f1']:.4f}")
print(f"{'B (Augmented)':<20} {'Perturbed':<15} {metrics_B_perturbed['precision']:.4f}       {metrics_B_perturbed['recall']:.4f}       {metrics_B_perturbed['f1']:.4f}")
print("="*70)

# Robustness analysis
drop_A = (metrics_A_clean['f1'] - metrics_A_perturbed['f1']) / max(metrics_A_clean['f1'], 1e-8) * 100
drop_B = (metrics_B_clean['f1'] - metrics_B_perturbed['f1']) / max(metrics_B_clean['f1'], 1e-8) * 100

print(f"\nRobustness Analysis:")
print(f"  Model A F1 drop on perturbed: {drop_A:.1f}%")
print(f"  Model B F1 drop on perturbed: {drop_B:.1f}%")
print(f"  Robustness improvement: {drop_A - drop_B:.1f}%")

if drop_B < drop_A:
    print("\n-> Model B (with SAR augmentation) is MORE ROBUST to adverse conditions!")
else:
    print("\n-> Augmentation did not improve robustness (may need tuning)")

## Cell 14: Visualization

In [None]:
# =============================================================================
# CELL 14: VISUALIZATION
# =============================================================================

# Training curves comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Training loss
axes[0, 0].plot(history_A['train_loss'], 'b-o', label='Model A (Baseline)')
axes[0, 0].plot(history_B['train_loss'], 'r-s', label='Model B (Augmented)')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# F1 Score
axes[0, 1].plot(history_A['val_f1'], 'b-o', label='Model A (Baseline)')
axes[0, 1].plot(history_B['val_f1'], 'r-s', label='Model B (Augmented)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Validation F1')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Precision
axes[1, 0].plot(history_A['val_precision'], 'b-o', label='Model A (Baseline)')
axes[1, 0].plot(history_B['val_precision'], 'r-s', label='Model B (Augmented)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('Validation Precision')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Recall
axes[1, 1].plot(history_A['val_recall'], 'b-o', label='Model A (Baseline)')
axes[1, 1].plot(history_B['val_recall'], 'r-s', label='Model B (Augmented)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].set_title('Validation Recall')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(f"{Config.OUTPUT_DIR}/training_comparison.png", dpi=150)
plt.show()

# Final metrics bar chart
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(3)
width = 0.2

metrics_labels = ['Precision', 'Recall', 'F1']
a_clean = [metrics_A_clean['precision'], metrics_A_clean['recall'], metrics_A_clean['f1']]
a_perturbed = [metrics_A_perturbed['precision'], metrics_A_perturbed['recall'], metrics_A_perturbed['f1']]
b_clean = [metrics_B_clean['precision'], metrics_B_clean['recall'], metrics_B_clean['f1']]
b_perturbed = [metrics_B_perturbed['precision'], metrics_B_perturbed['recall'], metrics_B_perturbed['f1']]

bars1 = ax.bar(x - 1.5*width, a_clean, width, label='A-Clean', color='blue', alpha=0.8)
bars2 = ax.bar(x - 0.5*width, a_perturbed, width, label='A-Perturbed', color='blue', alpha=0.4)
bars3 = ax.bar(x + 0.5*width, b_clean, width, label='B-Clean', color='red', alpha=0.8)
bars4 = ax.bar(x + 1.5*width, b_perturbed, width, label='B-Perturbed', color='red', alpha=0.4)

ax.set_ylabel('Score')
ax.set_title('Model Comparison: Clean vs Perturbed Test Sets')
ax.set_xticks(x)
ax.set_xticklabels(metrics_labels)
ax.legend()
ax.set_ylim(0, 1.0)
ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(f"{Config.OUTPUT_DIR}/metrics_comparison.png", dpi=150)
plt.show()

print(f"\nVisualizations saved to {Config.OUTPUT_DIR}/")

## Cell 15: Sample Predictions Visualization

In [None]:
# =============================================================================
# CELL 15: SAMPLE PREDICTIONS
# =============================================================================

def visualize_detections(model, images_dir, annotations_path, title, num_samples=4):
    """Visualize model detections on sample images."""
    model.eval()
    
    with open(annotations_path, 'r') as f:
        coco = json.load(f)
    
    img_to_anns = defaultdict(list)
    for ann in coco['annotations']:
        img_to_anns[ann['image_id']].append(ann)
    
    samples = random.sample(coco['images'], min(num_samples, len(coco['images'])))
    
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    
    for i, img_info in enumerate(samples):
        img_path = images_dir / img_info['file_name']
        img = cv2.imread(str(img_path))
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Ground truth
        gt_img = img_rgb.copy()
        for ann in img_to_anns[img_info['id']]:
            x, y, w, h = map(int, ann['bbox'])
            cv2.rectangle(gt_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
        
        axes[0, i].imshow(gt_img)
        axes[0, i].set_title(f'GT ({len(img_to_anns[img_info["id"]])})')
        axes[0, i].axis('off')
        
        # Predictions
        img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
        img_tensor = img_tensor.to(Config.DEVICE)
        
        with torch.no_grad():
            output = model([img_tensor])[0]
        
        pred_img = img_rgb.copy()
        mask = (output['scores'] > Config.CONF_THRESHOLD) & (output['labels'] == 1)
        boxes = output['boxes'][mask].cpu().numpy()
        scores = output['scores'][mask].cpu().numpy()
        
        for box, score in zip(boxes, scores):
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(pred_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(pred_img, f'{score:.2f}', (x1, y1-5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
        
        axes[1, i].imshow(pred_img)
        axes[1, i].set_title(f'Pred ({len(boxes)})')
        axes[1, i].axis('off')
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    return fig


# Visualize Model B on clean and perturbed
print("Model B predictions on CLEAN test images:")
fig1 = visualize_detections(model_B, IMAGES_DIR, TEST_ANN, "Model B - Clean Test Set")
fig1.savefig(f"{Config.OUTPUT_DIR}/model_B_clean_predictions.png", dpi=150)
plt.show()

print("\nModel B predictions on PERTURBED test images:")
fig2 = visualize_detections(model_B, perturbed_dir, PERTURBED_TEST_ANN, "Model B - Perturbed Test Set")
fig2.savefig(f"{Config.OUTPUT_DIR}/model_B_perturbed_predictions.png", dpi=150)
plt.show()

## Cell 16: Save Summary

In [None]:
# =============================================================================
# CELL 16: SAVE EXPERIMENT SUMMARY
# =============================================================================

summary = {
    'config': {
        'image_size': Config.IMG_SIZE,
        'num_epochs': Config.NUM_EPOCHS,
        'batch_size': Config.BATCH_SIZE,
        'learning_rate': Config.LR,
        'iou_threshold': Config.IOU_THRESHOLD,
        'conf_threshold': Config.CONF_THRESHOLD
    },
    'model_A_baseline': {
        'training_history': history_A,
        'test_clean': metrics_A_clean,
        'test_perturbed': metrics_A_perturbed,
        'robustness_drop_percent': drop_A
    },
    'model_B_augmented': {
        'training_history': history_B,
        'test_clean': metrics_B_clean,
        'test_perturbed': metrics_B_perturbed,
        'robustness_drop_percent': drop_B
    },
    'conclusion': {
        'augmentation_improves_robustness': drop_B < drop_A,
        'robustness_improvement_percent': drop_A - drop_B
    }
}

with open(f"{Config.OUTPUT_DIR}/experiment_summary.json", 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("EXPERIMENT COMPLETE")
print("="*60)
print(f"""
Saved outputs:
  - Checkpoints: {Config.CHECKPOINT_DIR}/
    - model_A_baseline_best.pth
    - model_B_augmented_best.pth
  - Visualizations: {Config.OUTPUT_DIR}/
    - training_comparison.png
    - metrics_comparison.png
    - model_B_*_predictions.png
  - Summary: {Config.OUTPUT_DIR}/experiment_summary.json

Key Results:
  - Model A (Baseline) F1 drop on perturbed: {drop_A:.1f}%
  - Model B (Augmented) F1 drop on perturbed: {drop_B:.1f}%
  - SAR augmentation improves robustness by: {drop_A - drop_B:.1f}%
""")

---

## Notes

### Key Fixes in This Version:
1. **Box Format**: Proper conversion from COCO [x,y,w,h] to Faster R-CNN [x1,y1,x2,y2]
2. **Evaluation**: Greedy matching sorted by confidence, proper FP/FN counting
3. **Class Filtering**: Only count predictions with label=1 (person)
4. **Debug Info**: TP/FP/FN printed each epoch for verification

### Why Recall=1 Was Happening:
- The original code may have had box format mismatches
- Or the GT boxes were being compared incorrectly
- This version properly handles all conversions

### Swapping to YOLOv8:
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.train(data='data.yaml', epochs=6, imgsz=512, batch=4)
```