# MLHD Model Evaluation Notebook

Comprehensive evaluation of the YOLO-like detector with configurable hyperparameters and visualizations.

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
from glob import glob

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

ROOT = Path.cwd()
sys.path.insert(0, str(ROOT))

from models.yolo_like import Model
from datasets.target_encoding import load_yolo_labels
from datasets.transforms import letterbox_image
from eval.evaluator import ObjectDetectionEvaluator, evaluate_predictions

print(f"Working directory: {ROOT}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")

## 2. Hyperparameters Configuration

In [None]:
# Model Configuration
CHECKPOINT_PATH = 'checkpoints/best.pt'
GRID_SIZE = 26
IMG_SIZE = 416

# Inference Hyperparameters
CONFIDENCE_THRESHOLD = 0.5  # Minimum confidence for detections
NMS_IOU_THRESHOLD = 0.5     # IoU threshold for Non-Maximum Suppression

# Evaluation Hyperparameters
EVAL_IOU_THRESHOLDS = [0.5, 0.75]  # IoU thresholds for metrics computation

# Dataset Paths
VAL_IMAGES_DIR = 'data/processed_training_3/images/val'
VAL_LABELS_DIR = 'data/processed_training_3/labels/val'

# Output Configuration
OUTPUT_DIR = 'outputs/evaluation'
MAX_IMAGES = None  # Set to a number to limit evaluation (None = all images)

# Device Configuration
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')

print("\n" + "="*70)
print("HYPERPARAMETERS")
print("="*70)
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Grid Size: {GRID_SIZE}")
print(f"Image Size: {IMG_SIZE}")
print(f"Confidence Threshold: {CONFIDENCE_THRESHOLD}")
print(f"NMS IoU Threshold: {NMS_IOU_THRESHOLD}")
print(f"Evaluation IoU Thresholds: {EVAL_IOU_THRESHOLDS}")
print(f"Device: {DEVICE}")
print(f"Max Images: {MAX_IMAGES if MAX_IMAGES else 'All'}")
print("="*70)

## 3. Load Model

In [None]:
print("\nLoading model...")
model = Model(S=GRID_SIZE)

# Load checkpoint to CPU first to avoid MPS compatibility issues
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')
model.load_state_dict(checkpoint['model'])

# Then move to target device
model.to(DEVICE)
model.eval()

print(f"Model loaded from {CHECKPOINT_PATH}")
print(f"Checkpoint epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"Checkpoint validation loss: {checkpoint.get('val_loss', 'N/A'):.4f}")

# Count model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

## 4. Helper Functions

In [None]:
def decode_predictions(pred_grid, conf_thresh, grid_size, img_size):
    """Decode grid predictions [S, S, 5] to bounding boxes."""
    S = grid_size
    boxes = []

    for j in range(S):
        for i in range(S):
            tx = pred_grid[j, i, 0].item()
            ty = pred_grid[j, i, 1].item()
            tw = pred_grid[j, i, 2].item()
            th = pred_grid[j, i, 3].item()
            obj_conf = pred_grid[j, i, 4].item()

            if obj_conf < conf_thresh:
                continue

            cx = (i + tx) / S
            cy = (j + ty) / S
            w = tw
            h = th

            cx_pix = cx * img_size
            cy_pix = cy * img_size
            w_pix = w * img_size
            h_pix = h * img_size

            x1 = cx_pix - w_pix / 2
            y1 = cy_pix - h_pix / 2
            x2 = cx_pix + w_pix / 2
            y2 = cy_pix + h_pix / 2

            boxes.append((x1, y1, x2, y2, obj_conf))

    return boxes


def compute_iou(box1, box2):
    """Compute IoU between two boxes."""
    x1_1, y1_1, x2_1, y2_1, _ = box1
    x1_2, y1_2, x2_2, y2_2, _ = box2

    x1_i = max(x1_1, x1_2)
    y1_i = max(y1_1, y1_2)
    x2_i = min(x2_1, x2_2)
    y2_i = min(y2_1, y2_2)

    if x2_i <= x1_i or y2_i <= y1_i:
        return 0.0

    inter_area = (x2_i - x1_i) * (y2_i - y1_i)
    area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
    area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
    union_area = area1 + area2 - inter_area

    return inter_area / union_area if union_area > 0 else 0.0


def nms(boxes, iou_thresh):
    """Apply Non-Maximum Suppression."""
    if len(boxes) == 0:
        return []

    boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
    keep = []

    while len(boxes) > 0:
        best = boxes[0]
        keep.append(best)
        boxes = boxes[1:]

        filtered = []
        for box in boxes:
            iou = compute_iou(best, box)
            if iou < iou_thresh:
                filtered.append(box)
        boxes = filtered

    return keep


def transform_boxes_to_original(boxes, params, img_size):
    """Transform boxes from letterbox coordinates to original image coordinates."""
    if len(boxes) == 0:
        return []

    scale = params['scale']
    pad_w = params['pad_w']
    pad_h = params['pad_h']
    orig_w, orig_h = params['orig_wh']

    transformed = []
    for (x1, y1, x2, y2, conf) in boxes:
        x1_no_pad = x1 - pad_w
        y1_no_pad = y1 - pad_h
        x2_no_pad = x2 - pad_w
        y2_no_pad = y2 - pad_h

        x1_orig = x1_no_pad / scale
        y1_orig = y1_no_pad / scale
        x2_orig = x2_no_pad / scale
        y2_orig = y2_no_pad / scale

        x1_orig = max(0, min(x1_orig, orig_w))
        y1_orig = max(0, min(y1_orig, orig_h))
        x2_orig = max(0, min(x2_orig, orig_w))
        y2_orig = max(0, min(y2_orig, orig_h))

        transformed.append((x1_orig, y1_orig, x2_orig, y2_orig, conf))

    return transformed


def infer_image(model, image_path, device, img_size, grid_size, conf_thresh, iou_thresh):
    """Run inference on a single image."""
    img_tensor, params = letterbox_image(image_path, (img_size, img_size))
    img_tensor = img_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img_tensor)

    boxes = decode_predictions(pred[0], conf_thresh, grid_size, img_size)
    boxes = nms(boxes, iou_thresh)

    return boxes, params

