In [10]:
import pandas as pd

# Define file path to the CSV file.
csv_path = '../test-data/TestTrainSplits/train_test_easy/train.csv'  # update with your CSV file path

# Read the CSV file into a DataFrame.
df = pd.read_csv(csv_path)

# Determine which images are positive or negative.
# Assume positive images start with 'P' and negatives with 'N'.
df['is_positive'] = df['name'].astype(str).str.startswith('P')

# Count the number of positives and negatives.
total_count = len(df)
count_pos = df['is_positive'].sum()
count_neg = total_count - count_pos

# Calculate percentage split.
percent_pos = count_pos / total_count * 100
percent_neg = count_neg / total_count * 100

print(f"Total images: {total_count}")
print(f"Positives: {count_pos} ({percent_pos:.2f}%)")
print(f"Negatives: {count_neg} ({percent_neg:.2f}%)")

# Split the DataFrame into positives and negatives.
positives = df[df['is_positive']]
negatives = df[~df['is_positive']]

# Define desired sample size.
num_sample = 1000

# Determine how many positives and negatives to sample according to the original split.
num_pos_sample = int(round(num_sample * (count_pos / total_count)))
num_neg_sample = num_sample - num_pos_sample

print(f"Sampling {num_pos_sample} positive images and {num_neg_sample} negative images.")

# Randomly sample the desired number from both groups.
# We set a random_state for reproducibility.
sampled_positives = positives.sample(n=num_pos_sample, random_state=42)
sampled_negatives = negatives.sample(n=num_neg_sample, random_state=42)

# Combine the sampled positives and negatives, then shuffle.
sampled_df = pd.concat([sampled_positives, sampled_negatives]).sample(frac=1, random_state=42)

# (Optional) Write the sampled CSV out for later reference.
sampled_csv_path = '../test-data/TestTrainSplits/train_test_easy/train-1000.csv'
sampled_df.to_csv(sampled_csv_path, index=False)
print(f"Saved the sampled dataframe to {sampled_csv_path}")

# Print a quick summary of the sampled set.
sampled_count_pos = sampled_df['is_positive'].sum()
sampled_count_neg = len(sampled_df) - sampled_count_pos
print(f"Sampled Set: {len(sampled_df)} images")
print(f"Positives: {sampled_count_pos} ({sampled_count_pos / len(sampled_df) * 100:.2f}%)")
print(f"Negatives: {sampled_count_neg} ({sampled_count_neg / len(sampled_df) * 100:.2f}%)")

# If needed, loop through the sampled_df rows.
for index, row in sampled_df.iterrows():
    image_name = row['name']
    # Do additional per-image processing if needed:
    print(f"Processing image: {image_name}")


Total images: 374800
Positives: 7496 (2.00%)
Negatives: 367304 (98.00%)
Sampling 20 positive images and 980 negative images.
Saved the sampled dataframe to ../test-data/TestTrainSplits/train_test_easy/train-1000.csv
Sampled Set: 1000 images
Positives: 20 (2.00%)
Negatives: 980 (98.00%)
Processing image: N0507555
Processing image: N0830884
Processing image: N0530829
Processing image: N0458875
Processing image: N0150190
Processing image: N0489370
Processing image: N0254623
Processing image: N0785581
Processing image: N0724860
Processing image: N0036275
Processing image: N1044430
Processing image: N1045703
Processing image: N1012865
Processing image: N0331470
Processing image: N0232399
Processing image: N0269855
Processing image: N0082507
Processing image: N0456783
Processing image: N0387842
Processing image: N0860318
Processing image: N0217403
Processing image: N0439370
Processing image: N0573658
Processing image: N0408663
Processing image: N0548019
Processing image: N0853750
Processing 

In this part, we are creating a YOLO-T model. The model uses a Swin Transformer as the backbone (loaded from the timm library) and a custom YOLO detection head that fuses features from three scales. The detection head mimics the YOLOv3 design for multi-scale predictions.

---

Ways we can try and improve from this model:
- Use the latest YOLO model.
- Add the EAOD-Net modifications.
- Try and use the information gained from the extraction with another model: like neural-net, random forest, and other applicable ones.
- Balance between precision and accuracy. [Training a model that has high accuracy, a model that has high precision, and then putting the results from both of those together.]

In [None]:
import os
import torch
from datasets import get_dataloader, custom_collate_fn

