In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import cv2
import numpy as np
import wandb

In [2]:
from focalLoss import FocalLoss
from models import RetinaNet
from dataset import PascalDataset

In [3]:
def compute_loss(cls_outputs, reg_outputs, targets, criterion_cls, criterion_reg):
    cls_losses = []
    reg_losses = []
    
    for cls_output, reg_output, target in zip(cls_outputs, reg_outputs, targets):
        cls_target = target['labels']
        reg_target = target['boxes']
        
        # Flatten cls_output and cls_target
        cls_output = cls_output.permute(0, 2, 3, 1).contiguous().view(-1, cls_output.size(1))
        cls_target = cls_target.view(-1)
        
        # Ensure cls_output and cls_target are of the same size
        min_size = min(cls_output.size(0), cls_target.size(0))
        cls_output = cls_output[:min_size]
        cls_target = cls_target[:min_size]
        
        # Match reg_output and reg_target
        reg_output = reg_output.view(-1, 4)
        reg_target = reg_target.view(-1, 4)
        
        # Ensure reg_output and reg_target are of the same size
        min_size = min(reg_output.size(0), reg_target.size(0))
        reg_output = reg_output[:min_size]
        reg_target = reg_target[:min_size]
        
        # Compute classification loss
        cls_loss = criterion_cls(cls_output, cls_target)
        cls_losses.append(cls_loss)
        
        # Compute regression loss
        reg_loss = criterion_reg(reg_output, reg_target)
        reg_losses.append(reg_loss)
    
    # Aggregate losses and return mean values
    return torch.mean(torch.stack(cls_losses)), torch.mean(torch.stack(reg_losses))

In [4]:
def load_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = v2.Compose([
        v2.ToDtype(torch.float32),
        v2.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    img = transform(img)
    return img.unsqueeze(0)

In [5]:
def train_model(model, dataloader, num_epochs, device, patience=5):
    model.to(device)
    optimiser = optim.AdamW(model.parameters(), lr=3e-4)
    criterion_cls = FocalLoss()
    criterion_reg = nn.SmoothL1Loss()
    model.train()
    best_loss = np.inf
    counter = 0

    for epoch in range(num_epochs):
        running_loss_cls = 0.0
        running_loss_reg = 0.0
        for images, targets in dataloader:
            images = torch.stack(images).to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            # Forward pass
            with torch.set_grad_enabled(True):
                optimiser.zero_grad(set_to_none=True)
                cls_outputs, reg_outputs = model(images)
                loss_cls, loss_reg = compute_loss(cls_outputs, reg_outputs, targets, criterion_cls, criterion_reg)
                loss = loss_cls + loss_reg
                
                # Backward pass
                loss.backward()
                optimiser.step()
            
            running_loss_cls += loss_cls.item()
            running_loss_reg += loss_reg.item()
        
        # Logging
        epoch_loss = running_loss_cls + running_loss_reg
        wandb.log({
            'epoch': epoch,
            'classification_loss': running_loss_cls,
            'regression_loss': running_loss_reg, 
            'epoch_loss': epoch_loss 
        })
        
        # Early stopping
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            counter = 0
        else:
            counter += 1
        
        if counter >= patience:
            raise ValueError("Early Stopping")
        
        # Print progress
        print(f"Epoch {epoch + 1}/{num_epochs}, Classification Loss: {running_loss_cls / len(dataloader)}",
              f"Regression Loss: {running_loss_reg / len(dataloader)}")


In [6]:
def visualise_detections(img_path, detections, threshold=0.5):
    img = cv2.imread(img_path)
    for detection in detections:
        if detection['score'] > threshold:
            x1, y1, x2, y2 = detection['bbox']
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            label = f"{detection['class']}:{detection['score']:.2f}"
            cv2.putText(img, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    cv2.imshow("Detections", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [7]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [8]:
transform = v2.Compose([
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [9]:
trainpath = r'VOCtrainval_06-Nov-2007'
trainset = PascalDataset(trainpath, img_set='train', transform=transform)
trainloader = DataLoader(trainset, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_fn, pin_memory=True)

In [10]:
wandb.init(
    # set the wandb project where this run will be logged
    project="retinanet",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.01,
    "architecture": "RetinaNet",
    "dataset": "PASCAL VOC 2007",
    "epochs": 10,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mahaan1984[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RetinaNet(num_classes=len(trainset.classes))
train_model(model, trainloader, num_epochs=10, device=device)