In this part, we are creating a YOLO-T model. The model uses a Swin Transformer as the backbone (loaded from the timm library) and a custom YOLO detection head that fuses features from three scales. The detection head mimics the YOLOv3 design for multi-scale predictions.

---

Ways we can try and improve from this model:
- Use the latest YOLO model.
- Add the EAOD-Net modifications.
- Try and use the information gained from the extraction with another model: like neural-net, random forest, and other applicable ones.
- Balance between precision and accuracy. [Training a model that has high accuracy, a model that has high precision, and then putting the results from both of those together.]

In [None]:
import os
import torch
from datasets import get_dataloader, custom_collate_fn

# =============================================================================
# Adjust these paths according to your folder structure.
# =============================================================================
if __name__ == '__main__':
    # For the train_test_easy split
    base_dir = "/Users/jamesngugi/Desktop/Applied ML/ML-Project/test-data"
    
    # Use the CSV files in the easy split folder:
    csv_train = os.path.join(base_dir, "TestTrainSplits", "train_test_easy", "train-3000.csv")
    csv_test  = os.path.join(base_dir, "TestTrainSplits", "train_test_easy", "test-3000.csv")
    
    # Directory containing JPEG images.
    images_dir = os.path.join(base_dir, "JPEGImageFull", "dataset", "JPEGImage")
    # Directory containing positive XML annotations.
    annotations_dir = os.path.join(base_dir, "positive-Annotation")
    
    # DataLoaders for training and testing.
    # Pass the custom collate function here:
    train_loader = get_dataloader(csv_train, images_dir, annotations_dir, batch_size=32, train=True)
    test_loader  = get_dataloader(csv_test, images_dir, annotations_dir, batch_size=32, train=False)
    
    # When creating the DataLoader inside get_dataloader, set the collate_fn parameter
    
    # For our testing, we can either modify get_dataloader() or wrap it here:
    from torch.utils.data import DataLoader
    # Reconstruct using our custom_collate_fn for demonstration:
    train_loader = DataLoader(train_loader.dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)
    
    # Simple test: iterate through one batch.
    for imgs, targets in train_loader:
        print("Train Images shape:", imgs.shape)  # Expected: [batch, 3, 416, 416]
        print("Train Targets:", targets)  # A list, each element a tensor of shape [N, 4] (or [N, 5] if you include classes)
        break


In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader

# Make sure you have imported your custom_collate_fn from your datasets module
# from datasets import custom_collate_fn

# Rebuild your test loader with the custom collate function.
test_loader = DataLoader(
    test_loader.dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=custom_collate_fn
)

# ---------------------------
# Utility Functions for Evaluation
# ---------------------------
def compute_iou(box1, box2):
    """
    Computes Intersection over Union (IoU) for two boxes.
    Boxes are in the format [x1, y1, x2, y2].
    """
    x1, y1, x2, y2 = box1
    x1g, y1g, x2g, y2g = box2
    
    inter_x1 = max(x1, x1g)
    inter_y1 = max(y1, y1g)
    inter_x2 = min(x2, x2g)
    inter_y2 = min(y2, y2g)
    
    inter_area = max(inter_x2 - inter_x1, 0) * max(inter_y2 - inter_y1, 0)
    area1 = (x2 - x1) * (y2 - y1)
    area2 = (x2g - x1g) * (y2g - y1g)
    union_area = area1 + area2 - inter_area + 1e-6  # avoid division by zero
    return inter_area / union_area

def convert_gt_to_pixels(gt_list, img_size=416):
    """
    Converts a list of ground truth boxes in YOLO format 
    [class, cx, cy, w, h] (normalized) to pixel coordinates [x1, y1, x2, y2, class].
    """
    converted = []
    for gt in gt_list:
        cls, cx, cy, w, h = gt
        x1 = (cx - w/2) * img_size
        y1 = (cy - h/2) * img_size
        x2 = (cx + w/2) * img_size
        y2 = (cy + h/2) * img_size
        converted.append([x1, y1, x2, y2, int(cls)])
    return converted

