## Install dependencies
Run the cell below to install Python packages (non-PyTorch).
For PyTorch + CUDA, please follow the official installation command that matches your CUDA version (examples provided).

In [None]:
# Install general Python packages used by the notebook.
# NOTE: PyTorch installation depends on your CUDA version and is not installed by this single pip command.
# If you use conda and want GPU-enabled PyTorch (recommended), run one of the example commands below in a terminal.

# Example conda (CUDA 11.8) - run in a terminal:
# conda install -y pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch

# Example pip (CUDA 11.8) - run in a terminal or here if appropriate:
# pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio --upgrade

# Install the rest of the requirements (this installs matplotlib, tqdm, pillow, etc.).
!pip install -r requirements.txt

print('Requirements installation attempted.
Please ensure PyTorch with CUDA is installed separately if you need GPU support.')

## 1. Setup and Imports

In [None]:
import os
import json
import random
import numpy as np
from PIL import Image
from collections import defaultdict
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch
import torch.utils.data
from torch.cuda.amp import GradScaler, autocast
import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetHead
import torchvision.transforms.functional as F

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Quick GPU test: small tensor operations on GPU (if available)
import torch
print('torch version:', torch.__version__)
print('cuda available:', torch.cuda.is_available())
if torch.cuda.is_available():
    dev = torch.device('cuda')
    x = torch.randn((1024, 1024), device=dev)
    y = x * 2.0
    z = (x + y).sum()
    print('Tensor op result on', x.device, '->', z.item())
else:
    print('CUDA not available in this kernel. If you expect GPU support, install PyTorch with CUDA and restart the kernel.')

## 2. Configuration

In [None]:
class Config:
    # Paths
    DATA_PATH = r'D:\CLID\IDD-CPLID.v3-cplid_new.coco'
    OUTPUT_DIR = 'checkpoints'
    
    # Model
    NUM_CLASSES = 3  # background (0 is used internally), defect (1), insulator (2)
    
    # Training
    BATCH_SIZE = 4
    NUM_EPOCHS = 30
    NUM_WORKERS = 4
    
    # Optimizer
    LEARNING_RATE = 0.001
    MOMENTUM = 0.9
    WEIGHT_DECAY = 0.0005
    
    # Learning Rate Scheduler
    LR_SCHEDULER = 'cosine'  # 'step', 'cosine', 'onecycle'
    WARMUP_EPOCHS = 3
    
    # Mixed Precision
    USE_AMP = True
    
    # Gradient Clipping
    GRAD_CLIP = 1.0
    
    # Early Stopping
    EARLY_STOPPING_PATIENCE = 10
    
    # Data Augmentation
    USE_AUGMENTATION = True
    
    # Evaluation
    SCORE_THRESHOLD = 0.5
    IOU_THRESHOLD = 0.5
    
    # Class names
    CLASS_NAMES = {0: 'background', 1: 'defect', 2: 'insulator'}
    CLASS_COLORS = {1: 'red', 2: 'green'}

config = Config()
os.makedirs(config.OUTPUT_DIR, exist_ok=True)

## 3. Dataset Definition with Augmentation

In [None]:
class CustomCocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation_file, transforms=None, augment=False):
        self.root = root
        self.transforms = transforms
        self.augment = augment
        
        # Load annotations
        with open(annotation_file, 'r') as f:
            self.coco_data = json.load(f)
        
        # Map image_id to filename and other info
        self.images = {img['id']: img for img in self.coco_data['images']}
        
        # Map image_id to list of annotations
        self.img_to_anns = {img['id']: [] for img in self.coco_data['images']}
        for ann in self.coco_data['annotations']:
            if ann['image_id'] in self.img_to_anns:
                self.img_to_anns[ann['image_id']].append(ann)
        
        # List of image IDs for indexing
        self.ids = list(self.images.keys())
        
        # Categories mapping
        self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']}

    def __getitem__(self, index):
        img_id = self.ids[index]
        img_info = self.images[img_id]
        file_name = img_info['file_name']
        
        # Load Image
        img_path = os.path.join(self.root, file_name)
        img = Image.open(img_path).convert("RGB")
        
        # Get Annotations
        anns = self.img_to_anns.get(img_id, [])
        
        boxes = []
        labels = []
        areas = []
        iscrowd = []
        
        for ann in anns:
            # COCO bbox: [x, y, w, h] -> PyTorch: [x1, y1, x2, y2]
            x, y, w, h = ann['bbox']
            # Skip invalid boxes
            if w <= 0 or h <= 0:
                continue
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'])
            areas.append(ann['area'])
            iscrowd.append(ann.get('iscrowd', 0))
        
        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            areas = torch.as_tensor(areas, dtype=torch.float32)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        else:
            # Negative example (no objects)
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            areas = torch.zeros((0,), dtype=torch.float32)
            iscrowd = torch.zeros((0,), dtype=torch.int64)

        image_id = torch.tensor([img_id])
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": image_id,
            "area": areas,
            "iscrowd": iscrowd
        }

        # Apply augmentations
        if self.augment:
            img, target = self.apply_augmentations(img, target)
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def apply_augmentations(self, img, target):
        """Apply data augmentations"""
        boxes = target['boxes']
        
        # Random Horizontal Flip
        if random.random() > 0.5:
            img = F.hflip(img)
            if len(boxes) > 0:
                w = img.width
                boxes[:, [0, 2]] = w - boxes[:, [2, 0]]
        
        # Color Jittering
        if random.random() > 0.5:
            brightness = random.uniform(0.8, 1.2)
            contrast = random.uniform(0.8, 1.2)
            saturation = random.uniform(0.8, 1.2)
            img = F.adjust_brightness(img, brightness)
            img = F.adjust_contrast(img, contrast)
            img = F.adjust_saturation(img, saturation)
        
        target['boxes'] = boxes
        return img, target

    def __len__(self):
        return len(self.ids)


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target


