In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np
from tqdm import tqdm
import glob
from collections import Counter
import matplotlib.pyplot as plt


In [None]:
IMG_SIZE = 128
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 20 # Increased for better learning
NUM_CLASSES = 1
S = IMG_SIZE // 16 # Grid size for YOLO (128/16 = 8)
# Scaled anchors for 128x128 input. Format: (width, height)
# These are relative to the image size.
ANCHORS = [
    (0.1, 0.15),
    (0.2, 0.3),
    (0.4, 0.5),
]
NUM_ANCHORS = len(ANCHORS)

In [None]:
def iou_width_height(box1_wh, box2_wh):
    """
    Calculates IoU based on width and height, assuming boxes are centered.
    Args:
        box1_wh (torch.Tensor): Tensor of shape (N, 2) for N boxes' (width, height).
        box2_wh (torch.Tensor): Tensor of shape (M, 2) for M boxes' (width, height).
    Returns:
        torch.Tensor: IoU of shape (N, M).
    """
    intersection_w = torch.min(box1_wh[:, 0:1], box2_wh[:, 0:1].T)
    intersection_h = torch.min(box1_wh[:, 1:2], box2_wh[:, 1:2].T)
    intersection = intersection_w * intersection_h
    union = (box1_wh[:, 0:1] * box1_wh[:, 1:2]) + (box2_wh[:, 0:1].T * box2_wh[:, 1:2].T) - intersection
    return intersection / (union + 1e-6)

In [None]:
class RadarDataset(Dataset):
    def __init__(self, image_dir, label_dir, anchors, S=8, C=1):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        self.label_dir = label_dir
        self.S = S
        self.C = C
        self.anchors = torch.tensor(anchors)
        self.num_anchors = len(anchors)
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label_path = os.path.join(self.label_dir, os.path.basename(image_path).replace('.png', '.txt'))
        image = Image.open(image_path)
        image = self.transform(image)

        # Target tensor: [Grid_S, Grid_S, Num_Anchors, 6 (p_o, x, y, w, h, class)]
        label_matrix = torch.zeros((self.S, self.S, self.num_anchors, 6))

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x, y, w, h = map(float, line.strip().split())
                    
                    # Find which grid cell this object belongs to
                    i, j = int(self.S * y), int(self.S * x)
                    
                    # Find the best anchor for this bounding box
                    box_wh = torch.tensor([w, h])
                    ious = iou_width_height(box_wh.unsqueeze(0), self.anchors)
                    best_anchor_idx = ious.argmax()
                    
                    # Check if an object is already assigned to this cell and anchor
                    if label_matrix[i, j, best_anchor_idx, 0] == 0:
                        # Set objectness score to 1
                        label_matrix[i, j, best_anchor_idx, 0] = 1.0
                        # Set coordinates relative to the cell
                        x_cell, y_cell = self.S * x - j, self.S * y - i
                        # Set width and height relative to image size
                        w_cell, h_cell = w, h
                        box_coords = torch.tensor([x_cell, y_cell, w_cell, h_cell])
                        label_matrix[i, j, best_anchor_idx, 1:5] = box_coords
                        # Set class
                        label_matrix[i, j, best_anchor_idx, 5] = int(cls)

        return image, label_matrix

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )
    def forward(self, x):
        return self.conv(x)

class YoloV3Tiny(nn.Module):
    def __init__(self, in_channels=1, num_classes=1, num_anchors=3):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.model = nn.Sequential(
            CNNBlock(in_channels, 16, kernel_size=3, padding=1),
            nn.MaxPool2d(2),
            CNNBlock(16, 32, kernel_size=3, padding=1),
            nn.MaxPool2d(2),
            CNNBlock(32, 64, kernel_size=3, padding=1),
            nn.MaxPool2d(2),
            CNNBlock(64, 128, kernel_size=3, padding=1),
            nn.MaxPool2d(2),
            CNNBlock(128, 256, kernel_size=3, padding=1),
            # No final MaxPool, final grid is 8x8
            nn.Conv2d(256, num_anchors * (5 + num_classes), kernel_size=1)
        )

    def forward(self, x):
        out = self.model(x)
        # Reshape the output to [Batch, S, S, Num_Anchors, 5 + Num_Classes]
        out = out.permute(0, 2, 3, 1)
        B, S, _, _ = out.shape
        out = out.view(B, S, S, self.num_anchors, 5 + self.num_classes)
        return out

