In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np
from tqdm import tqdm
import glob
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
IMG_SIZE = 416
BATCH_SIZE = 8
LEARNING_RATE = 1e-5 # Lower LR for a more complex model
EPOCHS = 10 
NUM_CLASSES = 11
CONFIDENCE_THRESHOLD = 0.6
IOU_THRESHOLD = 0.5
ANCHORS = [
    [(116, 90), (156, 198), (373, 326)],  # Scale 1 (13x13) for large objects
    [(30, 61), (62, 45), (59, 119)],      # Scale 2 (26x26) for medium objects
    [(10, 13), (16, 30), (33, 23)],        # Scale 3 (52x52) for small objects
]
S = [IMG_SIZE // 32, IMG_SIZE // 16, IMG_SIZE // 8] # Strides, Grid sizes: [13, 26, 52]

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not use_bn, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.leaky = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leaky(self.bn(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return x

class PredictionHead(nn.Module):
    def __init__(self, in_channels, num_classes, anchors):
        super().__init__()
        self.num_classes = num_classes
        self.anchors = anchors
        self.num_anchors = len(anchors)

        self.head = nn.Sequential(
            CNNBlock(in_channels, in_channels * 2, kernel_size=3, padding=1),
            CNNBlock(in_channels * 2, (self.num_anchors * (5 + num_classes)), use_bn=False, kernel_size=1),
        )

    def forward(self, x):
        # Reshape the output to [Batch, Num_Anchors, Grid_S, Grid_S, 5 + Num_Classes]
        out = self.head(x)
        out = out.view(x.shape[0], self.num_anchors, 5 + self.num_classes, x.shape[2], x.shape[3])
        out = out.permute(0, 1, 3, 4, 2)
        return out

class Darknet53(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.in_channels = in_channels
        self.layers = nn.ModuleList([
            CNNBlock(in_channels, 32, kernel_size=3, padding=1),
            CNNBlock(32, 64, kernel_size=3, padding=1, stride=2),
            ResidualBlock(64, num_repeats=1),
            CNNBlock(64, 128, kernel_size=3, padding=1, stride=2),
            ResidualBlock(128, num_repeats=2),
            CNNBlock(128, 256, kernel_size=3, padding=1, stride=2),
            ResidualBlock(256, num_repeats=8), # -> Route 1
            CNNBlock(256, 512, kernel_size=3, padding=1, stride=2),
            ResidualBlock(512, num_repeats=8), # -> Route 2
            CNNBlock(512, 1024, kernel_size=3, padding=1, stride=2),
            ResidualBlock(1024, num_repeats=4),
        ])

    def forward(self, x):
        outputs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # The routes are the outputs of the last three residual blocks
            if i in [6, 8]:
                outputs.append(x)
        outputs.append(x)
        return outputs[0], outputs[1], outputs[2] # 52x52, 26x26, 13x13

In [None]:
class YOLOv3(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.backbone = Darknet53(in_channels=in_channels)
        
        # Prediction Heads for each scale
        self.head1 = PredictionHead(1024, num_classes, ANCHORS[0]) # Large scale
        self.head2 = PredictionHead(512, num_classes, ANCHORS[1])  # Medium scale
        self.head3 = PredictionHead(256, num_classes, ANCHORS[2])  # Small scale

        self.conv1 = CNNBlock(1024, 512, kernel_size=1)
        self.conv2 = CNNBlock(512, 256, kernel_size=1)
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")

    def forward(self, x):
        route3, route2, route1 = self.backbone(x) # small, medium, large routes
        
        # Scale 1 prediction (large objects)
        out1 = self.head1(route1)
        
        # Scale 2 prediction (medium objects)
        x = self.conv1(route1)
        x = self.upsample(x)
        x = torch.cat([x, route2], dim=1)
        out2 = self.head2(x)

        # Scale 3 prediction (small objects)
        x = self.conv2(x)
        x = self.upsample(x)
        x = torch.cat([x, route3], dim=1)
        out3 = self.head3(x)

        return out1, out2, out3

In [None]:
def iou_width_height(box1_wh, box2_wh):
    """
    Calculates IoU based on width and height, assuming boxes are centered.
    This is used for matching ground truth boxes to the best anchor box.
    
    Args:
        box1_wh (torch.Tensor): Tensor of shape (N, 2) for N boxes' (width, height).
        box2_wh (torch.Tensor): Tensor of shape (M, 2) for M boxes' (width, height).
    
    Returns:
        torch.Tensor: IoU of shape (N, M).
    """
    intersection_w = torch.min(box1_wh[:, 0:1], box2_wh[:, 0:1].T)
    intersection_h = torch.min(box1_wh[:, 1:2], box2_wh[:, 1:2].T)
    intersection = intersection_w * intersection_h
    
    box1_area = box1_wh[:, 0:1] * box1_wh[:, 1:2]
    box2_area = box2_wh[:, 0:1].T * box2_wh[:, 1:2].T
    union = box1_area + box2_area - intersection
    
    return intersection / (union + 1e-6)

In [None]:
def iou_boxes(box1, box2, box_format="xywh"):
    """
    Calculates Intersection over Union (IoU) between two bounding boxes.
    This is used during evaluation (NMS and mAP).

    Args:
        box1 (torch.Tensor): Bounding box 1.
        box2 (torch.Tensor): Bounding box 2.
        box_format (str): "xywh" (center_x, center_y, width, height) or "xyxy" (x1, y1, x2, y2).

    Returns:
        torch.Tensor: IoU value.
    """
    if box_format == "xywh":
        box1_x1 = box1[..., 0:1] - box1[..., 2:3] / 2
        box1_y1 = box1[..., 1:2] - box1[..., 3:4] / 2
        box1_x2 = box1[..., 0:1] + box1[..., 2:3] / 2
        box1_y2 = box1[..., 1:2] + box1[..., 3:4] / 2
        box2_x1 = box2[..., 0:1] - box2[..., 2:3] / 2
        box2_y1 = box2[..., 1:2] - box2[..., 3:4] / 2
        box2_x2 = box2[..., 0:1] + box2[..., 2:3] / 2
        box2_y2 = box2[..., 1:2] + box2[..., 3:4] / 2
    elif box_format == "xyxy":
        box1_x1, box1_y1, box1_x2, box1_y2 = box1[..., 0:1], box1[..., 1:2], box1[..., 2:3], box1[..., 3:4]
        box2_x1, box2_y1, box2_x2, box2_y2 = box2[..., 0:1], box2[..., 1:2], box2[..., 2:3], box2[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    area1 = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    area2 = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = area1 + area2 - intersection + 1e-6

    return intersection / union

In [None]:
def non_max_suppression(bboxes, iou_threshold, confidence_threshold):
    """
    Performs Non-Maximum Suppression on a list of bounding boxes to filter duplicates.
    
    Args:
        bboxes (list): List of lists: [[class, conf, x, y, w, h], ...]
        iou_threshold (float): IoU threshold for suppressing boxes.
        confidence_threshold (float): Confidence threshold for filtering boxes.
    
    Returns:
        list: Bounding boxes after NMS.
    """
    # Filter out boxes with low confidence
    bboxes = [box for box in bboxes if box[1] > confidence_threshold]
    # Sort boxes by confidence score in descending order
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)
        
        # Keep only boxes of different classes or with low IoU
        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0] or 
               iou_boxes(torch.tensor(chosen_box[2:]), torch.tensor(box[2:])) < iou_threshold
        ]
        
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

In [None]:
class RadarDataset(Dataset):
    def __init__(self, image_dir, label_dir, anchors, S, C=1):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        self.label_dir = label_dir
        self.S = S
        self.C = C
        self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 3
        self.ignore_iou_thresh = 0.5
        
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label_path = os.path.join(self.label_dir, os.path.basename(image_path).replace('.png', '.txt'))
        image = Image.open(image_path)
        image = self.transform(image)

        targets = [torch.zeros((self.num_anchors_per_scale, s, s, 6)) for s in self.S]

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x, y, w, h = map(float, line.strip().split())
                    
                    # Find the best anchor for this bounding box across ALL anchors
                    ious = iou_width_height(torch.tensor([w, h]), self.anchors)
                    best_anchor_idx = ious.argmax()
                    
                    # Determine which scale and which anchor on that scale it belongs to
                    scale_idx = best_anchor_idx // self.num_anchors_per_scale
                    anchor_on_scale_idx = best_anchor_idx % self.num_anchors_per_scale
                    
                    s = self.S[scale_idx]
                    i, j = int(s * y), int(s * x) # grid cell
                    
                    # Check if cell is already taken
                    if targets[scale_idx][anchor_on_scale_idx, i, j, 0] == 0:
                        targets[scale_idx][anchor_on_scale_idx, i, j, 0] = 1 
                        x_cell, y_cell = s * x - j, s * y - i
                        w_cell, h_cell = w, h
                        box_coords = torch.tensor([x_cell, y_cell, w_cell, h_cell])
                        targets[scale_idx][anchor_on_scale_idx, i, j, 1:5] = box_coords
                        targets[scale_idx][anchor_on_scale_idx, i, j, 5] = int(cls)

        return image, tuple(targets)


In [None]:
class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self, predictions, targets, anchors):
        total_loss = 0
        
        # Iterate over the 3 scales
        for i in range(3):
            pred = predictions[i]
            target = targets[i]
            # Anchors for the current scale
            scale_anchors = anchors[i]
            
            obj_mask = target[..., 0] == 1
            noobj_mask = target[..., 0] == 0

            # No Object Loss
            noobj_loss = self.bce(
                (pred[..., 0:1][noobj_mask]), (target[..., 0:1][noobj_mask])
            )

            # Object Loss
            obj_loss = self.bce(
                (pred[..., 0:1][obj_mask]), (target[..., 0:1][obj_mask])
            )

            # Box Coordinate Loss
            pred[..., 1:3] = torch.sigmoid(pred[..., 1:3]) # x,y
            target[..., 3:5] = torch.log(
                (1e-16 + target[..., 3:5] / scale_anchors)
            ) # w,h
            box_loss = self.mse(pred[..., 1:5][obj_mask], target[..., 1:5][obj_mask])
            
            # Class Loss
            class_loss = self.bce(
                (pred[..., 5:][obj_mask]), (target[..., 5:][obj_mask].float())
            )
            
            total_loss += (
                self.lambda_box * box_loss
                + self.lambda_obj * obj_loss
                + self.lambda_noobj * noobj_loss
                + self.lambda_class * class_loss
            )
            
        return total_loss

In [None]:
def get_all_bboxes(loader, model, iou_threshold, confidence_threshold, anchors, device="cpu"):
    """
    Gets all predictions and ground truths from a data loader for the multi-scale model.
    
    Args:
        loader: The DataLoader for the dataset.
        model: The trained YOLOv3 model.
        iou_threshold (float): IoU threshold for NMS.
        confidence_threshold (float): Confidence threshold for filtering predictions.
        anchors (list): The list of anchor boxes.
        device (str): The device to run on ('cuda' or 'cpu').

    Returns:
        tuple: A tuple containing two lists:
               - all_pred_boxes: [[train_idx, class, conf, x, y, w, h], ...]
               - all_true_boxes: [[train_idx, class, 1, x, y, w, h], ...]
    """
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []

    scaled_anchors = (
        torch.tensor(anchors)
        .reshape((3, 3, 2))
        .to(device)
    )

    for batch_idx, (x, y) in enumerate(tqdm(loader, desc="Getting BBoxes")):
        x = x.to(device)
        
        y = (y[0].to(device), y[1].to(device), y[2].to(device))

        with torch.no_grad():
            predictions = model(x)

        batch_size = x.shape[0]
        
        for i in range(batch_size):
            pred_boxes_single_image = []
            # For each scale
            for scale_idx in range(3):
                S = predictions[scale_idx].shape[2]
                # For each anchor on that scale
                for anchor_idx in range(3):
                    # Get all predictions where objectness is above threshold
                    obj_conf = torch.sigmoid(predictions[scale_idx][i, anchor_idx, ..., 0])
                    conf_mask = obj_conf > confidence_threshold
                    
                    if not conf_mask.any():
                        continue

                    # Extract confident predictions
                    preds_on_scale = predictions[scale_idx][i, anchor_idx, conf_mask, :]
                    grid_y, grid_x = torch.where(conf_mask)
                    
                    # Decode bounding box coordinates
                    box_coords = torch.sigmoid(preds_on_scale[:, 1:3])
                    x_center = (box_coords[:, 0] + grid_x) / S
                    y_center = (box_coords[:, 1] + grid_y) / S
                    
                    # Decode width and height
                    anchor = scaled_anchors[scale_idx, anchor_idx]
                    box_wh = torch.exp(preds_on_scale[:, 3:5]) * anchor
                    w = box_wh[:, 0] / IMG_SIZE
                    h = box_wh[:, 1] / IMG_SIZE
                    
                    # Get class predictions
                    class_probs = torch.sigmoid(preds_on_scale[:, 5:])
                    class_conf, class_label = torch.max(class_probs, dim=1)
                    
                    # Combine into [class, conf, x, y, w, h] format
                    final_conf = (torch.sigmoid(preds_on_scale[:, 0]) * class_conf).float()
                    
                    # Filter again by the final confidence
                    final_conf_mask = final_conf > confidence_threshold
                    if not final_conf_mask.any():
                        continue
                        
                    pred_boxes_batch = torch.cat([
                        class_label[final_conf_mask].float().unsqueeze(1),
                        final_conf[final_conf_mask].unsqueeze(1),
                        x_center[final_conf_mask].unsqueeze(1),
                        y_center[final_conf_mask].unsqueeze(1),
                        w[final_conf_mask].unsqueeze(1),
                        h[final_conf_mask].unsqueeze(1)
                    ], dim=1)
                    
                    pred_boxes_single_image.extend(pred_boxes_batch.tolist())

            # Run NMS on all boxes for this image
            nms_boxes = non_max_suppression(pred_boxes_single_image, iou_threshold, confidence_threshold)
            for nms_box in nms_boxes:
                all_pred_boxes.append([train_idx] + nms_box)

            # Extract ground truth boxes for this image
            for scale_idx in range(3):
                S = y[scale_idx].shape[2]
                for anchor_idx in range(3):
                    obj_mask = y[scale_idx][i, anchor_idx, ..., 0] == 1
                    if not obj_mask.any():
                        continue
                        
                    true_boxes_on_scale = y[scale_idx][i, anchor_idx, obj_mask, :]
                    grid_y, grid_x = torch.where(obj_mask)
                    
                    x_center = (true_boxes_on_scale[:, 1] + grid_x) / S
                    y_center = (true_boxes_on_scale[:, 2] + grid_y) / S
                    w = true_boxes_on_scale[:, 3]
                    h = true_boxes_on_scale[:, 4]
                    class_label = true_boxes_on_scale[:, 5]
                    
                    true_boxes_batch = torch.cat([
                        class_label.unsqueeze(1),
                        torch.ones_like(class_label).unsqueeze(1), # Confidence is 1 for true boxes
                        x_center.unsqueeze(1),
                        y_center.unsqueeze(1),
                        w.unsqueeze(1),
                        h.unsqueeze(1),
                    ], dim=1)
                    
                    all_true_boxes.extend([[train_idx] + box for box in true_boxes_batch.tolist()])
            
            train_idx += 1
            
    model.train()
    return all_pred_boxes, all_true_boxes

In [None]:
def train_model(model, loader, optimizer, criterion, device, scaled_anchors):
    model.train()
    for epoch in range(EPOCHS):
        loop = tqdm(loader, leave=True)
        total_loss = 0
        for imgs, labels in loop:
            imgs = imgs.to(device)
            labels = (
                labels[0].to(device),
                labels[1].to(device),
                labels[2].to(device),
            )
            
            preds = model(imgs)
            loss = criterion(preds, labels, scaled_anchors)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_description(f"Epoch {epoch+1}/{EPOCHS}")
            loop.set_postfix(loss=loss.item())
        
        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(loader)}")

In [None]:
def mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.5, num_classes=1):
    """
    Calculates mean Average Precision (mAP), the standard metric for object detection.
    
    Args:
        pred_boxes (list): [[train_idx, class, conf, x, y, w, h], ...]
        true_boxes (list): [[train_idx, class, conf, x, y, w, h], ...]
        iou_threshold (float): Threshold for a detection to be a True Positive.
        num_classes (int): Number of classes in the dataset.
    
    Returns:
        float: mAP value across all classes.
    """
    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = [d for d in pred_boxes if d[1] == c]
        ground_truths = [gt for gt in true_boxes if gt[1] == c]

        # Count how many ground truth boxes are in each image
        amount_bboxes = Counter(gt[0] for gt in ground_truths)
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # Sort detections by confidence
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros(len(detections))
        FP = torch.zeros(len(detections))
        total_true_bboxes = len(ground_truths)
        
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Get all ground truth boxes for the same image as the detection
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            best_iou = 0
            best_gt_idx = -1

            for idx, gt in enumerate(ground_truth_img):
                iou = iou_boxes(torch.tensor(detection[3:]), torch.tensor(gt[3:]))
                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # Check if we haven't already matched this ground truth box
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1 # Mark as True Positive
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1 # It's a duplicate detection
            else:
                FP[detection_idx] = 1 # Failed to meet IoU threshold

        # Calculate precision and recall
        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        
        # Integrate under the precision-recall curve
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / (len(average_precisions) + epsilon)