def get_transform(train=False):
    transforms = [ToTensor()]
    # Note: RetinaNet handles normalization internally with pretrained weights
    return Compose(transforms)


def collate_fn(batch):
    return tuple(zip(*batch))

## 4. Load Datasets

In [None]:
# Dataset paths
train_dir = os.path.join(config.DATA_PATH, 'train')
train_ann = os.path.join(train_dir, '_annotations.coco.json')
val_dir = os.path.join(config.DATA_PATH, 'valid')
val_ann = os.path.join(val_dir, '_annotations.coco.json')
test_dir = os.path.join(config.DATA_PATH, 'test')
test_ann = os.path.join(test_dir, '_annotations.coco.json')

# Create datasets
dataset_train = CustomCocoDataset(
    train_dir, train_ann, 
    transforms=get_transform(train=True), 
    augment=config.USE_AUGMENTATION
)
dataset_val = CustomCocoDataset(
    val_dir, val_ann, 
    transforms=get_transform(train=False), 
    augment=False
)
dataset_test = CustomCocoDataset(
    test_dir, test_ann, 
    transforms=get_transform(train=False), 
    augment=False
)

print(f"Training images: {len(dataset_train)}")
print(f"Validation images: {len(dataset_val)}")
print(f"Test images: {len(dataset_test)}")

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=config.BATCH_SIZE, 
    shuffle=True, 
    num_workers=config.NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=config.NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=config.NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

## 5. Visualize Sample Data

In [None]:
def visualize_sample(dataset, idx=0, class_names=None, class_colors=None):
    """Visualize a sample from the dataset with bounding boxes"""
    img, target = dataset[idx]
    
    # Convert tensor to numpy for visualization
    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).numpy()
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(img)
    
    boxes = target['boxes']
    labels = target['labels']
    
    for box, label in zip(boxes, labels):
        x1, y1, x2, y2 = box.numpy() if isinstance(box, torch.Tensor) else box
        label_id = label.item() if isinstance(label, torch.Tensor) else label
        
        color = class_colors.get(label_id, 'blue') if class_colors else 'red'
        label_name = class_names.get(label_id, str(label_id)) if class_names else str(label_id)
        
        rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=2, edgecolor=color, facecolor='none'
        )
        ax.add_patch(rect)
        ax.text(x1, y1 - 5, label_name, color=color, fontsize=10, fontweight='bold')
    
    ax.axis('off')
    plt.title(f'Sample {idx}')
    plt.tight_layout()
    plt.show()