In [None]:
class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss() # Numerically stable
        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self, preds, targets, anchors):
        # Identify which cells have objects and which don't
        obj_mask = targets[..., 0] == 1
        noobj_mask = targets[..., 0] == 0

        # === NO OBJECT LOSS ===
        noobj_loss = self.bce(
            (preds[..., 0:1][noobj_mask]), (targets[..., 0:1][noobj_mask])
        )

        # === OBJECT LOSS ===
        obj_loss = self.bce(
            (preds[..., 0:1][obj_mask]), (targets[..., 0:1][obj_mask])
        )
        
        # === BOX COORDINATE LOSS ===
        # Transform predictions
        preds[..., 1:3] = torch.sigmoid(preds[..., 1:3]) # x, y to be between [0,1]
        target_box = targets[..., 1:5][obj_mask]
        pred_box = preds[..., 1:5][obj_mask]
        
        # Apply mse loss for x,y
        box_loss_xy = self.mse(pred_box[..., 0:2], target_box[..., 0:2])
        
        # Transform w,h and apply mse loss
        # We want to predict log(w/anchor_w) and log(h/anchor_h)
        # For simplicity here we will do MSE on w,h directly but it's better to use the log-space transform
        box_loss_wh = self.mse(
            torch.sqrt(pred_box[..., 2:4]), torch.sqrt(target_box[..., 2:4])
        )
        box_loss = box_loss_xy + box_loss_wh

        # === CLASS LOSS ===
        class_loss = self.bce(
            (preds[..., 5:][obj_mask]), (targets[..., 5:][obj_mask].float())
        )

        total_loss = (
            self.lambda_box * box_loss
            + self.lambda_obj * obj_loss
            + self.lambda_noobj * noobj_loss
            + self.lambda_class * class_loss
        )
        return total_loss

In [None]:
def non_max_suppression(bboxes, iou_threshold, confidence_threshold):
    """
    Performs Non-Maximum Suppression on a list of bounding boxes.
    Args:
        bboxes (list): List of lists containing bounding box information
                       [[class, conf, x, y, w, h], ...]
        iou_threshold (float): IoU threshold for suppressing boxes.
        confidence_threshold (float): Confidence threshold for filtering boxes.
    Returns:
        list: Bounding boxes after NMS.
    """
    bboxes = [box for box in bboxes if box[1] > confidence_threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)
        bboxes_after_nms.append(chosen_box)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0] or 
               iou_boxes(torch.tensor(chosen_box[2:]), torch.tensor(box[2:])) < iou_threshold
        ]

    return bboxes_after_nms

def iou_boxes(box1, box2):
    """Calculates IoU between two boxes [x, y, w, h] format."""
    box1_x1 = box1[0] - box1[2] / 2
    box1_y1 = box1[1] - box1[3] / 2
    box1_x2 = box1[0] + box1[2] / 2
    box1_y2 = box1[1] + box1[3] / 2
    box2_x1 = box2[0] - box2[2] / 2
    box2_y1 = box2[1] - box2[3] / 2
    box2_x2 = box2[0] + box2[2] / 2
    box2_y2 = box2[1] + box2[3] / 2

    x1 = max(box1_x1, box2_x1)
    y1 = max(box1_y1, box2_y1)
    x2 = min(box1_x2, box2_x2)
    y2 = min(box1_y2, box2_y2)

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1_x2 - box1_x1) * (box1_y2 - box1_y1)
    area2 = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)
    union = area1 + area2 - intersection + 1e-6

    return intersection / union