print("Helper functions loaded.")

## 5. Run Evaluation on Dataset

In [None]:
image_paths = sorted(glob(os.path.join(VAL_IMAGES_DIR, "*.jpg")))

if MAX_IMAGES:
    image_paths = image_paths[:MAX_IMAGES]

all_predictions = []
all_ground_truth = []
image_ids = []
stats = {
    'total': 0,
    'processed': 0,
    'errors': 0,
    'images_with_objects': 0,
    'total_pred_objects': 0,
    'total_gt_objects': 0
}

print(f"\nEvaluating {len(image_paths)} images...")
print("-" * 70)

for idx, img_path in enumerate(image_paths):
    stats['total'] += 1

    try:
        basename = os.path.basename(img_path)
        label_name = os.path.splitext(basename)[0] + '.txt'
        label_path = os.path.join(VAL_LABELS_DIR, label_name)

        gt_boxes = load_yolo_labels(label_path)
        if len(gt_boxes) > 0:
            stats['images_with_objects'] += 1
        stats['total_gt_objects'] += len(gt_boxes)

        img = cv2.imread(img_path)
        h, w = img.shape[:2]
        gt_pixel = []
        for (cx, cy, bw, bh) in gt_boxes:
            x1 = (cx - bw / 2) * w
            y1 = (cy - bh / 2) * h
            x2 = (cx + bw / 2) * w
            y2 = (cy + bh / 2) * h
            gt_pixel.append((x1, y1, x2, y2))

        preds, params = infer_image(
            model, img_path, DEVICE, IMG_SIZE, GRID_SIZE,
            CONFIDENCE_THRESHOLD, NMS_IOU_THRESHOLD
        )

        preds_orig = transform_boxes_to_original(preds, params, IMG_SIZE)

        all_predictions.append(preds_orig)
        all_ground_truth.append(gt_pixel)
        image_ids.append(basename)

        stats['processed'] += 1
        stats['total_pred_objects'] += len(preds_orig)

        if (idx + 1) % 50 == 0:
            print(f"Processed {idx + 1}/{len(image_paths)} images...")

    except Exception as e:
        stats['errors'] += 1
        print(f"Error processing {img_path}: {e}")

print("-" * 70)
print(f"Evaluation complete: {stats['processed']} images processed, {stats['errors']} errors")

print("\n" + "="*70)
print("DATASET STATISTICS")
print("="*70)
print(f"Total images evaluated: {stats['processed']}")
print(f"Images with objects: {stats['images_with_objects']}")
print(f"Total ground truth objects: {stats['total_gt_objects']}")
print(f"Total predicted objects: {stats['total_pred_objects']}")