# Visualize a few samples
for i in [0, 10, 50]:
    visualize_sample(dataset_train, i, config.CLASS_NAMES, config.CLASS_COLORS)

## 6. Model Definition

In [None]:
def get_model(num_classes, pretrained=True):
    """
    Load RetinaNet with ResNet50 FPN backbone.
    Replace the classification head for custom number of classes.
    """
    if pretrained:
        weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
        model = retinanet_resnet50_fpn_v2(weights=weights)
    else:
        model = retinanet_resnet50_fpn_v2(weights=None)
    
    # Get model parameters
    in_channels = model.backbone.out_channels
    num_anchors = model.head.classification_head.num_anchors
    
    # Replace head with new one for our number of classes
    model.head = RetinaNetHead(
        in_channels,
        num_anchors,
        num_classes,
        norm_layer=torch.nn.BatchNorm2d
    )
    
    return model

# Create model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

model = get_model(config.NUM_CLASSES)
model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Optimizer and Scheduler

In [None]:
# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, 
    lr=config.LEARNING_RATE, 
    momentum=config.MOMENTUM, 
    weight_decay=config.WEIGHT_DECAY
)

# Learning rate scheduler
if config.LR_SCHEDULER == 'step':
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
elif config.LR_SCHEDULER == 'cosine':
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.NUM_EPOCHS, eta_min=1e-6
    )
elif config.LR_SCHEDULER == 'onecycle':
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config.LEARNING_RATE * 10,
        epochs=config.NUM_EPOCHS,
        steps_per_epoch=len(data_loader_train)
    )
else:
    lr_scheduler = None

# Mixed precision scaler
scaler = GradScaler() if config.USE_AMP else None

print(f"Optimizer: SGD (lr={config.LEARNING_RATE}, momentum={config.MOMENTUM})")
print(f"LR Scheduler: {config.LR_SCHEDULER}")
print(f"Mixed Precision: {config.USE_AMP}")

## 8. Training Functions

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, scaler=None, grad_clip=None):
    """Train for one epoch with mixed precision support"""
    model.train()
    
    total_loss = 0
    loss_classifier = 0
    loss_box_reg = 0
    
    pbar = tqdm(data_loader, desc=f'Epoch {epoch}')
    
    for images, targets in pbar:
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with autocast():
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
            
            scaler.scale(losses).backward()
            
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            losses.backward()
            
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()
        
        total_loss += losses.item()
        loss_classifier += loss_dict.get('classification', torch.tensor(0)).item()
        loss_box_reg += loss_dict.get('bbox_regression', torch.tensor(0)).item()
        
        pbar.set_postfix({
            'loss': f'{losses.item():.4f}',
            'cls': f'{loss_dict.get("classification", 0):.4f}',
            'box': f'{loss_dict.get("bbox_regression", 0):.4f}'
        })
    
    n = len(data_loader)
    return {
        'total_loss': total_loss / n,
        'cls_loss': loss_classifier / n,
        'box_loss': loss_box_reg / n
    }

## 9. Evaluation Functions

In [None]:
def compute_iou(box1, box2):
    """Compute IoU between two boxes"""
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0