# =============================================================================
# Adjust these paths according to your folder structure.
# =============================================================================
if __name__ == '__main__':
    # For the train_test_easy split
    base_dir = "/Users/jamesngugi/Desktop/Applied ML/ML-Project/test-data"
    
    # Use the CSV files in the easy split folder:
    csv_train = os.path.join(base_dir, "TestTrainSplits", "train_test_easy", "train-1000.csv")
    csv_test  = os.path.join(base_dir, "TestTrainSplits", "train_test_easy", "test-100.csv")
    
    # Directory containing JPEG images.
    images_dir = os.path.join(base_dir, "JPEGImageFull", "dataset", "JPEGImage")
    # Directory containing positive XML annotations.
    annotations_dir = os.path.join(base_dir, "positive-Annotation")
    
    # DataLoaders for training and testing.
    # Pass the custom collate function here:
    train_loader = get_dataloader(csv_train, images_dir, annotations_dir, batch_size=32, train=True)
    test_loader  = get_dataloader(csv_test, images_dir, annotations_dir, batch_size=32, train=False)
    
    # When creating the DataLoader inside get_dataloader, set the collate_fn parameter
    
    # For our testing, we can either modify get_dataloader() or wrap it here:
    from torch.utils.data import DataLoader
    # Reconstruct using our custom_collate_fn for demonstration:
    train_loader = DataLoader(train_loader.dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)
    
    # Simple test: iterate through one batch.
    for imgs, targets in train_loader:
        print("Train Images shape:", imgs.shape)  # Expected: [batch, 3, 416, 416]
        print("Train Targets:", targets)  # A list, each element a tensor of shape [N, 4] (or [N, 5] if you include classes)
        break


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Train Images shape: torch.Size([32, 3, 416, 416])
Train Targets: [tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([[3.0000, 0.8011, 0.6599, 0.2122, 0.2514],
        [3.0000, 0.8177, 0.5290, 0.1614, 0.1495]]), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5)), tensor([], size=(0, 5))]


# Training

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import timm

torch.set_num_threads(os.cpu_count())

# Define anchors for each scale (example values; normalize relative to input size 416)
ANCHORS = {
    'large':  [(0.10, 0.13), (0.16, 0.30), (0.33, 0.23)],  # for 52x52
    'medium': [(0.22, 0.27), (0.38, 0.56), (0.95, 0.80)],  # for 26x26
    'small':  [(0.90, 1.10), (1.87, 3.23), (4.42, 2.74)]   # for 13x13
}

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