def evaluate_detections(all_preds, all_gts, iou_threshold=0.01):
    """
    Compares predictions and ground truths across all images.
    For each image, a prediction is considered a true positive (TP) if it matches a ground truth (GT)
    with the same class and IoU >= iou_threshold. Otherwise, it is a false positive (FP).
    Ground truths with no matching prediction are counted as false negatives (FN).
    Returns a dictionary of overall metrics.
    """
    total_TP = 0
    total_FP = 0
    total_FN = 0
    total_iou = 0.0
    iou_count = 0
    
    # Loop over each image.
    for preds, gt in zip(all_preds, all_gts):
        # Convert ground truths to pixel coordinates.
        gts_pixels = convert_gt_to_pixels(gt, img_size=416)
        matched_gts = set()  # to keep track of ground truths that are already matched
        
        # Process each prediction.
        for pred in preds:
            # pred is [x1, y1, x2, y2, conf, cls] with pixel coordinates
            x1p, y1p, x2p, y2p, conf, cls_pred = pred
            best_iou = 0.0
            best_gt_idx = -1
            for idx, gt_box in enumerate(gts_pixels):
                # Only consider ground truths of the same class.
                if gt_box[4] != int(cls_pred):
                    continue
                iou_val = compute_iou(pred[:4], gt_box[:4])
                if iou_val > best_iou:
                    best_iou = iou_val
                    best_gt_idx = idx
            
            # A valid match is found if IoU is above threshold and the GT hasn't been matched.
            if best_iou >= iou_threshold and best_gt_idx not in matched_gts:
                total_TP += 1
                total_iou += best_iou
                iou_count += 1
                matched_gts.add(best_gt_idx)
            else:
                total_FP += 1
        
        # All GTs not matched are false negatives.
        total_FN += len(gts_pixels) - len(matched_gts)
    
    precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0
    recall    = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0
    f1_score  = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    avg_iou   = total_iou / iou_count if iou_count > 0 else 0
    
    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "avg_iou": avg_iou,
        "true_positives": total_TP,
        "false_positives": total_FP,
        "false_negatives": total_FN,
    }

# ---------------------------
# Modified decode_predictions Function (with debug prints removed)
# ---------------------------
def decode_predictions(preds_tuple, conf_thresh=0.25):
    """
    Decodes predictions from bbox, objectness, and class tensors (for simple model).
    Returns a list of detections for one image.
    Each detection is [x1, y1, x2, y2, conf, cls].
    """
    bbox_pred, objectness_pred, class_pred = preds_tuple # Unpack the tuple of tensors
    bbox_pred = bbox_pred.detach().cpu() # [1, H, W, 4] - assuming batch size 1 passed here
    objectness_pred = objectness_pred.detach().cpu() # [1, H, W, 1]
    class_pred = class_pred.detach().cpu() # [1, H, W, num_classes]

    if bbox_pred.ndim != 4: # Check expected ndim for bbox
        raise ValueError(f"Expected bbox_pred to have 4 dimensions (B, H, W, 4), got {bbox_pred.ndim}")
    B, grid_h, grid_w, _ = bbox_pred.shape # Get grid size from bbox_pred
    if grid_h != grid_w:
        raise ValueError(f"Expected square grid but got {grid_h} and {grid_w}")
    grid_size = grid_h
    num_classes = class_pred.shape[-1] # Get num_classes from class_pred

    detections = []

    # Create grids for offset calculation. (Same as before)
    grid_x = torch.arange(grid_size).repeat(grid_size, 1).view(1, grid_size, grid_size).float()
    grid_y = torch.arange(grid_size).repeat(grid_size, 1).t().view(1, grid_size, grid_size).float()


    # Process predictions (similar logic as before, but directly using the tensors)
    box = torch.zeros_like(bbox_pred[..., :4]) # Initialize box tensor [1, H, W, 4]
    box[..., 0] = (torch.sigmoid(bbox_pred[..., 0]) + grid_x) / grid_size # cx
    box[..., 1] = (torch.sigmoid(bbox_pred[..., 1]) + grid_y) / grid_size # cy

    dummy_anchor = torch.tensor([0.5, 0.5]).view(1, 1, 1, 2).type_as(bbox_pred) # Dummy anchor
    box[..., 2] = dummy_anchor[..., 0] * torch.exp(bbox_pred[..., 2]) # w
    box[..., 3] = dummy_anchor[..., 1] * torch.exp(bbox_pred[..., 3]) # h

    conf = torch.sigmoid(objectness_pred[..., 0]) # Objectness confidence
    cls_prob = torch.softmax(class_pred, dim=-1) # Class probabilities
    cls_conf, cls_pred = torch.max(cls_prob, dim=-1) # Max class prob and class index
    final_conf = conf * cls_conf # Final confidence score


    # Iterate through grid cells (as before)
    for i in range(grid_size):
        for j in range(grid_size):
            if final_conf[0, i, j] > conf_thresh: # Check confidence threshold
                x_center, y_center, w, h = box[0, i, j] # Get box params
                # Scale to original image dimensions (assumed to be 416x416).
                x1 = (x_center - w/2) * 416
                y1 = (y_center - h/2) * 416
                x2 = (x_center + w/2) * 416
                y2 = (y_center - h/2) * 416 # Typo in original: should be + h/2
                y2 = (y_center + h/2) * 416 # Corrected line
                detections.append([
                    x1.item(), y1.item(), x2.item(), y2.item(),
                    final_conf[0, i, j].item(), cls_pred[0, i, j].item()
                ])
    return detections