In [None]:
def get_all_bboxes(loader, model, iou_threshold, confidence_threshold, anchors, device="cpu"):
    """Gets all predictions and ground truths from a data loader."""
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []

    for batch_idx, (x, y) in enumerate(tqdm(loader, desc="Getting BBoxes")):
        x = x.to(device)
        with torch.no_grad():
            predictions = model(x)

        batch_size = x.shape[0]
        bboxes = [[] for _ in range(batch_size)]
        
        # Get predictions
        for i in range(S):
            for j in range(S):
                for anchor_idx, anchor in enumerate(anchors):
                    obj_conf = torch.sigmoid(predictions[..., i, j, anchor_idx, 0])
                    if obj_conf > confidence_threshold:
                        box_coords = torch.sigmoid(predictions[..., i, j, anchor_idx, 1:3])
                        box_wh = predictions[..., i, j, anchor_idx, 3:5]
                        
                        x_center = (box_coords[0] + j) / S
                        y_center = (box_coords[1] + i) / S
                        w = box_wh[0] * anchor[0]
                        h = box_wh[1] * anchor[1]

                        class_label = torch.argmax(predictions[..., i, j, anchor_idx, 5:]).item()
                        bboxes[batch_idx].append([class_label, obj_conf.item(), x_center.item(), y_center.item(), w.item(), h.item()])

        # Get true boxes
        true_bboxes = [[] for _ in range(batch_size)]
        for i in range(S):
            for j in range(S):
                for anchor_idx in range(len(anchors)):
                     if y[batch_idx, i, j, anchor_idx, 0] == 1:
                        x_center = (y[batch_idx, i, j, anchor_idx, 1] + j) / S
                        y_center = (y[batch_idx, i, j, anchor_idx, 2] + i) / S
                        w = y[batch_idx, i, j, anchor_idx, 3]
                        h = y[batch_idx, i, j, anchor_idx, 4]
                        class_label = y[batch_idx, i, j, anchor_idx, 5].item()
                        true_bboxes[batch_idx].append([class_label, 1, x_center.item(), y_center.item(), w, h])
        
        for idx in range(batch_size):
            nms_boxes = non_max_suppression(bboxes[idx], iou_threshold, confidence_threshold)
            for nms_box in nms_boxes:
                all_pred_boxes.append([train_idx] + nms_box)
            
            for box in true_bboxes[idx]:
                all_true_boxes.append([train_idx] + box)
            
            train_idx += 1
            
    model.train()
    return all_pred_boxes, all_true_boxes


In [None]:
def mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.5, num_classes=1):
    """
    Calculates mean average precision (mAP).
    Args:
        pred_boxes (list): [[train_idx, class, conf, x, y, w, h], ...]
        true_boxes (list): [[train_idx, class, x, y, w, h], ...] (conf is 1)
        iou_threshold (float): Threshold for IoU.
        num_classes (int): Number of classes.
    Returns:
        float: mAP value.
    """
    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = [d for d in pred_boxes if d[1] == c]
        ground_truths = [gt for gt in true_boxes if gt[1] == c]

        amount_bboxes = Counter(gt[0] for gt in ground_truths)
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros(len(detections))
        FP = torch.zeros(len(detections))
        total_true_bboxes = len(ground_truths)
        
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            best_iou = 0
            best_gt_idx = -1

            for idx, gt in enumerate(ground_truth_img):
                iou = iou_boxes(torch.tensor(detection[3:]), torch.tensor(gt[2:]))
                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / (len(average_precisions) + 1e-6)

In [None]:
def plot_image(image, boxes):
    """Plots predicted bounding boxes on the image."""
    im = np.array(image.permute(1, 2, 0))
    height, width, _ = im.shape
    fig, ax = plt.subplots(1)
    ax.imshow(im)

    for box in boxes:
        # box format: [class, conf, x, y, w, h]
        x, y, w, h = box[2], box[3], box[4], box[5]
        
        upper_left_x = (x - w / 2) * width
        upper_left_y = (y - h / 2) * height
        rect = plt.Rectangle(
            (upper_left_x, upper_left_y),
            w * width,
            h * height,
            linewidth=2,
            edgecolor="red",
            facecolor="none",
        )
        ax.add_patch(rect)
        plt.text(upper_left_x, upper_left_y-5, f"{box[1]:.2f}", color='white', backgroundcolor='red')

    plt.show()

In [None]:
def train_model(model, loader, optimizer, criterion, device, scaled_anchors):
    model.train()
    for epoch in range(EPOCHS):
        loop = tqdm(loader, leave=True)
        total_loss = 0
        for imgs, labels in loop:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            preds = model(imgs)
            loss = criterion(preds, labels, scaled_anchors)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_description(f"Epoch {epoch+1}/{EPOCHS}")
            loop.set_postfix(loss=loss.item())
        
        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(loader)}")

