# Project 2

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

RANDOM_SEED = 265
EPOCH_COUNT = 10
BATCH_SIZE = 256

# Object localization
Classify and locate a single digit within the image.

In [None]:
class LocalizationNetwork(nn.Module):
    def __init__(self):
        super(LocalizationNetwork, self).__init__()

        # Input = (1, 48, 60)
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Input = (32, 24, 30)
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Input = (64, 12, 15)
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2,3))
        )

        # Input = (128, 6, 5)
        self.cnn_size = 128 * 6 * 5

        self.confidence = nn.Sequential(
            nn.Linear(self.cnn_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.cnn_size, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.Softmax(dim=1)
        )

        self.bbox = nn.Sequential(
            nn.Linear(self.cnn_size, 512),
            nn.ReLU(),
            nn.Linear(512, 4),
            nn.Sigmoid()
        )

    def forward(self, x) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(-1, self.cnn_size)

        return self.confidence(out), self.classifier(out), self.bbox(out)

In [None]:
def train_localizer(model:nn.Module, optimizer, name:str):

    print(f"Training: {name}")

    losses = []

    model.train()
    for epoch in range(EPOCH_COUNT):

        total_loss = 0.0
        total_size = 0

        for i, (images, labels) in enumerate(train_loader):
            
            # Split the label into the different parts
            true_confidence = labels[:, 0]
            true_class = F.one_hot(labels[:, -1].long(), num_classes=10).float()
            true_bbox = labels[:, 1:5]

            # Create mask for images with objects
            has_object_mask = true_confidence > 0.5
            
            # Zero model gradients.
            optimizer.zero_grad()

            # Make predictions
            pred_confidence, pred_class, pred_bbox = model(images)
            pred_confidence = pred_confidence.squeeze()  # Fix the shape
            
            # We calculate confidence loss for all images
            loss_confidence = F.binary_cross_entropy(pred_confidence, true_confidence)

            # We use the mask to only calculate bbox and class loss for images with objects
            loss_bbox = F.mse_loss(pred_bbox[has_object_mask], true_bbox[has_object_mask])
            loss_class = F.cross_entropy(pred_class[has_object_mask], true_class[has_object_mask])

            # Calculate total loss and gradients
            loss = (loss_confidence + loss_bbox + loss_class)
            loss.backward()

            # Update model weights
            optimizer.step()

            # Keep track of loss
            total_size += images.size(0)
            total_loss += loss.item()

        epoch_loss = total_loss / total_size
        losses.append(epoch_loss)
        torch.save(model.state_dict(), f'models/train/{name}_e{epoch}.pt')
        print(f'Epoch {epoch + 1}/{EPOCH_COUNT}, \tLoss: {epoch_loss}')

    torch.save(model.state_dict(), f'models/{name}.pt')

    # Plot the loss
    plt.plot(losses)
    plt.title(f"Loss for {name}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig(f"assets/loss_{name}.png")

def train_or_load_localizer(model:nn.Module, optimizer, name:str):
    try:
        model.load_state_dict(torch.load(f'models/{name}.pt'))
        model.eval()
        print(f"Loaded: {name}")
    except FileNotFoundError:
        train_localizer(model, optimizer, name)

In [None]:
v1 = LocalizationNetwork()
v1_optimizer = torch.optim.Adam(v1.parameters(), lr=0.0001)
train_or_load_localizer(v1, v1_optimizer, "v1_adam_lr0.0001")

v2 = LocalizationNetwork()
v2_optimizer = torch.optim.Adam(v2.parameters(), lr=0.001)
train_or_load_localizer(v2, v2_optimizer, "v2_adam_lr0.001")

v3 = LocalizationNetwork()
v3_optimizer = torch.optim.Adam(v3.parameters(), lr=0.001, weight_decay=0.001)
train_or_load_localizer(v3, v3_optimizer, "v3_adam_lr0.001_wd0.001")