# ---------------------------
# Updated Inference Function with Metrics Collection
# ---------------------------
def inference(model, dataloader, device, output_dir='output', conf_thresh=0.3, iou_threshold=0.5):
    """
    Runs inference on the given dataloader, saves the detection images, and collects results to evaluate metrics.
    """
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    cls_names = ['gun', 'knife', 'wrench', 'pliers', 'scissors', 'hammer']
    
    all_preds = []  # List to store predictions for each image.
    all_gts = []    # List to store ground truth boxes for each image.
    
    with torch.no_grad():
        for idx, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            preds_scales = model(images)  # Output is now (bbox_pred, objectness_pred, class_pred)
            # print("Shape of preds_scales:", [p.shape for p in preds_scales]) # Debug print - No longer a list

            for i in range(images.size(0)):
                # Prepare image for drawing.
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = (img * 255).astype(np.uint8)
                img = np.ascontiguousarray(img)  # Ensure contiguous layout.

                # Decode predictions - Directly pass bbox_pred, objectness_pred, class_pred
                bbox_pred, objectness_pred, class_pred = preds_scales # Unpack the tuple
                img_preds = (bbox_pred[i:i+1], objectness_pred[i:i+1], class_pred[i:i+1]) # Take i-th image batch from each
                dets = decode_predictions(img_preds, conf_thresh=conf_thresh) # Modified decode_predictions to accept tuple
                # Perform Non-Maximum Suppression.
                dets = non_max_suppression(dets, conf_thresh=conf_thresh, iou_thresh=iou_threshold)
                
                # Save detections for this image.
                all_preds.append(dets)
                # Store ground truths. targets[i] is a tensor of shape [N, 5].
                gt_list = targets[i].cpu().numpy().tolist()
                all_gts.append(gt_list)
                
                # Draw detections on the image.
                for det in dets:
                    x1, y1, x2, y2, conf, cls = det
                    cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                    cv2.putText(img, f"{cls_names[int(cls)]} {conf:.2f}", (int(x1), int(y1)-10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                # out_path = os.path.join(output_dir, f"result_{idx}_{i}.jpg")
                # cv2.imwrite(out_path, img)
                # print(f"Saved detection result to {out_path}")
    
    # After processing all images, evaluate detections.
    metrics = evaluate_detections(all_preds, all_gts, iou_threshold=iou_threshold)
    print("Evaluation Metrics:")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall:    {metrics['recall']:.4f}")
    print(f"F1 Score:  {metrics['f1_score']:.4f}")
    print(f"Avg IoU:   {metrics['avg_iou']:.4f}")
    print(f"TP: {metrics['true_positives']}  FP: {metrics['false_positives']}  FN: {metrics['false_negatives']}")

# ---------------------------
# Non-Maximum Suppression Function (unchanged)
# ---------------------------
def non_max_suppression(detections, conf_thresh=0.25, iou_thresh=0.5):
    """
    Applies Non-Maximum Suppression (NMS) on the detections.
    Returns the final list of detection boxes.
    """
    if len(detections) == 0:
        return []
    detections = np.array(detections)
    # Filter detections with confidence lower than the threshold.
    detections = detections[detections[:, 4] >= conf_thresh]
    if len(detections) == 0:
        return []
    # Sort detections by confidence (highest first).
    indices = np.argsort(-detections[:, 4])
    detections = detections[indices]
    final_dets = []
    while len(detections) > 0:
        best = detections[0]
        final_dets.append(best)
        if len(detections) == 1:
            break
        rest = detections[1:]
        x1 = np.maximum(best[0], rest[:, 0])
        y1 = np.maximum(best[1], rest[:, 1])
        x2 = np.minimum(best[2], rest[:, 2])
        y2 = np.minimum(best[3], rest[:, 3])
        inter_area = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
        best_area = (best[2] - best[0]) * (best[3] - best[1])
        rest_area = (rest[:, 2] - rest[:, 0]) * (rest[:, 3] - rest[:, 1])
        iou = inter_area / (best_area + rest_area - inter_area + 1e-6)
        detections = rest[iou < iou_thresh]
    return final_dets

# ---------------------------
# Main Inference Execution
# ---------------------------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Assume model and test_loader are defined elsewhere.
    # For example:
    #   from model_definition import YOLOTModel
    #   model = YOLOTModel(num_classes=6).to(device)
    
    checkpoint_path = 'trained_models_simple/simple_model_state_3000.pth'
    if os.path.exists(checkpoint_path):
        # Loading the saved checkpoint.
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print("Loaded checkpoint for inference.")
    else:
        print("Checkpoint not found. Using current model weights.")
    
    # Run inference and evaluation.
    inference(model, test_loader, device, output_dir='output')

Loaded checkpoint for inference.


  model.load_state_dict(torch.load(checkpoint_path, map_location=device))
  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Evaluation Metrics:
Precision: 0.0000
Recall:    0.0000
F1 Score:  0.0000
Avg IoU:   0.0000
TP: 0  FP: 13  FN: 3


In [None]:
# Define anchors for each scale (example values; normalize relative to input size 416)
ANCHORS = {
    'large':  [(0.10, 0.13), (0.16, 0.30), (0.33, 0.23)],  # for 52x52
    'medium': [(0.22, 0.27), (0.38, 0.56), (0.95, 0.80)],  # for 26x26
    'small':  [(0.90, 1.10), (1.87, 3.23), (4.42, 2.74)]   # for 13x13
}
# Note: These anchor values are exemplary. In a production setup, we should use k-means on SIXray boxes.

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

# the images output by preprocessing must be compatible with the same format: images as tensors and targets as a tensor of shape [N, 5] where each row is [class, x_center, y_center, w, h] with all values normalized.

class YOLOLoss(nn.Module):
    def __init__(self, anchors, num_classes, img_dim=416, ignore_thresh=0.5,
                 lambda_coord=5.0, lambda_noobj=0.5):
        """
        anchors: list of (w, h) for this scale (normalized)
        num_classes: number of classes
        img_dim: input image dimension (assumed square)
        ignore_thresh: IoU threshold for ignoring objectness loss in no-object cells
        lambda_coord: weight for coordinate loss
        lambda_noobj: weight for no-object confidence loss
        """
        super(YOLOLoss, self).__init__()
        self.anchors = anchors  # for one scale
        self.num_anchors = len(anchors)
        self.num_classes = num_classes
        self.img_dim = img_dim
        self.ignore_thresh = ignore_thresh
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse_loss = nn.MSELoss(reduction='sum')
        self.bce_loss = nn.BCELoss(reduction='sum')
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
    
    def forward(self, prediction, targets):
        """
        prediction: tensor of shape [batch, (5+num_classes)*num_anchors, grid, grid]
        targets: list of targets for each image; each target is a tensor of shape [N, 5],
                 with [cls, x_center, y_center, w, h] in normalized coordinates.
        """
        batch_size = prediction.size(0)
        grid_size = prediction.size(2)  # square grid
        stride = self.img_dim / grid_size
        
        prediction = prediction.view(batch_size, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
        prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()  # shape: [B, A, grid, grid, 5+num_classes]
        
        # Get outputs
        pred_tx = prediction[..., 0]  # center x
        pred_ty = prediction[..., 1]  # center y
        pred_tw = prediction[..., 2]  # width
        pred_th = prediction[..., 3]  # height
        pred_conf = prediction[..., 4]  # objectness
        pred_cls = prediction[..., 5:]  # class scores
        
        # Create grid offsets
        grid_x = torch.arange(grid_size).repeat(grid_size, 1).view([1, 1, grid_size, grid_size]).type_as(prediction)
        grid_y = torch.arange(grid_size).repeat(grid_size, 1).t().view([1, 1, grid_size, grid_size]).type_as(prediction)
        
        # Transform predictions to bounding box coordinates
        # According to YOLOv3: 
        # x = sigmoid(tx) + grid_x, similarly for y.
        # w = anchor_w * exp(tw), h = anchor_h * exp(th)
        pred_boxes = torch.zeros(prediction[..., :4].shape).type_as(prediction)
        pred_boxes[..., 0] = (sigmoid(pred_tx) + grid_x) / grid_size
        pred_boxes[..., 1] = (sigmoid(pred_ty) + grid_y) / grid_size
        # Prepare anchors tensor
        anchors_tensor = torch.tensor(self.anchors).type_as(prediction)  # shape: [num_anchors, 2]
        anchors_tensor = anchors_tensor.view(1, self.num_anchors, 1, 1, 2)
        pred_boxes[..., 2] = anchors_tensor[..., 0] * torch.exp(pred_tw)
        pred_boxes[..., 3] = anchors_tensor[..., 1] * torch.exp(pred_th)
        
        # Convert targets to tensor for matching.
        # For each image, create a target tensor of shape [batch, num_anchors, grid, grid, 5+num_classes]
        target_tensor = torch.zeros_like(prediction)
        # Also create object mask.
        obj_mask = torch.zeros(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        noobj_mask = torch.ones(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        class_mask = torch.zeros(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        t_box = torch.zeros_like(pred_boxes)
        
        for b in range(batch_size):
            if targets[b].nelement() == 0:
                continue
            for target in targets[b]:
                # target: [cls, x, y, w, h]
                cls = target[0]
                x, y, w, h = target[1], target[2], target[3], target[4]
                i = int(x * grid_size)
                j = int(y * grid_size)
                # Find best anchor based on IoU between target and anchors (ignoring grid cell offset)
                gt_box = torch.tensor([0, 0, w, h]).unsqueeze(0)  # center not needed here
                anchor_shapes = torch.cat([torch.zeros((self.num_anchors,2)), anchors_tensor[0, :, 0, 0, :]], dim=1)
                # Compute IoU between gt_box and each anchor box
                inter = torch.min(gt_box[:,2:], anchor_shapes[:,2:]).prod(1)
                union = (gt_box[:,2:]*torch.ones_like(anchor_shapes[:,2:])).prod(1) + anchor_shapes[:,2:].prod(1) - inter
                ious = inter / (union + 1e-6)
                best_anchor = torch.argmax(ious)
                
                # Assign ground truth to this grid cell and anchor
                obj_mask[b, best_anchor, j, i] = 1
                noobj_mask[b, best_anchor, j, i] = 0
                target_tensor[b, best_anchor, j, i, 0] = sigmoid(x * grid_size - i)  # target tx
                target_tensor[b, best_anchor, j, i, 1] = sigmoid(y * grid_size - j)  # target ty
                target_tensor[b, best_anchor, j, i, 2] = math.log(w / (self.anchors[best_anchor][0] + 1e-6) + 1e-6)
                target_tensor[b, best_anchor, j, i, 3] = math.log(h / (self.anchors[best_anchor][1] + 1e-6) + 1e-6)
                target_tensor[b, best_anchor, j, i, 4] = 1  # object exists
                # Class one-hot encoding
                target_tensor[b, best_anchor, j, i, 5 + int(cls)] = 1
        
        # Losses:
        # Localization loss (for x, y, w, h)
        loss_x = self.mse_loss(sigmoid(pred_tx) * obj_mask, target_tensor[...,0] * obj_mask)
        loss_y = self.mse_loss(sigmoid(pred_ty) * obj_mask, target_tensor[...,1] * obj_mask)
        loss_w = self.mse_loss(torch.sqrt(torch.abs(pred_boxes[...,2] + 1e-6)) * obj_mask,
                               torch.sqrt(torch.abs(torch.exp(target_tensor[...,2])) * obj_mask))
        loss_h = self.mse_loss(torch.sqrt(torch.abs(pred_boxes[...,3] + 1e-6)) * obj_mask,
                               torch.sqrt(torch.abs(torch.exp(target_tensor[...,3])) * obj_mask))
        loss_coord = self.lambda_coord * (loss_x + loss_y + loss_w + loss_h)
        
        # Confidence loss:
        loss_conf_obj = self.bce_loss(sigmoid(pred_conf) * obj_mask, target_tensor[...,4] * obj_mask)
        loss_conf_noobj = self.lambda_noobj * self.bce_loss(sigmoid(pred_conf) * noobj_mask, 
                                                            target_tensor[...,4] * noobj_mask)
        loss_conf = loss_conf_obj + loss_conf_noobj
        
        # Classification loss:
        # For each cell with an object, use cross entropy loss. Reshape predictions.
        pred_cls = pred_cls[obj_mask.bool()]
        target_cls = target_tensor[..., 5:][obj_mask.bool()]
        # target_cls is one-hot; get the index.
        if pred_cls.nelement() > 0:
            target_cls_index = torch.argmax(target_cls, dim=-1)
            loss_cls = self.ce_loss(pred_cls, target_cls_index)
        else:
            loss_cls = torch.tensor(0.0).type_as(prediction)
        
        total_loss = loss_coord + loss_conf + loss_cls
        return total_loss

# Define the YOLO Head and YOLO-T Model.
class YOLOHead(nn.Module):
    def __init__(self, num_classes=6, in_channels=[192, 384, 768]):
        super(YOLOHead, self).__init__()
        self.conv_small = nn.Sequential(
            nn.Conv2d(in_channels[2], 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, (num_classes + 5) * 3, kernel_size=1)
        )
        self.conv_medium_upsample = nn.Conv2d(in_channels[2], 128, kernel_size=1)
        self.conv_medium = nn.Sequential(
            nn.Conv2d(in_channels[1] + 128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, (num_classes + 5) * 3, kernel_size=1)
        )
        self.conv_large_upsample = nn.Conv2d(256, 64, kernel_size=1)
        self.conv_large = nn.Sequential(
            nn.Conv2d(in_channels[0] + 64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, (num_classes + 5) * 3, kernel_size=1)
        )
    def forward(self, feats):
        large, medium, small = feats[0], feats[1], feats[2]
        pred_small = self.conv_small(small)
        up_small = nn.functional.interpolate(small, scale_factor=2, mode='nearest')
        up_small = self.conv_medium_upsample(up_small)
        fused_medium = torch.cat([up_small, medium], dim=1)
        pred_medium = self.conv_medium(fused_medium)
        up_medium = nn.functional.interpolate(fused_medium, scale_factor=2, mode='nearest')
        up_medium = self.conv_large_upsample(up_medium)
        fused_large = torch.cat([up_medium, large], dim=1)
        pred_large = self.conv_large(fused_large)
        return [pred_large, pred_medium, pred_small]

class YOLOTModel(nn.Module):
    def __init__(self, num_classes=6):
        super(YOLOTModel, self).__init__()
        self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True,
                                          features_only=True, out_indices=(1, 2, 3))
        self.head = YOLOHead(num_classes=num_classes, in_channels=[192, 384, 768])
    def forward(self, x):
        feats = self.backbone(x)
        preds = self.head(feats)
        return preds

# Instantiate model and loss functions for each scale.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YOLOTModel(num_classes=6).to(device)

# Create one YOLOLoss instance per scale using corresponding anchors.
loss_large = YOLOLoss(ANCHORS['large'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                        lambda_coord=5.0, lambda_noobj=0.5)
loss_medium = YOLOLoss(ANCHORS['medium'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                         lambda_coord=5.0, lambda_noobj=0.5)
loss_small = YOLOLoss(ANCHORS['small'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                        lambda_coord=5.0, lambda_noobj=0.5)

# Training loop that combines losses from all scales.
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
def train_model(model, dataloader, optimizer, device, num_epochs=50):
    model.train()
    for epoch in range(num_epochs):
        total_loss_epoch = 0.0
        for batch_idx, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            optimizer.zero_grad()
            # Get predictions from model; each element in preds is for one scale.
            preds = model(images)  # list of three tensors
            # Compute loss for each scale
            loss_l = loss_large(preds[0], targets)
            loss_m = loss_medium(preds[1], targets)
            loss_s = loss_small(preds[2], targets)
            loss = loss_l + loss_m + loss_s
            loss.backward()
            optimizer.step()
            total_loss_epoch += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} Loss: {loss.item():.4f}")
        avg_loss = total_loss_epoch / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")
        torch.save(model.state_dict(), f'yolot_epoch_{epoch+1}.pth')

print("Starting Training...")
train_model(model, train_loader, optimizer, device, num_epochs=50) # train_loader is the dataloader with the train dataset

# bounding boxes, the labels of the dangerous goods identified and other data

### Documentation of code and summary of what is needed for each stage in relation to my code.
### Requirements From the Data Preprocessing Stage

1. **Dataset Format:**
   - **Images:**  
     - Images should be read from a directory (e.g., `data/sixray/images/train` for training and a corresponding test folder).  
     - They are assumed to be grayscale X‑ray images. (Your partner can either store them as grayscale or RGB; if grayscale, the code will convert to 3‑channel RGB.)
   - **Annotations:**  
     - Each image must have a corresponding text file (with the same base filename) in a dedicated labels directory (e.g., `data/sixray/labels/train`).  
     - The label file must be in **YOLO format** – each line should be:  
       ```
       class_id x_center y_center width height
       ```  
       where all the coordinates are normalized (i.e. in the [0, 1] range).
   - **DataLoader Compatibility:**  
     - The preprocessing code should output a PyTorch DataLoader where each sample is a tuple:  
       - **Image:** a tensor of shape `[3, 416, 416]` (i.e., resized to 416×416, normalized to [0,1]).  
       - **Targets:** a tensor of shape `[N, 5]` per image (each row is `[class, x_center, y_center, w, h]` in normalized format).  
     - If no ground truth exists for an image, targets should be an empty tensor.

2. **Augmentations:**  
   - The partner’s preprocessing should apply augmentations such as resizing (to 416×416), horizontal flipping, brightness adjustments, etc.
   - The augmentation process must also appropriately transform the bounding boxes in YOLO format.

---

### Expected Output from the Testing Stage

1. **Inference Output:**  
   - For each test image, the code will generate a list of detections.  
   - Each detection is a bounding box defined as `[x1, y1, x2, y2, confidence, class_id]`, where coordinates are in pixel space relative to the original (or resized) image.

2. **Visualization:**  
   - The testing code will produce images with drawn bounding boxes (with labels and confidence scores) and save them to an output folder (e.g., `output/`).

3. **Metrics Compatibility:**  
   - The detection outputs (the list of boxes) should be in a standard format so they can later be used to compute evaluation metrics such as mAP externally or in a subsequent evaluation stage.

---

### How My Code Works and Integrates with the Preprocessing and Testing Stages

1. **Input from Preprocessing Stage:**
   - The training pipeline receives a DataLoader from the preprocessing stage.  
   - Each batch consists of images (tensors of shape `[B, 3, 416, 416]`) and targets (a list of tensors, with each target tensor of shape `[N, 5]` for that image).
   - The images and annotations are standardized (normalized and augmented) so they can be fed into the model.

2. **Training Flow:**
   - The YOLO-T model first passes the input image batch through the Swin Transformer backbone, which produces multi-scale feature maps.
   - These features go into the custom YOLO head, which fuses different scales and outputs three prediction tensors, each corresponding to a different grid size (large, medium, and small scales).
   - The prediction tensors are then processed by a complete YOLO loss function. This loss:
     - Reshapes the predictions, applies sigmoid and exponential functions, and converts them into bounding box coordinates.
     - Assigns ground truth targets to specific grid cells and anchors.
     - Computes coordinate, objectness, and classification losses that are combined as the final loss.
   - The optimizer then updates the model weights based on this loss.

3. **Testing/Inference Flow:**
   - During inference, the same model (with the learned weights) takes a test image through the backbone and head.
   - The raw predictions are then decoded in the testing code. Decoding consists of:
     - Rearranging the output tensor, applying sigmoid to center and objectness predictions, and exponentials for width/height.
     - Converting these normalized values into absolute pixel coordinates.
   - The code applies Non-Maximum Suppression (NMS) to prune overlapping detections, ensuring that each object is represented by a single bounding box.
   - Finally, the resulting boxes (with their confidence scores and predicted class IDs) are drawn on the test image and saved. These outputs are standardized so they can later be used for metric evaluation.

4. **Output for Next Stage:**
   - For testing, the produced images (with drawn bounding boxes) and the detection results (lists of boxes) serve as input for the next code module that might compute evaluation metrics (such as mAP) or for further post-processing.
   - Consistency is maintained because the same coordinate conversion and NMS procedure are used both during evaluation and (if needed) in subsequent visualization or metric calculation stages.