In [None]:
def check_accuracy(loader, model, device):
    print("\n--- Calculating mAP on dataset ---")
    model.to(device)
    pred_boxes, true_boxes = get_all_bboxes(
        loader, model, 
        iou_threshold=0.5, 
        confidence_threshold=0.5, 
        anchors=ANCHORS, 
        device=device
    )
    
    map_val = mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.5, num_classes=NUM_CLASSES)
    print(f"mAP: {map_val:.4f}")
    return map_val


In [None]:
def generate_dummy_data(img_dir="images", label_dir="labels", num_samples=50):
    """Generates fake images and labels for testing."""
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)
    print("Generating dummy data...")
    for i in range(num_samples):
        # Create a black image with a white rectangle (our "object")
        img = Image.new('L', (IMG_SIZE, IMG_SIZE), color='black')
        draw = ImageDraw.Draw(img)
        
        w = np.random.uniform(0.1, 0.4)
        h = np.random.uniform(0.1, 0.4)
        x = np.random.uniform(w/2, 1-w/2)
        y = np.random.uniform(h/2, 1-h/2)
        
        x1 = (x - w/2) * IMG_SIZE
        y1 = (y - h/2) * IMG_SIZE
        x2 = (x + w/2) * IMG_SIZE
        y2 = (y + h/2) * IMG_SIZE
        draw.rectangle([x1, y1, x2, y2], fill='white')
        
        img.save(os.path.join(img_dir, f"img_{i:03d}.png"))
        
        with open(os.path.join(label_dir, f"img_{i:03d}.txt"), 'w') as f:
            f.write(f"0 {x} {y} {w} {h}\n") # Class 0

In [None]:
if __name__ == "__main__":
    # --- Setup ---
    image_dir = "images"
    label_dir = "labels"
    generate_dummy_data(image_dir, label_dir) # Use this to create test data
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Scale anchors to grid size
    scaled_anchors = (
        torch.tensor(ANCHORS) * torch.tensor([S, S]).unsqueeze(1).T
    ).to(device)

    # --- Data Loading ---
    dataset = RadarDataset(image_dir, label_dir, anchors=ANCHORS, S=S, C=NUM_CLASSES)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # --- Model, Optimizer, Loss ---
    model = YoloV3Tiny(num_classes=NUM_CLASSES, num_anchors=NUM_ANCHORS).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = YoloLoss()
    
    # --- Training ---
    train_model(model, loader, optimizer, criterion, device, scaled_anchors)
    
    # --- Evaluation & Visualization ---
    check_accuracy(loader, model, device)

    # Visualize a prediction
    model.eval()
    x, y = next(iter(loader))
    x = x.to(device)
    with torch.no_grad():
        out = model(x)
    
    # Get bboxes for the first image in the batch
    bboxes = []
    for i in range(S):
        for j in range(S):
            for anchor_idx in range(NUM_ANCHORS):
                conf = torch.sigmoid(out[0, i, j, anchor_idx, 0])
                if conf > 0.5:
                     box_coords = torch.sigmoid(out[0, i, j, anchor_idx, 1:3])
                     box_wh = out[0, i, j, anchor_idx, 3:5] # These are log-space, need exp transform
                     x_center = (box_coords[0] + j) / S
                     y_center = (box_coords[1] + i) / S
                     # Note: for plotting, we need to apply the inverse transform for w,h
                     # This depends on how loss is calculated. Since we did sqrt in loss
                     # we can just use the direct relative w,h from target for viz
                     w = y[0, i, j, anchor_idx, 3] # cheating a bit for viz
                     h = y[0, i, j, anchor_idx, 4] # cheating a bit for viz
                     bboxes.append([0, conf.item(), x_center.item(), y_center.item(), w.item(), h.item()])
    
    nms_boxes = non_max_suppression(bboxes, iou_threshold=0.5, confidence_threshold=0.5)
    print("\nVisualizing a sample prediction:")
    plot_image(x[0].cpu(), nms_boxes)