if stats['total_gt_objects'] > 0:
    recall_ceiling = stats['total_pred_objects'] / stats['total_gt_objects']
    print(f"Recall ceiling (pred/gt): {recall_ceiling:.2%}")

print(f"Average objects per image (GT): {stats['total_gt_objects'] / max(1, stats['processed']):.2f}")
print(f"Average objects per image (Pred): {stats['total_pred_objects'] / max(1, stats['processed']):.2f}")

## 6. Compute Evaluation Metrics

In [None]:
print("\n" + "="*70)
print("EVALUATION METRICS")
print("="*70)

results = evaluate_predictions(
    all_predictions, all_ground_truth,
    iou_thresholds=EVAL_IOU_THRESHOLDS,
    verbose=True
)

print("\n" + "="*70)
print("MEAN AVERAGE PRECISION (mAP)")
print("="*70)

evaluator = ObjectDetectionEvaluator(EVAL_IOU_THRESHOLDS)
map_scores = evaluator.compute_map(all_predictions, all_ground_truth)

print("\nmAP scores:")
for iou_thresh, ap in map_scores.items():
    print(f"  IoU {iou_thresh}: {ap:.4f}")

## 7. Detailed Results Summary

In [None]:
print("\n" + "="*70)
print("DETAILED RESULTS SUMMARY")
print("="*70)

for iou_thresh in sorted(results.keys()):
    metrics = results[iou_thresh]
    print(f"\nIoU Threshold: {iou_thresh}")
    print(f"  Precision: {metrics.precision:.4f}")
    print(f"  Recall:    {metrics.recall:.4f}")
    print(f"  F1 Score:  {metrics.f1:.4f}")
    print(f"  AP:        {metrics.ap:.4f}")
    print(f"  TP: {metrics.tp}, FP: {metrics.fp}, FN: {metrics.fn}")

## 8. Per-Image Performance Analysis

In [None]:
print("\n" + "="*70)
print("PER-IMAGE PERFORMANCE")
print("="*70)

per_image_stats = []
iou_05_threshold = 0.5

for i, (preds, gts, img_id) in enumerate(zip(all_predictions, all_ground_truth, image_ids)):
    tp, fp, fn, _ = evaluator.match_predictions_to_ground_truth(
        preds, gts, iou_05_threshold
    )
    precision, recall, f1 = evaluator.compute_precision_recall_f1(tp, fp, fn)
    per_image_stats.append({
        'image_id': img_id,
        'pred': len(preds),
        'gt': len(gts),
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'precision': precision,
        'recall': recall,
        'f1': f1
    })

per_image_stats.sort(key=lambda x: x['f1'], reverse=True)

print("\nTop 10 best performing images (by F1 score):")
print(f"{'Image':<30} {'Pred':<5} {'GT':<5} {'TP':<4} {'FP':<4} {'FN':<4} {'F1':<8}")
print("-" * 70)
for stat in per_image_stats[:10]:
    print(f"{stat['image_id']:<30} {stat['pred']:<5} {stat['gt']:<5} "
          f"{stat['tp']:<4} {stat['fp']:<4} {stat['fn']:<4} "
          f"{stat['f1']:<8.4f}")

print("\nBottom 10 worst performing images (by F1 score):")
print(f"{'Image':<30} {'Pred':<5} {'GT':<5} {'TP':<4} {'FP':<4} {'FN':<4} {'F1':<8}")
print("-" * 70)
for stat in per_image_stats[-10:]:
    print(f"{stat['image_id']:<30} {stat['pred']:<5} {stat['gt']:<5} "
          f"{stat['tp']:<4} {stat['fp']:<4} {stat['fn']:<4} "
          f"{stat['f1']:<8.4f}")

## 9. Visualize Metrics

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

iou_thresholds = sorted(results.keys())
precisions = [results[t].precision for t in iou_thresholds]
recalls = [results[t].recall for t in iou_thresholds]
f1_scores = [results[t].f1 for t in iou_thresholds]
aps = [results[t].ap for t in iou_thresholds]