@torch.no_grad()
def evaluate(model, data_loader, device, score_threshold=0.5, iou_threshold=0.5):
    """Evaluate model and compute metrics per class"""
    model.eval()
    
    # Metrics per class
    class_tp = defaultdict(int)
    class_fp = defaultdict(int)
    class_fn = defaultdict(int)
    class_iou_sum = defaultdict(float)
    class_iou_count = defaultdict(int)
    
    pbar = tqdm(data_loader, desc='Evaluating')
    
    for images, targets in pbar:
        images = [img.to(device) for img in images]
        outputs = model(images)
        
        for output, target in zip(outputs, targets):
            gt_boxes = target['boxes'].to(device)
            gt_labels = target['labels'].to(device)
            
            pred_boxes = output['boxes']
            pred_labels = output['labels']
            pred_scores = output['scores']
            
            # Filter by score
            keep = pred_scores > score_threshold
            pred_boxes = pred_boxes[keep]
            pred_labels = pred_labels[keep]
            pred_scores = pred_scores[keep]
            
            # Track matched ground truth
            gt_matched = [False] * len(gt_boxes)
            
            # For each prediction
            for pb, pl in zip(pred_boxes, pred_labels):
                best_iou = 0
                best_idx = -1
                
                # Find best matching GT
                for i, (gb, gl) in enumerate(zip(gt_boxes, gt_labels)):
                    if gt_matched[i]:
                        continue
                    if pl.item() != gl.item():
                        continue
                    
                    iou = compute_iou(pb.cpu().numpy(), gb.cpu().numpy())
                    if iou > best_iou:
                        best_iou = iou
                        best_idx = i
                
                label = pl.item()
                if best_iou >= iou_threshold:
                    class_tp[label] += 1
                    class_iou_sum[label] += best_iou
                    class_iou_count[label] += 1
                    gt_matched[best_idx] = True
                else:
                    class_fp[label] += 1
            
            # Count false negatives (unmatched GT)
            for i, (gb, gl) in enumerate(zip(gt_boxes, gt_labels)):
                if not gt_matched[i]:
                    class_fn[gl.item()] += 1
    
    # Compute metrics
    results = {}
    all_tp, all_fp, all_fn = 0, 0, 0
    
    for label in set(list(class_tp.keys()) + list(class_fn.keys())):
        tp = class_tp[label]
        fp = class_fp[label]
        fn = class_fn[label]
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        avg_iou = class_iou_sum[label] / class_iou_count[label] if class_iou_count[label] > 0 else 0
        
        results[label] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'avg_iou': avg_iou,
            'tp': tp, 'fp': fp, 'fn': fn
        }
        
        all_tp += tp
        all_fp += fp
        all_fn += fn
    
    # Overall metrics
    overall_precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
    overall_recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
    overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
    
    results['overall'] = {
        'precision': overall_precision,
        'recall': overall_recall,
        'f1': overall_f1,
        'tp': all_tp, 'fp': all_fp, 'fn': all_fn
    }
    
    return results


def print_metrics(results, class_names):
    """Pretty print evaluation metrics"""
    print("\n" + "="*70)
    print(f"{'Class':<15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Avg IoU':>10}")
    print("="*70)
    
    for label, metrics in results.items():
        if label == 'overall':
            continue
        name = class_names.get(label, str(label))
        print(f"{name:<15} {metrics['precision']:>10.4f} {metrics['recall']:>10.4f} "
              f"{metrics['f1']:>10.4f} {metrics.get('avg_iou', 0):>10.4f}")
    
    print("-"*70)
    overall = results['overall']
    print(f"{'Overall':<15} {overall['precision']:>10.4f} {overall['recall']:>10.4f} "
          f"{overall['f1']:>10.4f}")
    print("="*70 + "\n")

## 10. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_cls_loss': [],
    'train_box_loss': [],
    'val_f1': [],
    'val_precision': [],
    'val_recall': [],
    'lr': []
}

best_f1 = 0.0
patience_counter = 0

print(f"Starting training for {config.NUM_EPOCHS} epochs...")
print(f"Device: {device}")
print(f"Batch size: {config.BATCH_SIZE}")
print(f"Learning rate: {config.LEARNING_RATE}")
print("="*50)