In [None]:
def check_accuracy(loader, model, device):
    """
    Runs the full evaluation pipeline: gets bounding boxes, calculates mAP.
    """
    print("\nCalculating mAP on dataset ")
    model.to(device)
    pred_boxes, true_boxes = get_all_bboxes(
        loader, model, 
        iou_threshold=IOU_THRESHOLD, 
        confidence_threshold=CONFIDENCE_THRESHOLD, 
        anchors=ANCHORS, 
        device=device
    )
    
    map_val = mean_average_precision(pred_boxes, true_boxes, iou_threshold=IOU_THRESHOLD, num_classes=NUM_CLASSES)
    print(f"mAP: {map_val:.4f}")
    return map_val

In [None]:
def plot_image(image_tensor, boxes):
    """
    Plots predicted bounding boxes on a single image.
    
    Args:
        image_tensor (torch.Tensor): A single image tensor of shape [C, H, W].
        boxes (list): A list of bounding boxes for that image.
    """
    # Convert tensor to PIL Image
    im = transforms.ToPILImage()(image_tensor.cpu())
    if im.mode != "RGB":
        im = im.convert("RGB")
        
    draw = ImageDraw.Draw(im)
    width, height = im.size

    for box in boxes:
        # box format: [class, conf, x, y, w, h]
        class_pred, conf, x, y, w, h = box
        
        # Convert from center format to top-left corner format
        upper_left_x = (x - w / 2) * width
        upper_left_y = (y - h / 2) * height
        lower_right_x = (x + w / 2) * width
        lower_right_y = (y + h / 2) * height

        # Draw bounding box
        draw.rectangle(
            [upper_left_x, upper_left_y, lower_right_x, lower_right_y],
            outline="red",
            width=2
        )
        
        # Draw label
        text = f"Class {int(class_pred)}: {conf:.2f}"
        text_bbox = draw.textbbox((upper_left_x, upper_left_y), text)
        draw.rectangle(text_bbox, fill="red")
        draw.text((upper_left_x, upper_left_y), text, fill="white")

    plt.imshow(im)
    plt.axis('off')
    plt.show()