class YOLOLoss(nn.Module):
    def __init__(self, anchors, num_classes, img_dim=416, ignore_thresh=0.5,
                 lambda_coord=5.0, lambda_noobj=0.5):
        """
        anchors: list of (w, h) for this scale (normalized)
        num_classes: number of classes
        img_dim: input image dimension (assumed square)
        ignore_thresh: IoU threshold for ignoring objectness loss in no-object cells
        lambda_coord: weight for coordinate loss
        lambda_noobj: weight for no-object confidence loss
        """
        super(YOLOLoss, self).__init__()
        self.anchors = anchors  # for one scale
        self.num_anchors = len(anchors)
        self.num_classes = num_classes
        self.img_dim = img_dim
        self.ignore_thresh = ignore_thresh
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse_loss = nn.MSELoss(reduction='sum')
        self.bce_loss = nn.BCELoss(reduction='sum')
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum')
    
    def forward(self, prediction, targets):
        """
        prediction: tensor of shape [batch, (5+num_classes)*num_anchors, grid, grid]
        targets: list of targets for each image; each target is a tensor of shape [N, 5],
                 with [cls, x_center, y_center, w, h] in normalized coordinates.
        """
        batch_size = prediction.size(0)
        grid_size = prediction.size(2)  # square grid
        stride = self.img_dim / grid_size
        
        prediction = prediction.view(batch_size, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
        prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()  # shape: [B, A, grid, grid, 5+num_classes]
        
        # Get outputs
        pred_tx = prediction[..., 0]  # center x
        pred_ty = prediction[..., 1]  # center y
        pred_tw = prediction[..., 2]  # width
        pred_th = prediction[..., 3]  # height
        pred_conf = prediction[..., 4]  # objectness
        pred_cls = prediction[..., 5:]  # class scores
        
        # Create grid offsets
        grid_x = torch.arange(grid_size).repeat(grid_size, 1).view([1, 1, grid_size, grid_size]).type_as(prediction)
        grid_y = torch.arange(grid_size).repeat(grid_size, 1).t().view([1, 1, grid_size, grid_size]).type_as(prediction)
        
        # Transform predictions to bounding box coordinates
        # According to YOLOv3: 
        # x = sigmoid(tx) + grid_x, similarly for y.
        # w = anchor_w * exp(tw), h = anchor_h * exp(th)
        pred_boxes = torch.zeros(prediction[..., :4].shape).type_as(prediction)
        pred_boxes[..., 0] = (sigmoid(pred_tx) + grid_x) / grid_size
        pred_boxes[..., 1] = (sigmoid(pred_ty) + grid_y) / grid_size
        # Prepare anchors tensor
        anchors_tensor = torch.tensor(self.anchors).type_as(prediction)  # shape: [num_anchors, 2]
        anchors_tensor = anchors_tensor.view(1, self.num_anchors, 1, 1, 2)
        pred_boxes[..., 2] = anchors_tensor[..., 0] * torch.exp(pred_tw)
        pred_boxes[..., 3] = anchors_tensor[..., 1] * torch.exp(pred_th)
        
        # Convert targets to tensor for matching.
        # For each image, create a target tensor of shape [batch, num_anchors, grid, grid, 5+num_classes]
        target_tensor = torch.zeros_like(prediction)
        # Also create object mask.
        obj_mask = torch.zeros(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        noobj_mask = torch.ones(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        class_mask = torch.zeros(batch_size, self.num_anchors, grid_size, grid_size).type_as(prediction)
        t_box = torch.zeros_like(pred_boxes)
        
        for b in range(batch_size):
            if targets[b].nelement() == 0:
                continue
            for target in targets[b]:
                # target: [cls, x, y, w, h]
                cls = target[0]
                x, y, w, h = target[1], target[2], target[3], target[4]
                i = int(x * grid_size)
                j = int(y * grid_size)
                # Find best anchor based on IoU between target and anchors (ignoring grid cell offset)
                gt_box = torch.tensor([0, 0, w, h]).unsqueeze(0)  # center not needed here
                anchor_shapes = torch.cat([torch.zeros((self.num_anchors,2)), anchors_tensor[0, :, 0, 0, :]], dim=1)
                # Compute IoU between gt_box and each anchor box
                inter = torch.min(gt_box[:,2:], anchor_shapes[:,2:]).prod(1)
                union = (gt_box[:,2:]*torch.ones_like(anchor_shapes[:,2:])).prod(1) + anchor_shapes[:,2:].prod(1) - inter
                ious = inter / (union + 1e-6)
                best_anchor = torch.argmax(ious)
                
                # assign ground truth to this grid cell and anchor
                obj_mask[b, best_anchor, j, i] = 1
                noobj_mask[b, best_anchor, j, i] = 0
                target_tensor[b, best_anchor, j, i, 0] = sigmoid(x * grid_size - i)  # target tx
                target_tensor[b, best_anchor, j, i, 1] = sigmoid(y * grid_size - j)  # target ty
                target_tensor[b, best_anchor, j, i, 2] = math.log(w / (self.anchors[best_anchor][0] + 1e-6) + 1e-6)
                target_tensor[b, best_anchor, j, i, 3] = math.log(h / (self.anchors[best_anchor][1] + 1e-6) + 1e-6)
                target_tensor[b, best_anchor, j, i, 4] = 1  # object exists
                # Class one-hot encoding
                target_tensor[b, best_anchor, j, i, 5 + int(cls)] = 1
        
        # Losses:
        # Localization loss (for x, y, w, h)
        loss_x = self.mse_loss(sigmoid(pred_tx) * obj_mask, target_tensor[...,0] * obj_mask)
        loss_y = self.mse_loss(sigmoid(pred_ty) * obj_mask, target_tensor[...,1] * obj_mask)
        loss_w = self.mse_loss(torch.sqrt(torch.abs(pred_boxes[...,2] + 1e-6)) * obj_mask,
                               torch.sqrt(torch.abs(torch.exp(target_tensor[...,2])) * obj_mask))
        loss_h = self.mse_loss(torch.sqrt(torch.abs(pred_boxes[...,3] + 1e-6)) * obj_mask,
                               torch.sqrt(torch.abs(torch.exp(target_tensor[...,3])) * obj_mask))
        loss_coord = self.lambda_coord * (loss_x + loss_y + loss_w + loss_h)
        
        # Confidence loss:
        loss_conf_obj = self.bce_loss(sigmoid(pred_conf) * obj_mask, target_tensor[...,4] * obj_mask)
        loss_conf_noobj = self.lambda_noobj * self.bce_loss(sigmoid(pred_conf) * noobj_mask, 
                                                            target_tensor[...,4] * noobj_mask)
        loss_conf = loss_conf_obj + loss_conf_noobj
        
        # Classification loss:
        # For each cell with an object, use cross entropy loss. Reshape predictions.
        pred_cls = pred_cls[obj_mask.bool()]
        target_cls = target_tensor[..., 5:][obj_mask.bool()]
        # target_cls is one-hot; get the index.
        if pred_cls.nelement() > 0:
            target_cls_index = torch.argmax(target_cls, dim=-1)
            loss_cls = self.ce_loss(pred_cls, target_cls_index)
        else:
            loss_cls = torch.tensor(0.0).type_as(prediction)
        
        total_loss = loss_coord + loss_conf + loss_cls
        return total_loss

# Define the YOLO Head and YOLO-T Model.
class YOLOHead(nn.Module):
    def __init__(self, num_classes=6, in_channels=[192, 384, 768]):
        super(YOLOHead, self).__init__()
        # Small-scale branch: from the smallest (in_channels[2] = 768)
        self.conv_small = nn.Sequential(
            nn.Conv2d(in_channels[2], 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, (num_classes + 5) * 3, kernel_size=1)
        )
        # Medium-scale branch: upsample small to fuse with medium.
        self.conv_medium_upsample = nn.Conv2d(in_channels[2], 128, kernel_size=1)
        self.conv_medium = nn.Sequential(
            nn.Conv2d(in_channels[1] + 128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, (num_classes + 5) * 3, kernel_size=1)
        )
        # Large-scale branch: upsample the fused medium feature.
        self.conv_large_upsample = nn.Conv2d(512, 64, kernel_size=1)
        self.conv_large = nn.Sequential(
            nn.Conv2d(in_channels[0] + 64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, (num_classes + 5) * 3, kernel_size=1)
        )

    def forward(self, feats):
        # feats is a list: [large (from backbone, 192 channels), medium (384), small (768)]
        large, medium, small = feats[0], feats[1], feats[2]
        # Small-scale prediction:
        pred_small = self.conv_small(small)
        # Upsample small feature and fuse it with medium feature:
        up_small = nn.functional.interpolate(small, scale_factor=2, mode='nearest')
        up_small = self.conv_medium_upsample(up_small)  # outputs 128 channels
        fused_medium = torch.cat([up_small, medium], dim=1)  # shape becomes [B, 128+384=512, ...]
        pred_medium = self.conv_medium(fused_medium)
        # For large-scale prediction: upsample the fused medium feature:
        up_medium = nn.functional.interpolate(fused_medium, scale_factor=2, mode='nearest')
        up_medium = self.conv_large_upsample(up_medium)  # outputs 64 channels
        fused_large = torch.cat([up_medium, large], dim=1)  # [B, 64+192=256, ...]
        pred_large = self.conv_large(fused_large)
        return [pred_large, pred_medium, pred_small]

class YOLOTModel(nn.Module):
    def __init__(self, num_classes=6):
        super(YOLOTModel, self).__init__()
        self.backbone = timm.create_model('swin_tiny_patch4_window7_224', 
                                           pretrained=True,
                                           img_size=416,
                                           features_only=True, 
                                           out_indices=(1, 2, 3))
        self.head = YOLOHead(num_classes=num_classes, in_channels=[192, 384, 768])
        
    def forward(self, x):
        feats = self.backbone(x)
        # Note: Ensure each feature map is in NCHW order.
        for i, feat in enumerate(feats):
            if feat.dim() == 4 and feat.shape[1] < feat.shape[-1]:
                feats[i] = feat.permute(0, 3, 1, 2)
        preds = self.head(feats)
        return preds

# Instantiate model and loss functions for each scale.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YOLOTModel(num_classes=6).to(device)

loss_large = YOLOLoss(ANCHORS['large'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                        lambda_coord=5.0, lambda_noobj=0.5)
loss_medium = YOLOLoss(ANCHORS['medium'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                         lambda_coord=5.0, lambda_noobj=0.5)
loss_small = YOLOLoss(ANCHORS['small'], num_classes=6, img_dim=416, ignore_thresh=0.5, 
                        lambda_coord=5.0, lambda_noobj=0.5)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

def train_model(model, dataloader, optimizer, device, num_epochs=50):
    model.train()
    for epoch in range(num_epochs):
        total_loss_epoch = 0.0
        starting_loss = None  # record the starting loss for this epoch (from the first batch)
        for batch_idx, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            optimizer.zero_grad()
            # Get predictions from model; each element in preds is for one scale.
            preds = model(images)  # list of three tensors
            # Compute loss for each scale
            loss_l = loss_large(preds[0], targets)
            loss_m = loss_medium(preds[1], targets)
            loss_s = loss_small(preds[2], targets)
            loss = loss_l + loss_m + loss_s
            loss.backward()
            optimizer.step()
            
            loss_val = loss.item()
            if batch_idx == 0:
                starting_loss = loss_val  # set the starting loss for the epoch
            total_loss_epoch += loss_val
            
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} Loss: {loss_val:.4f}")

            # if starting_loss is not None and loss_val <= starting_loss * 0.10:
            #     print(f"Early stopping triggered at epoch {epoch+1}. loss value: {loss_val:.4f} is <= 10% of starting loss {starting_loss:.4f}.")
            #     torch.save(model.state_dict(), f'yolot_epoch_{epoch+1}_early_stop.pth')
            #     return
        
        avg_loss = total_loss_epoch / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")
        
        # Early stopping: if the average loss for this epoch is <= 10% of the starting loss.
        # if starting_loss is not None and avg_loss <= starting_loss * 0.10:
        #     print(f"Early stopping triggered at epoch {epoch+1}. Average loss {avg_loss:.4f} is <= 10% of starting loss {starting_loss:.4f}.")
        #     torch.save(model.state_dict(), f'yolot_epoch_{epoch+1}_early_stop.pth')
        #     return
        
        # torch.save(model.state_dict(), f'yolot_epoch_{epoch+1}.pth')

print("Starting Training...")
# Note: Ensure that 'train_loader' is defined and provides batches of (images, targets)
train_model(model, train_loader, optimizer, device, num_epochs=30)

save_dir = 'trained_models'
os.makedirs(save_dir, exist_ok=True)

# Define the path to the file where you want to save the final model state.
model_path = os.path.join(save_dir, 'model_state_1000.pth')
torch.save(model.state_dict(), model_path)
print(f"Model state saved to {model_path}")

Starting Training...


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [1/30] Batch 0/32 Loss: 125828.7188
Epoch [1/30] Batch 10/32 Loss: 97920.1875
Epoch [1/30] Batch 20/32 Loss: 82312.9531
Epoch [1/30] Batch 30/32 Loss: 87324.2031
Epoch [1/30] Average Loss: 93879.7899


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [2/30] Batch 0/32 Loss: 85378.1719
Epoch [2/30] Batch 10/32 Loss: 80953.7969
Epoch [2/30] Batch 20/32 Loss: 77775.0938
Epoch [2/30] Batch 30/32 Loss: 73291.5938
Epoch [2/30] Average Loss: 77510.7619


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [3/30] Batch 0/32 Loss: 72649.1172
Epoch [3/30] Batch 10/32 Loss: 69920.0312
Epoch [3/30] Batch 20/32 Loss: 66511.1641
Epoch [3/30] Batch 30/32 Loss: 64113.8867
Epoch [3/30] Average Loss: 66720.3625


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [4/30] Batch 0/32 Loss: 63550.6875
Epoch [4/30] Batch 10/32 Loss: 60980.2656
Epoch [4/30] Batch 20/32 Loss: 60640.2070
Epoch [4/30] Batch 30/32 Loss: 57663.7969
Epoch [4/30] Average Loss: 59547.0226


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [5/30] Batch 0/32 Loss: 56652.7344
Epoch [5/30] Batch 10/32 Loss: 56287.5586
Epoch [5/30] Batch 20/32 Loss: 54164.4961
Epoch [5/30] Batch 30/32 Loss: 51590.7500
Epoch [5/30] Average Loss: 53385.5023


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [6/30] Batch 0/32 Loss: 51194.3125
Epoch [6/30] Batch 10/32 Loss: 49680.5625
Epoch [6/30] Batch 20/32 Loss: 48468.3320
Epoch [6/30] Batch 30/32 Loss: 47150.7344
Epoch [6/30] Average Loss: 48085.0236


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [7/30] Batch 0/32 Loss: 46549.5352
Epoch [7/30] Batch 10/32 Loss: 49158.8242
Epoch [7/30] Batch 20/32 Loss: 57019.4336
Epoch [7/30] Batch 30/32 Loss: 48540.0430
Epoch [7/30] Average Loss: 51468.8383


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [8/30] Batch 0/32 Loss: 48147.2227
Epoch [8/30] Batch 10/32 Loss: 44873.4219
Epoch [8/30] Batch 20/32 Loss: 43332.4180
Epoch [8/30] Batch 30/32 Loss: 41524.8047
Epoch [8/30] Average Loss: 43415.4725


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [9/30] Batch 0/32 Loss: 41001.6211
Epoch [9/30] Batch 10/32 Loss: 40678.4727
Epoch [9/30] Batch 20/32 Loss: 38079.5352
Epoch [9/30] Batch 30/32 Loss: 36230.6367
Epoch [9/30] Average Loss: 38043.2921


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [10/30] Batch 0/32 Loss: 36373.3086
Epoch [10/30] Batch 10/32 Loss: 35494.8711
Epoch [10/30] Batch 20/32 Loss: 35351.5820
Epoch [10/30] Batch 30/32 Loss: 33185.5742
Epoch [10/30] Average Loss: 34165.9680


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [11/30] Batch 0/32 Loss: 33097.6094
Epoch [11/30] Batch 10/32 Loss: 31488.5410
Epoch [11/30] Batch 20/32 Loss: 31051.3047
Epoch [11/30] Batch 30/32 Loss: 30068.3945
Epoch [11/30] Average Loss: 30826.4614


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [12/30] Batch 0/32 Loss: 30227.8633
Epoch [12/30] Batch 10/32 Loss: 29638.9297
Epoch [12/30] Batch 20/32 Loss: 28335.9570
Epoch [12/30] Batch 30/32 Loss: 27265.9141
Epoch [12/30] Average Loss: 27986.3307


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [13/30] Batch 0/32 Loss: 27004.7812
Epoch [13/30] Batch 10/32 Loss: 26046.1426
Epoch [13/30] Batch 20/32 Loss: 25411.5840
Epoch [13/30] Batch 30/32 Loss: 25542.5410
Epoch [13/30] Average Loss: 25531.5480


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [14/30] Batch 0/32 Loss: 24491.3984
Epoch [14/30] Batch 10/32 Loss: 24208.3281
Epoch [14/30] Batch 20/32 Loss: 23285.5410
Epoch [14/30] Batch 30/32 Loss: 22426.0000
Epoch [14/30] Average Loss: 23261.6974


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [15/30] Batch 0/32 Loss: 22682.2812
Epoch [15/30] Batch 10/32 Loss: 21808.2891
Epoch [15/30] Batch 20/32 Loss: 21491.3086
Epoch [15/30] Batch 30/32 Loss: 21002.7070
Epoch [15/30] Average Loss: 21363.9009


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [16/30] Batch 0/32 Loss: 20652.8887
Epoch [16/30] Batch 10/32 Loss: 20209.7012
Epoch [16/30] Batch 20/32 Loss: 19605.8789
Epoch [16/30] Batch 30/32 Loss: 18912.0098
Epoch [16/30] Average Loss: 19736.9582


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [17/30] Batch 0/32 Loss: 19912.4199
Epoch [17/30] Batch 10/32 Loss: 18621.5371
Epoch [17/30] Batch 20/32 Loss: 17915.9219
Epoch [17/30] Batch 30/32 Loss: 17656.1797
Epoch [17/30] Average Loss: 18058.8857


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [18/30] Batch 0/32 Loss: 17371.9785
Epoch [18/30] Batch 10/32 Loss: 17369.2031
Epoch [18/30] Batch 20/32 Loss: 16996.0898
Epoch [18/30] Batch 30/32 Loss: 16832.0547
Epoch [18/30] Average Loss: 16734.9146


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [19/30] Batch 0/32 Loss: 16320.6289
Epoch [19/30] Batch 10/32 Loss: 16004.5967
Epoch [19/30] Batch 20/32 Loss: 15641.0225
Epoch [19/30] Batch 30/32 Loss: 15919.7256
Epoch [19/30] Average Loss: 15479.8463


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [20/30] Batch 0/32 Loss: 16784.0703
Epoch [20/30] Batch 10/32 Loss: 14578.4404
Epoch [20/30] Batch 20/32 Loss: 14651.3594
Epoch [20/30] Batch 30/32 Loss: 14033.8193
Epoch [20/30] Average Loss: 14475.3006


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [21/30] Batch 0/32 Loss: 14132.6250
Epoch [21/30] Batch 10/32 Loss: 13668.4053
Epoch [21/30] Batch 20/32 Loss: 13411.3545
Epoch [21/30] Batch 30/32 Loss: 12956.0918
Epoch [21/30] Average Loss: 13221.5596


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [22/30] Batch 0/32 Loss: 13026.9404
Epoch [22/30] Batch 10/32 Loss: 12652.2695
Epoch [22/30] Batch 20/32 Loss: 12576.6748
Epoch [22/30] Batch 30/32 Loss: 12207.3164
Epoch [22/30] Average Loss: 12349.1758


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [23/30] Batch 0/32 Loss: 12552.9863
Epoch [23/30] Batch 10/32 Loss: 11736.0234
Epoch [23/30] Batch 20/32 Loss: 11813.2188
Epoch [23/30] Batch 30/32 Loss: 11869.5361
Epoch [23/30] Average Loss: 11568.0940


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [24/30] Batch 0/32 Loss: 11316.7070
Epoch [24/30] Batch 10/32 Loss: 11445.5605
Epoch [24/30] Batch 20/32 Loss: 11754.9141
Epoch [24/30] Batch 30/32 Loss: 10868.7920
Epoch [24/30] Average Loss: 11154.3212


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [25/30] Batch 0/32 Loss: 10653.5586
Epoch [25/30] Batch 10/32 Loss: 10440.8086
Epoch [25/30] Batch 20/32 Loss: 10508.2393
Epoch [25/30] Batch 30/32 Loss: 9825.2119
Epoch [25/30] Average Loss: 10256.0155


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [26/30] Batch 0/32 Loss: 9852.1182
Epoch [26/30] Batch 10/32 Loss: 9853.0303
Epoch [26/30] Batch 20/32 Loss: 9396.4629
Epoch [26/30] Batch 30/32 Loss: 9214.8857
Epoch [26/30] Average Loss: 9425.1381


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [27/30] Batch 0/32 Loss: 9290.0547
Epoch [27/30] Batch 10/32 Loss: 9019.5625
Epoch [27/30] Batch 20/32 Loss: 9119.1250
Epoch [27/30] Batch 30/32 Loss: 9129.7695
Epoch [27/30] Average Loss: 8973.1868


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [28/30] Batch 0/32 Loss: 9261.6064
Epoch [28/30] Batch 10/32 Loss: 8866.3643
Epoch [28/30] Batch 20/32 Loss: 8528.5479
Epoch [28/30] Batch 30/32 Loss: 8744.6914
Epoch [28/30] Average Loss: 8423.3028


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [29/30] Batch 0/32 Loss: 8172.4219
Epoch [29/30] Batch 10/32 Loss: 8320.1416
Epoch [29/30] Batch 20/32 Loss: 8369.6152
Epoch [29/30] Batch 30/32 Loss: 7802.2485
Epoch [29/30] Average Loss: 7863.9409


  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Epoch [30/30] Batch 0/32 Loss: 8757.5225
Epoch [30/30] Batch 10/32 Loss: 7733.4941
Epoch [30/30] Batch 20/32 Loss: 7478.1514
Epoch [30/30] Batch 30/32 Loss: 7386.3745
Epoch [30/30] Average Loss: 7469.8052
Model state saved to trained_models/model_state_1000.pth


In [None]:
# class model():
#   import a pretrained -> gives outputs
#   take pretrained outputs -> classify using a neural net


# Testing

In [12]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader

# Make sure you have imported your custom_collate_fn from your datasets module
# from datasets import custom_collate_fn

# Rebuild your test loader with the custom collate function.
test_loader = DataLoader(
    test_loader.dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=custom_collate_fn
)

# ---------------------------
# Utility Functions for Evaluation
# ---------------------------
def compute_iou(box1, box2):
    """
    Computes Intersection over Union (IoU) for two boxes.
    Boxes are in the format [x1, y1, x2, y2].
    """
    x1, y1, x2, y2 = box1
    x1g, y1g, x2g, y2g = box2
    
    inter_x1 = max(x1, x1g)
    inter_y1 = max(y1, y1g)
    inter_x2 = min(x2, x2g)
    inter_y2 = min(y2, y2g)
    
    inter_area = max(inter_x2 - inter_x1, 0) * max(inter_y2 - inter_y1, 0)
    area1 = (x2 - x1) * (y2 - y1)
    area2 = (x2g - x1g) * (y2g - y1g)
    union_area = area1 + area2 - inter_area + 1e-6  # avoid division by zero
    return inter_area / union_area

def convert_gt_to_pixels(gt_list, img_size=416):
    """
    Converts a list of ground truth boxes in YOLO format 
    [class, cx, cy, w, h] (normalized) to pixel coordinates [x1, y1, x2, y2, class].
    """
    converted = []
    for gt in gt_list:
        cls, cx, cy, w, h = gt
        x1 = (cx - w/2) * img_size
        y1 = (cy - h/2) * img_size
        x2 = (cx + w/2) * img_size
        y2 = (cy + h/2) * img_size
        converted.append([x1, y1, x2, y2, int(cls)])
    return converted

def evaluate_detections(all_preds, all_gts, iou_threshold=0.01):
    """
    Compares predictions and ground truths across all images.
    For each image, a prediction is considered a true positive (TP) if it matches a ground truth (GT)
    with the same class and IoU >= iou_threshold. Otherwise, it is a false positive (FP).
    Ground truths with no matching prediction are counted as false negatives (FN).
    Returns a dictionary of overall metrics.
    """
    total_TP = 0
    total_FP = 0
    total_FN = 0
    total_iou = 0.0
    iou_count = 0
    
    # Loop over each image.
    for preds, gt in zip(all_preds, all_gts):
        # Convert ground truths to pixel coordinates.
        gts_pixels = convert_gt_to_pixels(gt, img_size=416)
        matched_gts = set()  # to keep track of ground truths that are already matched
        
        # Process each prediction.
        for pred in preds:
            # pred is [x1, y1, x2, y2, conf, cls] with pixel coordinates
            x1p, y1p, x2p, y2p, conf, cls_pred = pred
            best_iou = 0.0
            best_gt_idx = -1
            for idx, gt_box in enumerate(gts_pixels):
                # Only consider ground truths of the same class.
                if gt_box[4] != int(cls_pred):
                    continue
                iou_val = compute_iou(pred[:4], gt_box[:4])
                if iou_val > best_iou:
                    best_iou = iou_val
                    best_gt_idx = idx
            
            # A valid match is found if IoU is above threshold and the GT hasn't been matched.
            if best_iou >= iou_threshold and best_gt_idx not in matched_gts:
                total_TP += 1
                total_iou += best_iou
                iou_count += 1
                matched_gts.add(best_gt_idx)
            else:
                total_FP += 1
        
        # All GTs not matched are false negatives.
        total_FN += len(gts_pixels) - len(matched_gts)
    
    precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0
    recall    = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0
    f1_score  = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    avg_iou   = total_iou / iou_count if iou_count > 0 else 0
    
    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "avg_iou": avg_iou,
        "true_positives": total_TP,
        "false_positives": total_FP,
        "false_negatives": total_FN,
    }

# ---------------------------
# Modified decode_predictions Function (with debug prints removed)
# ---------------------------
def decode_predictions(preds, conf_thresh=0.25):
    """
    Decodes predictions from a list of tensors (one per scale).
    Returns a list of detections for one image.
    Each detection is [x1, y1, x2, y2, conf, cls].
    """
    detections = []
    for pred in preds:
        pred = pred.detach().cpu()
        if pred.ndim == 4:
            B, C, grid_h, grid_w = pred.shape
            if grid_h != grid_w:
                raise ValueError(f"Expected square grid but got {grid_h} and {grid_w}")
            grid_size = grid_h
            num_anchors = 3  # Adjust if needed.
            num_classes = (C // num_anchors) - 5
            pred = pred.view(B, num_anchors, 5 + num_classes, grid_size, grid_size)
            pred = pred.permute(0, 1, 3, 4, 2).contiguous()
        elif pred.ndim != 5:
            raise ValueError(f"Unexpected tensor shape: {pred.shape}")
        
        B, A, grid_size, _, channels = pred.shape
        # Process only the first image in the batch.
        pred = pred[0]
        num_anchors = pred.size(0)
        
        # Create grids for offset calculation.
        grid_x = torch.arange(grid_size).repeat(grid_size, 1).view(1, grid_size, grid_size).float()
        grid_y = torch.arange(grid_size).repeat(grid_size, 1).t().view(1, grid_size, grid_size).float()
        
        box = torch.zeros_like(pred[..., :4])
        box[..., 0] = (torch.sigmoid(pred[..., 0]) + grid_x) / grid_size
        box[..., 1] = (torch.sigmoid(pred[..., 1]) + grid_y) / grid_size
        
        dummy_anchor = torch.tensor([0.5, 0.5]).view(1, 1, 1, 2).type_as(pred)
        box[..., 2] = dummy_anchor[..., 0] * torch.exp(pred[..., 2])
        box[..., 3] = dummy_anchor[..., 1] * torch.exp(pred[..., 3])
        
        conf = torch.sigmoid(pred[..., 4])
        cls_score = pred[..., 5:]
        cls_prob = torch.softmax(cls_score, dim=-1)
        cls_conf, cls_pred = torch.max(cls_prob, dim=-1)
        final_conf = conf * cls_conf
        
        for a in range(num_anchors):
            for i in range(grid_size):
                for j in range(grid_size):
                    if final_conf[a, i, j] > conf_thresh:
                        x_center, y_center, w, h = box[a, i, j]
                        # Scale to original image dimensions (assumed to be 416x416).
                        x1 = (x_center - w/2) * 416
                        y1 = (y_center - h/2) * 416
                        x2 = (x_center + w/2) * 416
                        y2 = (y_center + h/2) * 416
                        detections.append([
                            x1.item(), y1.item(), x2.item(), y2.item(),
                            final_conf[a, i, j].item(), cls_pred[a, i, j].item()
                        ])
    return detections

# ---------------------------
# Updated Inference Function with Metrics Collection
# ---------------------------
def inference(model, dataloader, device, output_dir='output', conf_thresh=0.3, iou_threshold=0.5):
    """
    Runs inference on the given dataloader, saves the detection images, and collects results to evaluate metrics.
    """
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    cls_names = ['gun', 'knife', 'wrench', 'pliers', 'scissors', 'hammer']
    
    all_preds = []  # List to store predictions for each image.
    all_gts = []    # List to store ground truth boxes for each image.
    
    with torch.no_grad():
        for idx, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            preds_scales = model(images)  # List of predictions for each scale.
            
            for i in range(images.size(0)):
                # Prepare image for drawing.
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = (img * 255).astype(np.uint8)
                img = np.ascontiguousarray(img)  # Ensure contiguous layout.
                
                # Decode predictions from each scale.
                img_preds = [p[i:i+1] for p in preds_scales]
                dets = decode_predictions(img_preds, conf_thresh=conf_thresh)
                # Perform Non-Maximum Suppression.
                dets = non_max_suppression(dets, conf_thresh=conf_thresh, iou_thresh=iou_threshold)
                
                # Save detections for this image.
                all_preds.append(dets)
                # Store ground truths. targets[i] is a tensor of shape [N, 5].
                gt_list = targets[i].cpu().numpy().tolist()
                all_gts.append(gt_list)
                
                # Draw detections on the image.
                for det in dets:
                    x1, y1, x2, y2, conf, cls = det
                    cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                    cv2.putText(img, f"{cls_names[int(cls)]} {conf:.2f}", (int(x1), int(y1)-10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                # out_path = os.path.join(output_dir, f"result_{idx}_{i}.jpg")
                # cv2.imwrite(out_path, img)
                # print(f"Saved detection result to {out_path}")
    
    # After processing all images, evaluate detections.
    metrics = evaluate_detections(all_preds, all_gts, iou_threshold=iou_threshold)
    print("Evaluation Metrics:")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall:    {metrics['recall']:.4f}")
    print(f"F1 Score:  {metrics['f1_score']:.4f}")
    print(f"Avg IoU:   {metrics['avg_iou']:.4f}")
    print(f"TP: {metrics['true_positives']}  FP: {metrics['false_positives']}  FN: {metrics['false_negatives']}")

# ---------------------------
# Non-Maximum Suppression Function (unchanged)
# ---------------------------
def non_max_suppression(detections, conf_thresh=0.25, iou_thresh=0.5):
    """
    Applies Non-Maximum Suppression (NMS) on the detections.
    Returns the final list of detection boxes.
    """
    if len(detections) == 0:
        return []
    detections = np.array(detections)
    # Filter detections with confidence lower than the threshold.
    detections = detections[detections[:, 4] >= conf_thresh]
    if len(detections) == 0:
        return []
    # Sort detections by confidence (highest first).
    indices = np.argsort(-detections[:, 4])
    detections = detections[indices]
    final_dets = []
    while len(detections) > 0:
        best = detections[0]
        final_dets.append(best)
        if len(detections) == 1:
            break
        rest = detections[1:]
        x1 = np.maximum(best[0], rest[:, 0])
        y1 = np.maximum(best[1], rest[:, 1])
        x2 = np.minimum(best[2], rest[:, 2])
        y2 = np.minimum(best[3], rest[:, 3])
        inter_area = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
        best_area = (best[2] - best[0]) * (best[3] - best[1])
        rest_area = (rest[:, 2] - rest[:, 0]) * (rest[:, 3] - rest[:, 1])
        iou = inter_area / (best_area + rest_area - inter_area + 1e-6)
        detections = rest[iou < iou_thresh]
    return final_dets

# ---------------------------
# Main Inference Execution
# ---------------------------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Assume model and test_loader are defined elsewhere.
    # For example:
    #   from model_definition import YOLOTModel
    #   model = YOLOTModel(num_classes=6).to(device)
    
    checkpoint_path = 'trained_models/model_state_1000.pth'
    if os.path.exists(checkpoint_path):
        # Loading the saved checkpoint.
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print("Loaded checkpoint for inference.")
    else:
        print("Checkpoint not found. Using current model weights.")
    
    # Run inference and evaluation.
    inference(model, test_loader, device, output_dir='output')

Loaded checkpoint for inference.


  model.load_state_dict(torch.load(checkpoint_path, map_location=device))
  check_for_updates()
  check_for_updates()
  check_for_updates()
  check_for_updates()


Evaluation Metrics:
Precision: 0.0000
Recall:    0.0000
F1 Score:  0.0000
Avg IoU:   0.0000
TP: 0  FP: 0  FN: 3
