In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import cv2
import numpy as np

In [4]:
from focalLoss import FocalLoss
from models import RetinaNet

In [5]:
def compute_loss(cls_outputs, reg_outputs, targets, criterion_cls, criterion_reg):
    cls_losses = []
    reg_losses = []

    for cls_output, reg_output, target in zip(cls_outputs, reg_outputs, targets):
        cls_target = target['label']
        reg_target = target['boxes']

        cls_loss = criterion_cls(cls_output, cls_target)
        cls_losses.append(cls_loss)

        reg_loss = criterion_reg(reg_output, reg_target)
        reg_losses.append(reg_loss)

        return sum(cls_losses) / len(cls_losses), sum(reg_losses) / len(reg_losses)


In [6]:
def load_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    img = transform(img)
    return img.unsqueeze(0)

In [7]:
def train_model(model, dataloader, num_epochs, device, patience=5):
    model.to(device)
    optimiser = optim.Adam(model.parameters(), lr=0.001)
    criterion_cls = FocalLoss()
    criterion_reg = nn.SmoothL1Loss()
    model.train()
    best_loss = np.inf
    counter = 0

    for epoch in range(num_epochs):
        running_loss_cls = 0.0
        running_loss_reg = 0.0
        for images, targets in dataloader:
            images = images.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            optimiser.zero_grad()
            cls_outputs, reg_outputs = model(images)
            loss_cls, loss_reg = compute_loss(cls_outputs, reg_outputs, targets, criterion_cls, criterion_reg)
            loss = loss_cls + loss_reg
            loss.backward()
            optimiser.step()
            running_loss_cls += loss_cls.item()
            running_loss_reg += loss_reg.item()
        print(f"Epoch {epoch + 1}/{num_epochs}, Classification Loss: {running_loss_cls / len(dataloader)}",
              f"Regression Loss: {running_loss_reg / len(dataloader)}")
        epoch_loss = running_loss_cls + running_loss_reg

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            counter = 0
        else:
            counter += 1

        if counter >= patience:
            raise ValueError("Early Stopping")

In [8]:
def visualise_detections(img_path, detections, threshold=0.5):
    img = cv2.imread(img_path)
    for detection in detections:
        if detection['score'] > threshold:
            x1, y1, x2, y2 = detection['bbox']
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            label = f"{detection['class']}:{detection['score']:.2f}"
            cv2.putText(img, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    cv2.imshow("Detections", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()