In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
from math import ceil, floor
from datetime import datetime 
from dataclasses import dataclass
from torchvision import datasets, transforms
from torch.utils.data import Dataset as torchDataset
from src.models.localization.cnn_localizer import CNNLocalizer
from src.util.loss_funcs import localization_loss
import seaborn as sns
from src.util.transform_dataset import TransformDataset, get_transform


# Setup and constants

In [None]:
torch.manual_seed(123)
torch.set_default_dtype(torch.float32) # TODO maybe remove
batch_sizes = 128

# Model definition

done in separate file

## Localization


done in separate file

### Load data and preprocessing

done in seperate file

In [None]:
def intersection(bb1, bb2):
    left = max(bb1[0] - bb1[2]/2, bb2[0] - bb2[2]/2)
    right = min(bb1[0] + bb1[2]/2, bb2[0] + bb2[2]/2)
    top = max(bb1[1] - bb1[3]/2, bb2[1] - bb2[3]/2)
    bot = min(bb1[1] + bb1[3]/2, bb2[1] + bb2[3]/2)

    if left >= right or bot >= top:
        return 0
    
    width = right - left
    height = top - bot

    return width*height
   
def IoU(bb1, bb2):
    intersect_area = intersection(bb1, bb2)
    return intersect_area / (bb1[2]*bb1[3] + bb2[2] * bb2[3] - intersect_area)
   
def compute_IoU_localization(model, loader, preprocessor):
    """
    Compute IoU performance of the model on the given dataset
    """
    IoU_scores = []
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out, labels):
            bb1 = pred[1:5]
            bb2 = target[1:5]
            IoU_scores.append(IoU(bb1, bb2))
    
    return torch.mean(torch.Tensor(IoU_scores))


def compute_accuracy_localization(model, loader, preprocessor):
    """
    Compute accuracy of the model on the given dataset
    """
    accuracy_scores = 0
    n = 0
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out,labels):
            pred_class = torch.argmax(pred[5:])
            pc = pred[0] > 0 
            accuracy_scores += pred_class == target[5] and pc == target[0]
            n+=1
            
    
    return accuracy_scores / n

In [None]:
# Load the data
train = torch.load("data/localization_train.pt")
train_transform = get_transform(train.tensors[0])
train = TransformDataset(train, train_transform)

val = torch.load("data/localization_val.pt")
val = TransformDataset(val, train_transform)

test = torch.load("data/localization_test.pt")
test = TransformDataset(test, train_transform)

# TODO seed data loaders

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_sizes, shuffle=False)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_sizes, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_sizes, shuffle=False)


In [None]:
import optuna

def objective(trial: optuna.trial.FrozenTrial):
    learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True)
    num_epochs = trial.suggest_int("num_epochs", 5, 20)

    model = CNNLocalizer(loss_fn=localization_loss, learning_rate=learning_rate, num_epochs=num_epochs)
    model.fit(train_loader)

    IoU_score = compute_IoU_localization(model, val_loader, None)
    accuracy_score = compute_accuracy_localization(model, val_loader, None)

    return (IoU_score + accuracy_score) / 2

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)

print(f'Best validation score: {study.best_value}')
print(f'Best params: {study.best_params}')

### Training

In [None]:
def draw(img, out, target):
    po, xo, yo, wo, ho = out[0:5]
    pt, xt, yt, wt, ht = target[0:5]

    fig, ax = plt.subplots()

    img = img.squeeze(0).numpy()
    ax.imshow(img, cmap="gray")
    ax.axis("off")

    rectOut = plt.Rectangle(((xo-wo/2) * 60, (yo-ho/2)*48), wo*60, ho*48, linewidth=3, edgecolor='r', facecolor='none')
    rectTarget = plt.Rectangle(((xt-wt/2) * 60, (yt-ht/2)*48), wt*60, ht*48, linewidth=3, edgecolor='g', facecolor='none')

    ax.add_patch(rectOut)
    ax.add_patch(rectTarget)
    ax.text(0, 53, f"{out},\n{target}")

In [None]:
model = CNNLocalizer(loss_fn=localization_loss, learning_rate=study.best_params["learning_rate"], num_epochs=study.best_params["num_epochs"])
model.fit(train_loader)
images, labels = next(iter(val_loader))
outs = model.predict(images)
for i in range(10):
    draw(images[i], outs[i].cpu(), labels[i])

### Predictions

### Model selection and evaluation