In [None]:
for epoch in range(config.NUM_EPOCHS):
    # Train
    train_metrics = train_one_epoch(
        model, optimizer, data_loader_train, device, epoch,
        scaler=scaler, grad_clip=config.GRAD_CLIP
    )
    
    # Update learning rate
    current_lr = optimizer.param_groups[0]['lr']
    if lr_scheduler is not None and config.LR_SCHEDULER != 'onecycle':
        lr_scheduler.step()
    
    # Evaluate
    val_results = evaluate(
        model, data_loader_val, device,
        score_threshold=config.SCORE_THRESHOLD,
        iou_threshold=config.IOU_THRESHOLD
    )
    
    # Record history
    history['train_loss'].append(train_metrics['total_loss'])
    history['train_cls_loss'].append(train_metrics['cls_loss'])
    history['train_box_loss'].append(train_metrics['box_loss'])
    history['val_f1'].append(val_results['overall']['f1'])
    history['val_precision'].append(val_results['overall']['precision'])
    history['val_recall'].append(val_results['overall']['recall'])
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch}/{config.NUM_EPOCHS-1}")
    print(f"  Train Loss: {train_metrics['total_loss']:.4f} "
          f"(cls: {train_metrics['cls_loss']:.4f}, box: {train_metrics['box_loss']:.4f})")
    print(f"  Val F1: {val_results['overall']['f1']:.4f} "
          f"(P: {val_results['overall']['precision']:.4f}, R: {val_results['overall']['recall']:.4f})")
    print(f"  LR: {current_lr:.6f}")
    
    # Print per-class metrics
    print_metrics(val_results, config.CLASS_NAMES)
    
    # Save best model
    val_f1 = val_results['overall']['f1']
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_f1': best_f1,
            'config': vars(config)
        }, os.path.join(config.OUTPUT_DIR, 'best_retinanet.pth'))
        print(f"  âœ“ Saved new best model (F1: {best_f1:.4f})")
    else:
        patience_counter += 1
    
    # Save last model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_f1': best_f1,
        'config': vars(config)
    }, os.path.join(config.OUTPUT_DIR, 'last_retinanet.pth'))
    
    # Early stopping
    if patience_counter >= config.EARLY_STOPPING_PATIENCE:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break

print(f"\nTraining complete! Best F1: {best_f1:.4f}")

## 11. Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Training Loss
ax1 = axes[0, 0]
ax1.plot(history['train_loss'], label='Total Loss', linewidth=2)
ax1.plot(history['train_cls_loss'], label='Classification Loss', linewidth=2)
ax1.plot(history['train_box_loss'], label='Box Regression Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Validation Metrics
ax2 = axes[0, 1]
ax2.plot(history['val_f1'], label='F1 Score', linewidth=2)
ax2.plot(history['val_precision'], label='Precision', linewidth=2)
ax2.plot(history['val_recall'], label='Recall', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Score')
ax2.set_title('Validation Metrics')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning Rate
ax3 = axes[1, 0]
ax3.plot(history['lr'], linewidth=2, color='green')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_title('Learning Rate Schedule')
ax3.grid(True, alpha=0.3)

# F1 Score (main metric)
ax4 = axes[1, 1]
ax4.plot(history['val_f1'], linewidth=2, color='blue')
ax4.axhline(y=best_f1, color='r', linestyle='--', label=f'Best F1: {best_f1:.4f}')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('F1 Score')
ax4.set_title('Validation F1 Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'training_curves.png'), dpi=150)
plt.show()

## 12. Test Set Evaluation

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(config.OUTPUT_DIR, 'best_retinanet.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']} (F1: {checkpoint['best_f1']:.4f})")

# Evaluate on test set
print("\nEvaluating on Test Set...")
test_results = evaluate(
    model, data_loader_test, device,
    score_threshold=config.SCORE_THRESHOLD,
    iou_threshold=config.IOU_THRESHOLD
)

print("\n" + "="*50)
print("TEST SET RESULTS")
print("="*50)
print_metrics(test_results, config.CLASS_NAMES)

## 13. Visualize Predictions

In [None]:
@torch.no_grad()
def visualize_predictions(model, dataset, device, indices=None, num_samples=5, 
                          score_threshold=0.5, class_names=None, class_colors=None):
    """Visualize model predictions on sample images"""
    model.eval()
    
    if indices is None:
        indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    for idx in indices:
        img, target = dataset[idx]
        
        # Get prediction
        prediction = model([img.to(device)])[0]
        
        # Convert image for display
        img_np = img.permute(1, 2, 0).numpy()
        
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        # Ground Truth
        axes[0].imshow(img_np)
        axes[0].set_title('Ground Truth')
        for box, label in zip(target['boxes'], target['labels']):
            x1, y1, x2, y2 = box.numpy()
            label_id = label.item()
            color = class_colors.get(label_id, 'blue') if class_colors else 'red'
            name = class_names.get(label_id, str(label_id)) if class_names else str(label_id)
            
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                     linewidth=2, edgecolor=color, facecolor='none')
            axes[0].add_patch(rect)
            axes[0].text(x1, y1-5, name, color=color, fontsize=10, fontweight='bold')
        axes[0].axis('off')
        
        # Predictions
        axes[1].imshow(img_np)
        axes[1].set_title('Predictions')
        
        keep = prediction['scores'] > score_threshold
        pred_boxes = prediction['boxes'][keep].cpu()
        pred_labels = prediction['labels'][keep].cpu()
        pred_scores = prediction['scores'][keep].cpu()
        
        for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
            x1, y1, x2, y2 = box.numpy()
            label_id = label.item()
            color = class_colors.get(label_id, 'blue') if class_colors else 'red'
            name = class_names.get(label_id, str(label_id)) if class_names else str(label_id)
            
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                     linewidth=2, edgecolor=color, facecolor='none')
            axes[1].add_patch(rect)
            axes[1].text(x1, y1-5, f'{name}: {score:.2f}', color=color, fontsize=10, fontweight='bold')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()

