# Common

### Imports

In [464]:
import torch
import torch.nn as nn

### Config

In [465]:
BATCH_SIZE = 64
EPOCH_COUNT = 5

torch.manual_seed(420)

CLASS_COUNT = 10 # One for each digit
PARAMS_PER_PRED = 5 + CLASS_COUNT # confidence, x, y, width, height, class probabilities

CONFIDENCE_THRESHOLD = 0.5
ACCURACY_WEIGHT = 0.5
IOU_WEIGHT = 1 - ACCURACY_WEIGHT

### Utilities

In [466]:
def load_dataset(name:str):
    train = torch.load(f'data/{name}_train.pt')
    val = torch.load(f'data/{name}_val.pt')
    test = torch.load(f'data/{name}_test.pt')

    print(f"Dataset '{name}'")
    print(f"Training size:   {len(train)}")
    print(f"Validation size: {len(val)}")
    print(f"Test size:       {len(test)}")

    train = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

    return train, val, test

def split_label(label_tensor):
    confidence = label_tensor[:, 0:1]
    box = label_tensor[:, 1:5]
    classes = label_tensor[:, 5:]

    return confidence, box, classes

def get_iou(box1:[float], box2:[float]):
    X, Y, W, H = 0, 1, 2, 3

    # Calculate area of intersection rectangle
    intersection_width = min(box1[X] + box1[W], box2[X] + box2[W]) - max(box1[X], box2[X])
    intersection_height = min(box1[Y] + box1[H], box2[Y] + box2[H]) - max(box1[Y], box2[Y])
    intersection_area = intersection_width * intersection_height

    # Calculate areas of both boxes
    box1_area = box1[W] * box1[H]
    box2_area = box2[W] * box2[H]

    # Calculate union area
    union_area = box1_area + box2_area - intersection_area

    if union_area < 0:
        return 0

    # Calculate IoU
    return intersection_area / union_area
    

# Object localization
Localize and classify images of digits.
- Image dimensions are `height=48`, `wight=60` and `channels=1`
- Each image contains exactly **one** digit

## Load localization datasets

In [467]:
loc_train, loc_val, loc_test = load_dataset('localization')

Dataset 'localization'
Training size:   59400
Validation size: 6600
Test size:       11000


## Define networks

In [468]:
class CnnV1(nn.Module):
    """
    Expected input image to be 48x60x3.
    """

    def __init__(self):
        super(CnnV1, self).__init__()
        
        # Data = 48x60x3

        self.l1_conv = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.l2_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        # Data = 24x30x10

        self.l3_conv = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.l4_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        # Data = 12x15x10

        self.l5_conv = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.l6_pool = nn.MaxPool2d(kernel_size=(2, 3), stride=(2, 3), padding=0)

        # Data = 6x5x10

        self.l7_conv = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1)
        
        # Data = 6x5x10

        self.l8_fc = nn.Linear(in_features=6*5*10, out_features=100)
        self.l9_fc = nn.Linear(in_features=100, out_features=PARAMS_PER_PRED)

    def forward(self, x:torch.Tensor) -> torch.Tensor:

        x = torch.relu(self.l1_conv(x))
        x = self.l2_pool(x)

        x = torch.relu(self.l3_conv(x))
        x = self.l4_pool(x)

        x = torch.relu(self.l5_conv(x))
        x = self.l6_pool(x)

        x = torch.relu(self.l7_conv(x))

        x = x.view(-1, 6*5*10)

        x = torch.relu(self.l8_fc(x))

        x = torch.relu(self.l9_fc(x))

        return x

## Define training

In [469]:
loss_fn_confidence = nn.BCEWithLogitsLoss()
loss_fn_class = nn.CrossEntropyLoss()
loss_fn_box = nn.MSELoss()

def loss_fn(y_true, y_pred):
    # Extract true values
    y_true_confidence, y_true_box, y_true_classes = split_label(y_true)

    # Convert class labels to one-hot
    y_true_classes = torch.nn.functional.one_hot(y_true_classes.long(), num_classes=CLASS_COUNT).float()
    
    # Extract predicted values
    y_pred_confidence, y_pred_box, y_pred_classes = split_label(y_pred)

    # Calculate confidence loss
    loss = loss_fn_confidence(y_pred_confidence, y_true_confidence)

    # Aggregate loss for each label in the batch
    for i in range(y_true.shape[0]):

        contains_object = y_true_confidence[i].item() >= CONFIDENCE_THRESHOLD

        # Only add class and box loss if the label contains an object

        if contains_object:
            loss_class = loss_fn_class(y_pred_classes[i,:], y_true_classes[i,0])
            loss_box = loss_fn_box(y_pred_box[i,:], y_true_box[i,:])
            loss = loss + loss_class + loss_box

    return loss

def train(model):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(EPOCH_COUNT):
        aggregate_loss = 0.0

        for X, Y_true in loc_train:
            optimizer.zero_grad()

            Y_pred = model(X)
            
            loss = loss_fn(Y_true, Y_pred)

            loss.backward()

            optimizer.step()

            aggregate_loss += loss.item()

        print(f"Epoch {epoch+1}/{EPOCH_COUNT} - Loss: {aggregate_loss}")

def validate(model, dataset):
    with torch.no_grad():

        correct_count = 0
        total_count = 0

        aggregate_iou = 0.0
        aggregate_iou_count = 0

        for x, y_true in dataset:

            y_true_confidence = y_true[0:1].item()
            y_true_box = y_true[1:5].tolist()
            y_true_class = y_true[5:].long().item()

            y_pred = model(x)
            y_pred_confidence = torch.sigmoid(y_pred[0, 0:1]).item()
            y_pred_box = y_pred[0, 1:5].tolist()
            y_pred_class = y_pred[0, 5:].argmax(dim=-1).tolist()

            total_count += 1

            if y_true_confidence < CONFIDENCE_THRESHOLD:
                if y_pred_confidence < CONFIDENCE_THRESHOLD:
                    correct_count += 1
                continue

            if y_pred_confidence < CONFIDENCE_THRESHOLD:
                continue

            if y_true_class == y_pred_class:
                correct_count += 1

            aggregate_iou_count += 1
            aggregate_iou += get_iou(y_true_box, y_pred_box)

        accuracy = correct_count / total_count
        iou = aggregate_iou / aggregate_iou_count
        performance = (accuracy * ACCURACY_WEIGHT) + (iou * IOU_WEIGHT)

        print(f"Validation - Accuracy: {accuracy}, IoU: {iou}, Performance: {performance}")

### Train models

In [470]:
cnnv1_1 = CnnV1()

train(cnnv1_1)
validate(cnnv1_1, loc_val)

Epoch 1/5 - Loss: 133675.81755542755
Epoch 2/5 - Loss: 133671.76979637146
Epoch 3/5 - Loss: 133671.40523147583
Epoch 4/5 - Loss: 133407.37353515625
Epoch 5/5 - Loss: 132834.73357963562
Validation - Accuracy: 0.08757575757575757, IoU: 1.2817654849965137, Performance: 0.6846706212861356
