In [None]:
import torch
import matplotlib.pyplot as plt
from src.models.localization.cnn_localizer import CNNLocalizer
from src.util.loss_funcs import localization_loss
from src.util.transform_dataset import TransformDataset, get_transform
from src.util.shrunkDataLoader import ShrunkDataLoader
from itertools import product
from src.models.localization import (
    cnn_network_1,
    cnn_network_2,
    resnet,
)

# Setup and constants


In [None]:
torch.manual_seed(123)
torch.set_default_dtype(torch.float32)  # for mps
batch_size = 128

# Performance metric computation


In [None]:
def intersection(bb1, bb2):
    left = max(bb1[0] - bb1[2] / 2, bb2[0] - bb2[2] / 2)
    right = min(bb1[0] + bb1[2] / 2, bb2[0] + bb2[2] / 2)
    top = max(bb1[1] - bb1[3] / 2, bb2[1] - bb2[3] / 2)
    bot = min(bb1[1] + bb1[3] / 2, bb2[1] + bb2[3] / 2)

    if left >= right or top >= bot:
        return 0

    width = right - left
    height = bot - top

    return width * height


def IoU(bb1, bb2):
    intersect_area = intersection(bb1, bb2)
    return intersect_area / (bb1[2] * bb1[3] + bb2[2] * bb2[3] - intersect_area)


def compute_IoU_localization(model, loader):
    """
    Compute IoU performance of the model on the given dataset
    """
    IoU_scores = []
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out, labels):
            bb1 = pred[1:5]
            bb2 = target[1:5]
            IoU_scores.append(IoU(bb1, bb2) if target[0] else pred[0] == False)

    return torch.mean(torch.Tensor(IoU_scores))


def compute_accuracy_localization(model, loader):
    """
    Compute accuracy of the model on the given dataset
    """
    accuracy_scores = []
    for images, labels in loader:
        out = model.predict(images)
        for pred, target in zip(out, labels):
            accuracy_scores.append(
                pred[5] == target[5]
                and pred[0] == target[0]
                or not target[0]
                and not pred[0]
            )

    return torch.mean(torch.Tensor(accuracy_scores))

# Data loading and preprocessing


In [None]:
train = torch.load("data/localization_train.pt", weights_only=False)
print(train.tensors[0].shape)
train_transform = get_transform(train.tensors[0])
train = TransformDataset(train, train_transform)

val = torch.load("data/localization_val.pt", weights_only=False)
print(val.tensors[0].shape)
val = TransformDataset(val, train_transform)

test = torch.load("data/localization_test.pt", weights_only=False)
print(test.tensors[0].shape)
test = TransformDataset(test, train_transform)


# TODO seed data loaders
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

# Model training and selection


In [None]:
def grid_search(params, train_loader, val_loader):
    models = {}
    max_epochs = 300

    for learning_rate, weight_decay, momentum, network in product(
        params["learning_rates"],
        params["weight_decays"],
        params["momentums"],
        params["networks"],
    ):
        print(f"Starting training for:")
        print(f"Network: {network}")
        print(f"Learning rate: {learning_rate}")
        print(f"Weight decay: {weight_decay}")
        print(f"Momentum: {momentum}")

        model = CNNLocalizer(
            loss_fn=localization_loss,
            learning_rate=learning_rate,
            max_epochs=max_epochs,
            network=network,
            weight_decay=weight_decay,
            momentum=momentum,
        )
        model.fit(train_loader, val_loader)

        IoU_score = compute_IoU_localization(model, val_loader)
        accuracy_score = compute_accuracy_localization(model, val_loader)

        print(f"IoU score: {IoU_score}")
        print(f"Accuracy score: {accuracy_score}")
        models[model] = (IoU_score + accuracy_score) / 2

    best_model = max(models, key=models.get)

    return best_model, models[best_model]

### Hyper params


In [None]:
params = {
    "learning_rates": [1e-3, 1e-4, 1e-5],
    "weight_decays": [1e-2, 0],
    "momentums": [0.8, 0.9],
    "networks": [
        cnn_network_1.CNN1,
        cnn_network_2.CNN2,
        resnet.ResNet18Localization,
    ],
}

### Architechture Selection


In [None]:
shrunken_train_loader = ShrunkDataLoader(train_loader, fraction=0.4)

best_model, val_score = grid_search(params, shrunken_train_loader, val_loader)
best_architechture = best_model.get_params()["network"]
print(
    f"Best architechture: {best_architechture.__name__}\nvalidation score: {val_score}"
)
params["networks"] = [best_architechture]

### Hyper parameter tuning & model selction


In [None]:
best_model, val_score = grid_search(params, train_loader, val_loader)
print(f"Best model: {best_model.get_params()}\nvalidation score: {val_score}")

In [None]:
test_IoU_score = compute_IoU_localization(best_model, test_loader)
test_accuracy_score = compute_accuracy_localization(best_model, test_loader)
test_score = (test_IoU_score + test_accuracy_score) / 2

print("Model performance on unseen data:")
print(f"IoU score: {test_IoU_score}")
print(f"Accuracy score: {test_accuracy_score}")
print(f"Combined score: {test_score}")

### Visualization


In [None]:
def draw(img, out, target):
    po, xo, yo, wo, ho = out[0:5]
    pt, xt, yt, wt, ht = target[0:5]
    fig, ax = plt.subplots()

    img = img.squeeze(0).numpy()
    ax.imshow(img, cmap="gray")
    ax.axis("off")

    rectOut = plt.Rectangle(
        ((xo - wo / 2) * 60, (yo - ho / 2) * 48),
        wo * 60,
        ho * 48,
        linewidth=3,
        edgecolor="r",
        facecolor="none",
    )
    rectTarget = plt.Rectangle(
        ((xt - wt / 2) * 60, (yt - ht / 2) * 48),
        wt * 60,
        ht * 48,
        linewidth=3,
        edgecolor="g",
        facecolor="none",
    )

    if po > 0:
        ax.add_patch(rectOut)
    ax.add_patch(rectTarget)
    ax.text(
        0,
        53,
        f"Predicted — pc: {out[0]}, x: {out[1]}, y: {out[2]}, w: {out[3]}, h: {out[4]}, class: {out[5]}\nTarget — pc: {target[0]}, x: {target[1]}, y: {target[2]}, w: {target[3]}, h: {target[4]}, class: {target[5]}",
    )

In [None]:
images, labels = next(iter(test_loader))
outs = best_model.predict(images).cpu()
for i in range(20):
    draw(images[i], outs[i], labels[i])