# Visualize predictions on test set
visualize_predictions(
    model, dataset_test, device, 
    num_samples=6,
    score_threshold=config.SCORE_THRESHOLD,
    class_names=config.CLASS_NAMES,
    class_colors=config.CLASS_COLORS
)

## 14. Inference Function

In [None]:
@torch.no_grad()
def predict(model, image_path, device, score_threshold=0.5, class_names=None):
    """
    Run inference on a single image.
    
    Args:
        model: Trained model
        image_path: Path to image
        device: Device to run on
        score_threshold: Confidence threshold
        class_names: Dict mapping class IDs to names
    
    Returns:
        List of detections: [{'box': [x1,y1,x2,y2], 'label': str, 'score': float}, ...]
    """
    model.eval()
    
    # Load and preprocess image
    img = Image.open(image_path).convert('RGB')
    img_tensor = F.to_tensor(img).unsqueeze(0).to(device)
    
    # Run inference
    outputs = model(img_tensor)[0]
    
    # Filter by score
    keep = outputs['scores'] > score_threshold
    boxes = outputs['boxes'][keep].cpu().numpy()
    labels = outputs['labels'][keep].cpu().numpy()
    scores = outputs['scores'][keep].cpu().numpy()
    
    # Format results
    detections = []
    for box, label, score in zip(boxes, labels, scores):
        detections.append({
            'box': box.tolist(),
            'label': class_names.get(label, str(label)) if class_names else str(label),
            'label_id': int(label),
            'score': float(score)
        })
    
    return detections

# Example usage
sample_image = os.path.join(test_dir, os.listdir(test_dir)[0])
if sample_image.endswith('.jpg') or sample_image.endswith('.png'):
    detections = predict(model, sample_image, device, 
                        score_threshold=config.SCORE_THRESHOLD, 
                        class_names=config.CLASS_NAMES)
    print(f"Detections for {os.path.basename(sample_image)}:")
    for det in detections:
        print(f"  {det['label']}: {det['score']:.3f} at {det['box']}")

## 15. Export Model (Optional)

In [None]:
# Save model for inference
model.eval()

# Save just the model weights (smaller file)
torch.save(model.state_dict(), os.path.join(config.OUTPUT_DIR, 'retinanet_weights.pth'))

# Save complete model (easier to load)
torch.save(model, os.path.join(config.OUTPUT_DIR, 'retinanet_complete.pth'))

print(f"Models saved to {config.OUTPUT_DIR}")
print(f"  - retinanet_weights.pth (weights only)")
print(f"  - retinanet_complete.pth (complete model)")
print(f"  - best_retinanet.pth (checkpoint with optimizer state)")

## Summary

### Optimizations Made:
1. **Mixed Precision Training (AMP)** - Faster training with lower memory usage
2. **Data Augmentation** - Random horizontal flip and color jittering
3. **Cosine Annealing LR** - Smoother learning rate decay
4. **Gradient Clipping** - Training stability
5. **Early Stopping** - Prevent overfitting
6. **Per-class Metrics** - Better evaluation
7. **Visualization Tools** - Sample predictions
8. **Clean Inference API** - Easy to use for deployment

### Files Saved:
- `checkpoints/best_retinanet.pth` - Best model checkpoint
- `checkpoints/last_retinanet.pth` - Latest checkpoint
- `checkpoints/training_curves.png` - Training visualization
- `checkpoints/retinanet_weights.pth` - Model weights only
- `checkpoints/retinanet_complete.pth` - Complete model