axes[0, 0].bar(range(len(iou_thresholds)), precisions, color='blue', alpha=0.7)
axes[0, 0].set_xticks(range(len(iou_thresholds)))
axes[0, 0].set_xticklabels([f"{t:.2f}" for t in iou_thresholds])
axes[0, 0].set_ylabel('Precision')
axes[0, 0].set_xlabel('IoU Threshold')
axes[0, 0].set_title('Precision vs IoU Threshold')
axes[0, 0].set_ylim([0, 1])
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].bar(range(len(iou_thresholds)), recalls, color='green', alpha=0.7)
axes[0, 1].set_xticks(range(len(iou_thresholds)))
axes[0, 1].set_xticklabels([f"{t:.2f}" for t in iou_thresholds])
axes[0, 1].set_ylabel('Recall')
axes[0, 1].set_xlabel('IoU Threshold')
axes[0, 1].set_title('Recall vs IoU Threshold')
axes[0, 1].set_ylim([0, 1])
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].bar(range(len(iou_thresholds)), f1_scores, color='orange', alpha=0.7)
axes[1, 0].set_xticks(range(len(iou_thresholds)))
axes[1, 0].set_xticklabels([f"{t:.2f}" for t in iou_thresholds])
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].set_xlabel('IoU Threshold')
axes[1, 0].set_title('F1 Score vs IoU Threshold')
axes[1, 0].set_ylim([0, 1])
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].bar(range(len(iou_thresholds)), aps, color='red', alpha=0.7)
axes[1, 1].set_xticks(range(len(iou_thresholds)))
axes[1, 1].set_xticklabels([f"{t:.2f}" for t in iou_thresholds])
axes[1, 1].set_ylabel('Average Precision')
axes[1, 1].set_xlabel('IoU Threshold')
axes[1, 1].set_title('AP vs IoU Threshold')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'metrics_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVisualization saved to {OUTPUT_DIR}/metrics_visualization.png")

## 10. F1 Score Distribution

In [None]:
f1_values = [stat['f1'] for stat in per_image_stats]

plt.figure(figsize=(10, 6))
plt.hist(f1_values, bins=20, color='purple', alpha=0.7, edgecolor='black')
plt.xlabel('F1 Score')
plt.ylabel('Number of Images')
plt.title('Distribution of F1 Scores Across Images')
plt.axvline(np.mean(f1_values), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(f1_values):.3f}')
plt.axvline(np.median(f1_values), color='green', linestyle='--', linewidth=2, label=f'Median: {np.median(f1_values):.3f}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(OUTPUT_DIR, 'f1_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"F1 Score Statistics:")
print(f"  Mean: {np.mean(f1_values):.4f}")
print(f"  Median: {np.median(f1_values):.4f}")
print(f"  Std Dev: {np.std(f1_values):.4f}")
print(f"  Min: {np.min(f1_values):.4f}")
print(f"  Max: {np.max(f1_values):.4f}")

## 11. Visualize Sample Predictions

