In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import VOCDetection
import numpy as np
from torchvision import transforms
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch import nn, optim
from tqdm import tqdm

In [16]:
class CIFARDataset(Dataset):
    def __init__(self, root, train=True, S=7, transform=None, download=True):
        self.dataset = CIFAR10(root=root, train=train, transform=transform, download=download)
        self.S = S
        self.classes = self.dataset.classes
        self.transform = transform
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]

        x_center = 0.5
        y_center = 0.5
        bw = 1.0
        bh = 1.0

        target_tensor = torch.zeros(self.S, self.S, 5 + len(self.classes))
        i = int(y_center * self.S)
        j = int(x_center * self.S)
        target_tensor[i, j, 0:4] = torch.tensor([x_center, y_center, bw, bh])
        target_tensor[i, j, 4] = 1
        target_tensor[i, j, 5 + label] = 1

        return img, target_tensor   

In [17]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

train_dataset = CIFARDataset(root='./data', train=True, S=7, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [18]:
class YOLOModel(nn.Module):
    def __init__(self, S: int = 7, B: int = 1, C: int = 10):
        super().__init__()
        self.S = S
        self.B = B
        self.C = C

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 16x16 -> 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), # 8x8 -> 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), # 4x4
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((S, S))
        )

        self.pred = nn.Conv2d(256, B * (5 + C), kernel_size=1)

    def forward(self, x):
        x = self.features(x)
        x = self.pred(x)
        x = x.permute(0, 2, 3, 1)
        return x

In [19]:
class YOLOLoss(nn.Module):
    def __init__(self, S=7, B=1, C=10, lambda_coord=5, lambda_noobj=0.5):
        super().__init__()
        self.S = S
        self.B = B
        self.C = C
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj

    def forward(self, predictions, target):
        obj_mask = target[..., 4] == 1

        coord_loss = F.mse_loss(predictions[obj_mask][..., 0:2], target[obj_mask][..., 0:2], reduction='sum') + \
                     F.mse_loss(predictions[obj_mask][..., 2:4], target[obj_mask][..., 2:4], reduction='sum')

        obj_loss = F.binary_cross_entropy_with_logits(predictions[..., 4], target[..., 4], reduction='sum')

        class_loss = F.cross_entropy(predictions[obj_mask][..., 5:], 
                                     target[obj_mask][..., 5:].argmax(dim=-1), reduction='sum') \
                     if obj_mask.any() else torch.tensor(0.0, device=predictions.device)

        loss = self.lambda_coord * coord_loss + obj_loss + class_loss
        return loss / predictions.size(0)

In [20]:
def bbox_iou(bbox1, bbox2):
    b1_x1 = box1[..., 0] - box1[..., 2] / 2
    b1_y1 = box1[..., 1] - box1[..., 3] / 2
    b1_x2 = box1[..., 0] + box1[..., 2] / 2
    b1_y2 = box1[..., 1] + box1[..., 3] / 2

    b2_x1 = box2[..., 0] - box2[..., 2] / 2
    b2_y1 = box2[..., 1] - box2[..., 3] / 2
    b2_x2 = box2[..., 0] + box2[..., 2] / 2
    b2_y2 = box2[..., 1] + box2[..., 3] / 2

    inter_x1 = torch.max(b1_x1, b2_x1)
    inter_y1 = torch.max(b1_y1, b2_y1)
    inter_x2 = torch.min(b1_x2, b2_x2)
    inter_y2 = torch.min(b1_y2, b2_y2)

    inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * \
                 torch.clamp(inter_y2 - inter_y1, min=0)

    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-6)
    return iou

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = YOLOModel(S=7, B=1, C=10).to(device)
criterion = YOLOLoss(S=7, B=1, C=10)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, targets in tqdm(train_loader):
        imgs, targets = imgs.to(device), targets.to(device)

        preds = model(imgs)
        loss = criterion(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {running_loss/len(train_loader):.4f}")

100%|████████████████████████████████████████████████████████████████████████| 12500/12500 [35:33<00:00,  5.86it/s]


Epoch [1/5] Loss: 2.7718


  5%|███▍                                                                      | 590/12500 [01:32<31:37,  6.28it/s]

In [None]:
model.eval()
ious = []
correct = 0
total = 0

with torch.no_grad():
    for imgs, targets in DataLoader(train_dataset, batch_size=32, shuffle=False):
        imgs, targets = imgs.to(device), targets.to(device)
        preds = model(imgs)

        pred_obj = torch.sigmoid(preds[..., 4])
        best_cell = pred_obj.view(imgs.size(0), -1).argmax(dim=1)

        for i in range(imgs.size(0)):
            row = best_cell[i] // preds.size(2)
            col = best_cell[i] % preds.size(2)

            pred_box = preds[i, row, col, 0:4]
            true_box = targets[i, row, col, 0:4]

            iou = bbox_iou(pred_box.unsqueeze(0), true_box.unsqueeze(0))
            ious.append(iou.item())

            pred_class = preds[i, row, col, 5:].argmax()
            true_class = targets[i, row, col, 5:].argmax()

            correct += int(pred_class == true_class)
            total += 1

print(f"Mean IoU: {sum(ious)/len(ious):.4f}")
print(f"Classification Accuracy: {correct/total:.4f}")

In [None]:
inv_transform = T.ToPILImage()

def show_predictions(model, dataset, num_images=5):
    model.eval()
    fig, axs = plt.subplots(1, num_images, figsize=(15, 4))

    for i in range(num_images):
        img, target = dataset[i]
        img_tensor = img.unsqueeze(0).to(device)

        with torch.no_grad():
            pred = model(img_tensor)

        pred = pred.squeeze(0)  # (S, S, 5+C)
        pred_obj = torch.sigmoid(pred[..., 4])
        best_idx = pred_obj.view(-1).argmax()

        row = best_idx // pred.size(0)
        col = best_idx % pred.size(1)

        pred_box = pred[row, col, 0:4].cpu()
        pred_class = pred[row, col, 5:].argmax().item()

        true_box = target[row, col, 0:4]
        true_class = target[row, col, 5:].argmax().item()

        img_show = inv_transform(img.cpu())
        axs[i].imshow(img_show)

        def draw_box(ax, box, color, label):
            x, y, w, h = box
            H, W = img_show.size[1], img_show.size[0]
            x1 = (x - w/2) * W
            y1 = (y - h/2) * H
            rect = plt.Rectangle((x1, y1), w*W, h*H, 
                                 fill=False, color=color, linewidth=2)
            ax.add_patch(rect)
            ax.text(x1, y1, label, color=color, fontsize=8, backgroundcolor="white")

        draw_box(axs[i], true_box, "green", dataset.classes[true_class])
        draw_box(axs[i], pred_box, "red", dataset.classes[pred_class])

        axs[i].axis("off")
        axs[i].set_title(f"GT: {dataset.classes[true_class]}\nPred: {dataset.classes[pred_class]}")

    plt.show()

show_predictions(model, test_dataset, num_images=5)