In [None]:
if __name__ == "__main__":
    image_dir = "C:/Users/lenevo/OneDrive/Desktop/IP/RadDet-1T-128/RadDet40k128HW001Tv2/images"
    label_dir = "C:/Users/lenevo/OneDrive/Desktop/IP/RadDet-1T-128/RadDet40k128HW001Tv2/labels"
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Anchors for loss calculation 
    scaled_anchors = (
        torch.tensor(ANCHORS) / torch.tensor([IMG_SIZE, IMG_SIZE]).view(1, 1, 2)
    ).to(device)

    dataset = RadarDataset(image_dir, label_dir, anchors=ANCHORS, S=S, C=NUM_CLASSES)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

    model = YOLOv3(in_channels=1, num_classes=NUM_CLASSES).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    criterion = YoloLoss()
    
    for epoch in range(EPOCHS):
        train_model(model, loader, optimizer, criterion, device, scaled_anchors)
        check_accuracy(loader, model, device)
    
    print("\nFinal Evaluation: ")
    check_accuracy(loader, model, device)
    
    print("\nVisualizing a sample prediction: ")
    model.eval()
    x, y = next(iter(loader))
    x = x.to(device)
    with torch.no_grad():
        out = model(x)
    
    all_preds, _ = get_all_bboxes(
        [(x, y)], model, IOU_THRESHOLD, CONFIDENCE_THRESHOLD, ANCHORS, device
    )
    
    # Filter boxes for the first image in the batch
    boxes_for_image_0 = [box[1:] for box in all_preds if box[0] == 0]
    
    plot_image(x[0], boxes_for_image_0)