In [None]:
import torch
import matplotlib.pyplot as plt
from src.models.localization.cnn_localizer import CNNLocalizer
from src.util.loss_funcs import localization_loss
from src.util.transform_dataset import TransformDataset, get_transform
from torch.nn import Sigmoid
import torch.nn.functional as F
import numpy as np

In [None]:
torch.backends.mps.is_available()

# Setup and constants

In [None]:
torch.manual_seed(123)
torch.set_default_dtype(torch.float32) # TODO maybe remove
batch_size = 128

# Model definition

done in separate file

## Localization


done in separate file

### Load data and preprocessing

done in seperate file

In [None]:
def intersection(bb1, bb2):
    left = max(bb1[0] - bb1[2]/2, bb2[0] - bb2[2]/2)
    right = min(bb1[0] + bb1[2]/2, bb2[0] + bb2[2]/2)
    top = max(bb1[1] - bb1[3]/2, bb2[1] - bb2[3]/2)
    bot = min(bb1[1] + bb1[3]/2, bb2[1] + bb2[3]/2)

    if left >= right or bot >= top:
        return 0
    
    width = right - left
    height = top - bot

    return width*height
   
def IoU(bb1, bb2):
    intersect_area = intersection(bb1, bb2)
    return intersect_area / (bb1[2]*bb1[3] + bb2[2] * bb2[3] - intersect_area)
   
def compute_IoU_localization(model, loader, preprocessor):
    """
    Compute IoU performance of the model on the given dataset
    """
    IoU_scores = []
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out, labels):
            bb1 = pred[1:5]
            bb2 = target[1:5]
            predicted_detection = F.sigmoid(pred[0]).item() > 0.5
            IoU_scores.append(IoU(bb1, bb2) if target[0] else predicted_detection == False)
    
    return torch.mean(torch.Tensor(IoU_scores))


def compute_accuracy_localization(model, loader, preprocessor):
    """
    Compute accuracy of the model on the given dataset
    """
    accuracy_scores = []
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out,labels):
            pred_class = torch.argmax(pred[5:])
            predicted_detection = F.sigmoid(pred[0]).item() > 0.5
            accuracy_scores.append(pred_class == target[5] and predicted_detection == target[0] or not target[0] and not predicted_detection)
            
    return torch.mean(torch.Tensor(accuracy_scores))

In [None]:
# Load the data
train = torch.load("data/localization_train.pt", weights_only=False)
train_transform = get_transform(train.tensors[0])
train = TransformDataset(train, train_transform)

val = torch.load("data/localization_val.pt", weights_only=False)
val = TransformDataset(val, train_transform)

test = torch.load("data/localization_test.pt", weights_only=False)
test = TransformDataset(test, train_transform)

# TODO seed data loaders

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)


In [None]:
from itertools import product
learning_rates = [1e-2, 1e-1]
epochs = [0, 5, 10]

models = {}

for learning_rate, num_epochs in product(learning_rates, epochs):
    model = CNNLocalizer(loss_fn=localization_loss, learning_rate=learning_rate, num_epochs=num_epochs)
    model.fit(train_loader)

    IoU_score = compute_IoU_localization(model, val_loader, None)
    accuracy_score = compute_accuracy_localization(model, val_loader, None)

    print(f'Learning rate: {learning_rate}, num_epochs: {num_epochs}')
    print(f'IoU score: {IoU_score}')
    print(f'Accuracy score: {accuracy_score}')
    models[model] = (IoU_score+accuracy_score) / 2

### Training

In [None]:
def draw(img, out, target):
    po, xo, yo, wo, ho = out[0:5]
    pt, xt, yt, wt, ht = target[0:5]

    fig, ax = plt.subplots()

    img = img.squeeze(0).numpy()
    ax.imshow(img, cmap="gray")
    ax.axis("off")

    rectOut = plt.Rectangle(((xo-wo/2) * 60, (yo-ho/2)*48), wo*60, ho*48, linewidth=3, edgecolor='r', facecolor='none')
    rectTarget = plt.Rectangle(((xt-wt/2) * 60, (yt-ht/2)*48), wt*60, ht*48, linewidth=3, edgecolor='g', facecolor='none')

    ax.add_patch(rectOut)
    ax.add_patch(rectTarget)
    ax.text(0, 53, f"{out},\n{target}")

### Predictions

In [None]:
best_model = max(models, key=models.get)
images, labels = next(iter(val_loader))
outs = model.predict(images).cpu()
for i in range(20):
    draw(images[i], outs[i], labels[i])

### Model selection and evaluation