In [None]:
num_samples = min(6, len(image_ids))
sample_indices = np.linspace(0, len(image_ids) - 1, num_samples, dtype=int)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, sample_idx in enumerate(sample_indices):
    img_id = image_ids[sample_idx]
    img_path = os.path.join(VAL_IMAGES_DIR, img_id)
    
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    axes[idx].imshow(img)
    
    for (x1, y1, x2, y2) in all_ground_truth[sample_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='green', facecolor='none', label='GT')
        axes[idx].add_patch(rect)
    
    for (x1, y1, x2, y2, conf) in all_predictions[sample_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none', linestyle='--', label='Pred')
        axes[idx].add_patch(rect)
    
    stat = per_image_stats[sample_idx]
    axes[idx].set_title(f"{img_id}\nF1: {stat['f1']:.3f} | P: {len(all_predictions[sample_idx])} GT: {len(all_ground_truth[sample_idx])}")
    axes[idx].axis('off')

handles = [plt.Line2D([0], [0], color='green', linewidth=2, label='Ground Truth'),
           plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Prediction')]
fig.legend(handles=handles, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.98))

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(os.path.join(OUTPUT_DIR, 'sample_predictions.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSample predictions visualization saved to {OUTPUT_DIR}/sample_predictions.png")

## 12. Save Results to File

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

with open(os.path.join(OUTPUT_DIR, 'evaluation_summary.txt'), 'w') as f:
    f.write("="*70 + "\n")
    f.write("MLHD MODEL EVALUATION SUMMARY\n")
    f.write("="*70 + "\n\n")

    f.write("Configuration:\n")
    f.write(f"  Checkpoint: {CHECKPOINT_PATH}\n")
    f.write(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}\n")
    f.write(f"  NMS IoU threshold: {NMS_IOU_THRESHOLD}\n")
    f.write(f"  Evaluation IoU thresholds: {EVAL_IOU_THRESHOLDS}\n\n")

    f.write("Results:\n")
    for iou_thresh in sorted(results.keys()):
        metrics = results[iou_thresh]
        f.write(f"\nIoU {iou_thresh}:\n")
        f.write(f"  Precision: {metrics.precision:.4f}\n")
        f.write(f"  Recall:    {metrics.recall:.4f}\n")
        f.write(f"  F1 Score:  {metrics.f1:.4f}\n")
        f.write(f"  AP:        {metrics.ap:.4f}\n")
        f.write(f"  TP: {metrics.tp}, FP: {metrics.fp}, FN: {metrics.fn}\n")

    f.write(f"\nmean Average Precision (mAP):\n")
    for iou_thresh, ap in map_scores.items():
        f.write(f"  IoU {iou_thresh}: {ap:.4f}\n")

    f.write(f"\nDataset Statistics:\n")
    f.write(f"  Total images: {stats['processed']}\n")
    f.write(f"  Images with objects: {stats['images_with_objects']}\n")
    f.write(f"  Total GT objects: {stats['total_gt_objects']}\n")
    f.write(f"  Total predicted objects: {stats['total_pred_objects']}\n")

print(f"\nResults saved to {OUTPUT_DIR}/evaluation_summary.txt")
print(f"\nEvaluation complete!")

## 13. Training/Validation Loss Curves

In [None]:
# Check if training history is available in checkpoint
if 'train_losses' in checkpoint and 'val_losses' in checkpoint:
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss (log scale)')
    plt.yscale('log')
    plt.title('Training and Validation Loss (Log Scale)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'loss_curves.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Final Training Loss: {train_losses[-1]:.4f}")
    print(f"Final Validation Loss: {val_losses[-1]:.4f}")
    print(f"Best Validation Loss: {min(val_losses):.4f} (Epoch {val_losses.index(min(val_losses)) + 1})")
else:
    print("Training history not found in checkpoint. Skipping loss curves.")
    print("Note: To save training history, modify your training script to include:")
    print("  checkpoint['train_losses'] = train_losses")
    print("  checkpoint['val_losses'] = val_losses")

## 14. Precision-Recall Curves

In [None]:
# Compute precision-recall curve by varying confidence threshold
conf_thresholds = np.linspace(0.1, 0.9, 20)
pr_data = {'0.5': {'precisions': [], 'recalls': []}, '0.75': {'precisions': [], 'recalls': []}}

print("Computing Precision-Recall curves...")
print("This may take a few minutes...")

for conf_thresh in conf_thresholds:
    # Re-run inference with different confidence thresholds
    temp_predictions = []
    
    for img_path in image_paths[:len(all_predictions)]:
        try:
            preds, params = infer_image(
                model, img_path, DEVICE, IMG_SIZE, GRID_SIZE,
                conf_thresh, NMS_IOU_THRESHOLD
            )
            preds_orig = transform_boxes_to_original(preds, params, IMG_SIZE)
            temp_predictions.append(preds_orig)
        except:
            temp_predictions.append([])
    
    # Compute metrics at IoU 0.5 and 0.75
    for iou_thresh in [0.5, 0.75]:
        total_tp = 0
        total_fp = 0
        total_fn = 0
        
        for preds, gts in zip(temp_predictions, all_ground_truth):
            tp, fp, fn, _ = evaluator.match_predictions_to_ground_truth(preds, gts, iou_thresh)
            total_tp += tp
            total_fp += fp
            total_fn += fn
        
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        
        pr_data[str(iou_thresh)]['precisions'].append(precision)
        pr_data[str(iou_thresh)]['recalls'].append(recall)

# Plot P-R curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(pr_data['0.5']['recalls'], pr_data['0.5']['precisions'], 'b-o', linewidth=2, markersize=4)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve (IoU=0.5)')
plt.grid(True, alpha=0.3)
plt.xlim([0, 1])
plt.ylim([0, 1])

plt.subplot(1, 2, 2)
plt.plot(pr_data['0.75']['recalls'], pr_data['0.75']['precisions'], 'r-o', linewidth=2, markersize=4)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve (IoU=0.75)')
plt.grid(True, alpha=0.3)
plt.xlim([0, 1])
plt.ylim([0, 1])

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'pr_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

print("Precision-Recall curves saved.")

## 15. Confidence Score Distribution Analysis

In [None]:
# Analyze confidence scores of predictions
all_confidences = []
tp_confidences = []
fp_confidences = []

iou_thresh = 0.5

for preds, gts in zip(all_predictions, all_ground_truth):
    for pred in preds:
        conf = pred[4]
        all_confidences.append(conf)
        
        # Check if this prediction is TP or FP
        is_tp = False
        for gt in gts:
            pred_box = pred[:4]
            gt_box = gt[:4]
            
            # Compute IoU
            x1_i = max(pred_box[0], gt_box[0])
            y1_i = max(pred_box[1], gt_box[1])
            x2_i = min(pred_box[2], gt_box[2])
            y2_i = min(pred_box[3], gt_box[3])
            
            if x2_i > x1_i and y2_i > y1_i:
                inter_area = (x2_i - x1_i) * (y2_i - y1_i)
                pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
                gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
                union_area = pred_area + gt_area - inter_area
                iou = inter_area / union_area if union_area > 0 else 0
                
                if iou >= iou_thresh:
                    is_tp = True
                    break
        
        if is_tp:
            tp_confidences.append(conf)
        else:
            fp_confidences.append(conf)

# Plot confidence distributions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].hist(all_confidences, bins=30, color='blue', alpha=0.7, edgecolor='black')
axes[0].axvline(CONFIDENCE_THRESHOLD, color='red', linestyle='--', linewidth=2, label=f'Threshold: {CONFIDENCE_THRESHOLD}')
axes[0].set_xlabel('Confidence Score')
axes[0].set_ylabel('Number of Predictions')
axes[0].set_title('All Predictions - Confidence Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].hist(tp_confidences, bins=30, color='green', alpha=0.7, edgecolor='black')
axes[1].axvline(CONFIDENCE_THRESHOLD, color='red', linestyle='--', linewidth=2, label=f'Threshold: {CONFIDENCE_THRESHOLD}')
axes[1].set_xlabel('Confidence Score')
axes[1].set_ylabel('Number of Predictions')
axes[1].set_title(f'True Positives - Confidence Distribution (IoU≥{iou_thresh})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].hist(fp_confidences, bins=30, color='red', alpha=0.7, edgecolor='black')
axes[2].axvline(CONFIDENCE_THRESHOLD, color='red', linestyle='--', linewidth=2, label=f'Threshold: {CONFIDENCE_THRESHOLD}')
axes[2].set_xlabel('Confidence Score')
axes[2].set_ylabel('Number of Predictions')
axes[2].set_title(f'False Positives - Confidence Distribution (IoU<{iou_thresh})')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'confidence_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nConfidence Statistics:")
print(f"  All Predictions - Mean: {np.mean(all_confidences):.3f}, Median: {np.median(all_confidences):.3f}")
print(f"  True Positives  - Mean: {np.mean(tp_confidences):.3f}, Median: {np.median(tp_confidences):.3f}")
print(f"  False Positives - Mean: {np.mean(fp_confidences):.3f}, Median: {np.median(fp_confidences):.3f}")
print(f"  Total Predictions: {len(all_confidences)}")
print(f"  True Positives: {len(tp_confidences)} ({100*len(tp_confidences)/len(all_confidences):.1f}%)")
print(f"  False Positives: {len(fp_confidences)} ({100*len(fp_confidences)/len(all_confidences):.1f}%)")

## 16. Failure Case Analysis

In [None]:
# Categorize failure cases
failure_cases = {
    'missed_detections': [],  # FN: Objects in GT but not detected
    'false_positives': [],     # FP: Predictions with no matching GT
    'low_confidence_tp': [],   # TP with low confidence
    'high_confidence_fp': []   # FP with high confidence
}

iou_threshold = 0.5
low_conf_thresh = 0.3
high_conf_thresh = 0.7

for idx, (preds, gts, img_id) in enumerate(zip(all_predictions, all_ground_truth, image_ids)):
    # Find missed detections (FN)
    matched_gts = set()
    
    for pred in preds:
        for gt_idx, gt in enumerate(gts):
            pred_box = pred[:4]
            gt_box = gt[:4]
            
            x1_i = max(pred_box[0], gt_box[0])
            y1_i = max(pred_box[1], gt_box[1])
            x2_i = min(pred_box[2], gt_box[2])
            y2_i = min(pred_box[3], gt_box[3])
            
            if x2_i > x1_i and y2_i > y1_i:
                inter_area = (x2_i - x1_i) * (y2_i - y1_i)
                pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
                gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
                union_area = pred_area + gt_area - inter_area
                iou = inter_area / union_area if union_area > 0 else 0
                
                if iou >= iou_threshold:
                    matched_gts.add(gt_idx)
                    
                    # Check for low confidence TP
                    if pred[4] < low_conf_thresh:
                        failure_cases['low_confidence_tp'].append({
                            'image': img_id,
                            'confidence': pred[4],
                            'iou': iou
                        })
    
    # Unmatched GTs are missed detections
    for gt_idx in range(len(gts)):
        if gt_idx not in matched_gts:
            failure_cases['missed_detections'].append({
                'image': img_id,
                'gt_box': gts[gt_idx]
            })
    
    # Check for false positives
    for pred in preds:
        is_matched = False
        best_iou = 0
        
        for gt in gts:
            pred_box = pred[:4]
            gt_box = gt[:4]
            
            x1_i = max(pred_box[0], gt_box[0])
            y1_i = max(pred_box[1], gt_box[1])
            x2_i = min(pred_box[2], gt_box[2])
            y2_i = min(pred_box[3], gt_box[3])
            
            if x2_i > x1_i and y2_i > y1_i:
                inter_area = (x2_i - x1_i) * (y2_i - y1_i)
                pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
                gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
                union_area = pred_area + gt_area - inter_area
                iou = inter_area / union_area if union_area > 0 else 0
                best_iou = max(best_iou, iou)
                
                if iou >= iou_threshold:
                    is_matched = True
                    break
        
        if not is_matched:
            failure_cases['false_positives'].append({
                'image': img_id,
                'confidence': pred[4],
                'best_iou': best_iou
            })
            
            # High confidence FP
            if pred[4] >= high_conf_thresh:
                failure_cases['high_confidence_fp'].append({
                    'image': img_id,
                    'confidence': pred[4],
                    'best_iou': best_iou
                })

print("="*70)
print("FAILURE CASE ANALYSIS")
print("="*70)
print(f"\nTotal Missed Detections (FN): {len(failure_cases['missed_detections'])}")
print(f"Total False Positives (FP): {len(failure_cases['false_positives'])}")
print(f"Low Confidence True Positives (conf < {low_conf_thresh}): {len(failure_cases['low_confidence_tp'])}")
print(f"High Confidence False Positives (conf ≥ {high_conf_thresh}): {len(failure_cases['high_confidence_fp'])}")

print(f"\nMissed Detections Rate: {100*len(failure_cases['missed_detections'])/max(1,stats['total_gt_objects']):.2f}%")
print(f"False Positive Rate: {100*len(failure_cases['false_positives'])/max(1,stats['total_pred_objects']):.2f}%")

# Show sample failure cases
if len(failure_cases['high_confidence_fp']) > 0:
    print(f"\nSample High-Confidence False Positives:")
    for case in failure_cases['high_confidence_fp'][:5]:
        print(f"  {case['image']}: conf={case['confidence']:.3f}, best_IoU={case['best_iou']:.3f}")

if len(failure_cases['low_confidence_tp']) > 0:
    print(f"\nSample Low-Confidence True Positives:")
    for case in failure_cases['low_confidence_tp'][:5]:
        print(f"  {case['image']}: conf={case['confidence']:.3f}, IoU={case['iou']:.3f}")

## 17. Visualize Best and Worst Cases with GT Comparison

In [None]:
# Visualize best and worst performing images with GT comparison
best_images = per_image_stats[:3]
worst_images = per_image_stats[-3:]

fig, axes = plt.subplots(2, 3, figsize=(20, 14))

# Plot best cases (top row)
for idx, stat in enumerate(best_images):
    img_id = stat['image_id']
    img_idx = image_ids.index(img_id)
    img_path = os.path.join(VAL_IMAGES_DIR, img_id)
    
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    axes[0, idx].imshow(img)
    
    # Draw ground truth in green
    for (x1, y1, x2, y2) in all_ground_truth[img_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='green', facecolor='none', label='GT')
        axes[0, idx].add_patch(rect)
    
    # Draw predictions in red
    for (x1, y1, x2, y2, conf) in all_predictions[img_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none', linestyle='--', label='Pred')
        axes[0, idx].add_patch(rect)
        axes[0, idx].text(x1, y1-5, f'{conf:.2f}', color='red', fontsize=10, weight='bold',
                         bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
    
    axes[0, idx].set_title(f"BEST #{idx+1}: {img_id}\nF1: {stat['f1']:.3f} | TP:{stat['tp']} FP:{stat['fp']} FN:{stat['fn']}", 
                          fontsize=11, weight='bold')
    axes[0, idx].axis('off')

# Plot worst cases (bottom row)
for idx, stat in enumerate(worst_images):
    img_id = stat['image_id']
    img_idx = image_ids.index(img_id)
    img_path = os.path.join(VAL_IMAGES_DIR, img_id)
    
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    axes[1, idx].imshow(img)
    
    # Draw ground truth in green
    for (x1, y1, x2, y2) in all_ground_truth[img_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='green', facecolor='none', label='GT')
        axes[1, idx].add_patch(rect)
    
    # Draw predictions in red
    for (x1, y1, x2, y2, conf) in all_predictions[img_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none', linestyle='--', label='Pred')
        axes[1, idx].add_patch(rect)
        axes[1, idx].text(x1, y1-5, f'{conf:.2f}', color='red', fontsize=10, weight='bold',
                         bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
    
    axes[1, idx].set_title(f"WORST #{idx+1}: {img_id}\nF1: {stat['f1']:.3f} | TP:{stat['tp']} FP:{stat['fp']} FN:{stat['fn']}", 
                          fontsize=11, weight='bold')
    axes[1, idx].axis('off')

handles = [plt.Line2D([0], [0], color='green', linewidth=3, label='Ground Truth'),
           plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Prediction')]
fig.legend(handles=handles, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.99), fontsize=12)

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig(os.path.join(OUTPUT_DIR, 'best_worst_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nBest/Worst case comparison saved to {OUTPUT_DIR}/best_worst_comparison.png")

## 18. Grid Cell Activation Heatmap

In [None]:
# Visualize grid cell activations for sample images
sample_indices = [per_image_stats[0]['image_id'], per_image_stats[len(per_image_stats)//2]['image_id'], 
                  per_image_stats[-1]['image_id']]

fig, axes = plt.subplots(3, 3, figsize=(18, 18))

for row, img_id in enumerate(sample_indices[:3]):
    img_idx = image_ids.index(img_id)
    img_path = os.path.join(VAL_IMAGES_DIR, img_id)
    
    # Load and preprocess image
    img_tensor, params = letterbox_image(img_path, (IMG_SIZE, IMG_SIZE))
    img_tensor_batch = img_tensor.unsqueeze(0).to(DEVICE)
    
    # Get model prediction
    with torch.no_grad():
        pred_grid = model(img_tensor_batch)[0].cpu().numpy()
    
    # Original image
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    axes[row, 0].imshow(img_rgb)
    
    # Draw GT boxes
    for (x1, y1, x2, y2) in all_ground_truth[img_idx]:
        rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='green', facecolor='none')
        axes[row, 0].add_patch(rect)
    
    axes[row, 0].set_title(f"Original: {img_id}", fontsize=10)
    axes[row, 0].axis('off')
    
    # Objectness heatmap
    objectness_map = pred_grid[:, :, 4]
    im1 = axes[row, 1].imshow(objectness_map, cmap='hot', interpolation='nearest')
    axes[row, 1].set_title(f"Objectness Heatmap\nMax: {objectness_map.max():.3f}, Mean: {objectness_map.mean():.3f}", 
                          fontsize=10)
    axes[row, 1].axis('off')
    plt.colorbar(im1, ax=axes[row, 1], fraction=0.046, pad=0.04)
    
    # Overlay heatmap on image
    img_resized = cv2.resize(img_rgb, (GRID_SIZE, GRID_SIZE))
    axes[row, 2].imshow(img_resized, alpha=0.6)
    im2 = axes[row, 2].imshow(objectness_map, cmap='hot', alpha=0.4, interpolation='nearest')
    axes[row, 2].set_title(f"Overlay (Grid {GRID_SIZE}x{GRID_SIZE})", fontsize=10)
    axes[row, 2].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'grid_activation_heatmaps.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\nGrid activation heatmaps saved to {OUTPUT_DIR}/grid_activation_heatmaps.png")