In [1]:
score_threshold = 0.5

# Utility Functions to compute F2 scores for validation

In [None]:
#Utility Functions to compute F2 scores for validation
import torch
from tqdm import tqdm

def compute_iou(boxA, boxB):
    """
    Compute the Intersection-over-Union (IoU) between two bounding boxes.

    Args:
        boxA (array-like): [x1, y1, x2, y2]
        boxB (array-like): [x1, y1, x2, y2]

    Returns:
        float: IoU value
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interArea = max(0, xB - xA) * max(0, yB - yA)
    if interArea == 0:
        return 0.0

    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])

    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou


def compute_f2_score(pred_boxes, gt_boxes, iou_threshold=0.5):
    """
    Compute F2 score for a single image, given predicted and ground-truth boxes.

    Args:
        pred_boxes (np.ndarray): Array of predicted boxes [N_pred, 4]
        gt_boxes (np.ndarray): Array of ground truth boxes [N_gt, 4]
        iou_threshold (float): IoU threshold to consider a detection a true positive

    Returns:
        float: F2 score for this image
    """
    matched_gt = set()
    tp = 0  # True positives

    for pred_box in pred_boxes:
        for i, gt_box in enumerate(gt_boxes):
            if i in matched_gt:
                continue
            if compute_iou(pred_box, gt_box) >= iou_threshold:
                tp += 1
                matched_gt.add(i)
                break  # Move to next predicted box
    fp = len(pred_boxes) - tp  # False positives
    fn = len(gt_boxes) - tp    # False negatives

    beta2 = 4  # beta^2 for F2 score (beta=2)

    # Avoid division by zero
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f2 = (1 + beta2) * (precision * recall) / (beta2 * precision + recall)
    return f2

# Modified Training / Validation Loop

In [None]:
# Training hyperparameters
num_epochs = 10
lr = 0.005

# Optimizer setup
optimizer = torch.optim.SGD(
    model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005
)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    pbar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)

    # Training step

    for images, targets in pbar:

        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_loss += losses.item()

    print(f"[Epoch {epoch + 1}/{num_epochs}] Loss: {total_loss:.4f}")

    # ---- Validation Step: Compute average F2 score ----

    model.eval()
    val_f2s = []
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Validation", leave=False):
            images = [img.to(device) for img in images]

            # Get ground truth boxes as numpy arrays for each image in batch
            gt_boxes_batch = [t["boxes"].cpu().numpy() for t in targets]

            # Inference step
            outputs = model(images)

            # Each output is a dict with "boxes", "scores", "labels"
            for pred, gt_boxes in zip(outputs, gt_boxes_batch):
                pred_boxes = pred["boxes"].cpu().numpy()

                # Keep detections with score > score_threshold

                if "scores" in pred:
                    keep = pred["scores"].cpu().numpy() > score_threshold
                    pred_boxes = pred_boxes[keep]

                # Compute F2 score for this image and add to list
                f2 = compute_f2_score(pred_boxes, gt_boxes)
                val_f2s.append(f2)

    # Compute average F2 across all validation images for this epoch
    avg_f2 = sum(val_f2s) / len(val_f2s) if val_f2s else 0.0
    print(f"[Epoch {epoch + 1}/{num_epochs}] Validation F2: {avg_f2:.4f}")