In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from src.util.loss_funcs import detection_loss
from src.util.transform_dataset import TransformDataset, get_transform
import matplotlib.pyplot as plt
import numpy as np

# Detection


In [None]:
torch.manual_seed(123)
torch.set_default_dtype(torch.float32)  # TODO maybe remove
batch_size = 128
torch.set_printoptions(profile="full")

### Load data and preprocessing


In [None]:
H_in, W_in = 48, 60
H_out, W_out = 2, 3
CELL_WIDTH, CELL_HEIGHT = W_in / W_out, H_in / H_out


def get_cell(x, y):
    row = (y * H_in) // (CELL_HEIGHT)
    col = (x * W_in) // (CELL_WIDTH)
    return int(row), int(col)


def convert_Y_label(Y):
    converted_Y = [[[0, 0, 0, 0, 0, 0] for _ in range(W_out)] for _ in range(H_out)]

    for digit in Y:
        p, x, y, w, h, c = digit
        row, col = get_cell(x.item(), y.item())
        x = ((x * W_in) - col * CELL_WIDTH) / (CELL_WIDTH)
        y = ((y * H_in) - row * CELL_HEIGHT) / (CELL_HEIGHT)
        w *= W_out
        h *= H_out

        converted_Y[row][col] = [p, x, y, w, h, c]

    return torch.Tensor(converted_Y)


def revert_y_label(Y):
    reverted_y = []
    for row in range(H_out):
        reverted_y.append([])
        for col in range(W_out):
            p, x, y, w, h, c = Y[row][col]
            x = (x * CELL_WIDTH + col * CELL_WIDTH) / W_in
            y = (y * CELL_HEIGHT + row * CELL_HEIGHT) / H_in
            w /= W_out
            h /= H_out

            reverted_y[row].append((p, x, y, w, h, c))
    return torch.Tensor(reverted_y)

In [None]:
train_true = torch.load("data/list_y_true_train.pt")
val_true = torch.load("data/list_y_true_val.pt")
test_true = torch.load("data/list_y_true_test.pt")

train_images = torch.load("data/detection_train.pt", weights_only=False).tensors[0]
val_images = torch.load("data/detection_val.pt", weights_only=False).tensors[0]
test_images = torch.load("data/detection_test.pt", weights_only=False).tensors[0]


converted_data = [
    torch.zeros(N, H_out, W_out, 6)
    for N in [len(train_true), len(val_true), len(test_true)]
]
for i, dataset in enumerate([train_true, val_true, test_true]):
    for j in range(len(dataset)):
        converted_data[i][j] = convert_Y_label(dataset[j])

train_labels, val_labels, test_labels = converted_data

transforms = get_transform(train_images)

train_loader = DataLoader(
    TransformDataset(TensorDataset(train_images, train_labels), transforms),
    batch_size=batch_size,
    shuffle=False,
)
val_loader = DataLoader(
    TransformDataset(TensorDataset(val_images, val_labels), transforms),
    batch_size=batch_size,
    shuffle=False,
)
test_loader = DataLoader(
    TransformDataset(TensorDataset(test_images, test_labels), transforms),
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
import torchmetrics.detection


def prune_data(pred, target):
    reverted_pred, reverted_target = revert_y_label(pred).view(-1, 6), revert_y_label(
        target
    ).view(-1, 6)
    pruned_pred, pruned_target = [], []

    for i in range(pred.shape[0]):
        if reverted_pred[i][0] > 0.5:
            pruned_pred.append(reverted_pred[i])
        if reverted_target[i][0]:
            pruned_target.append(reverted_target[i])

    pruned_preds_tensor = None if not pruned_pred else torch.stack(pruned_pred, dim=0)
    pruned_target_tensor = (
        None if not pruned_target else torch.stack(pruned_target, dim=0)
    )

    return pruned_preds_tensor, pruned_target_tensor


def compute_mAP(model, loader):
    mAP_metric = torchmetrics.detection.MeanAveragePrecision(
        box_format="cxcywh",
        iou_type="bbox",
        iou_thresholds=[0.5],
        backend="pycocotools",
    )
    pred_dicts, target_dicts = [], []

    for X, targets in loader:
        preds = model.predict(X)
        for pred, target in zip(preds, targets):
            pruned_preds, pruned_target = prune_data(pred, target)

            pred_dicts.append(
                {
                    "boxes": (
                        pruned_preds[:, 1:5]
                        if pruned_preds is not None
                        else torch.empty(0, 4)
                    ),
                    "scores": (
                        pruned_preds[:, 0]
                        if pruned_preds is not None
                        else torch.empty(0)
                    ),
                    "labels": (
                        pruned_preds[:, 5].long()
                        if pruned_preds is not None
                        else torch.empty(0, dtype=torch.long)
                    ),
                }
            )
            target_dicts.append(
                {
                    "boxes": (
                        pruned_target[:, 1:5]
                        if pruned_target is not None
                        else torch.empty(0, 4)
                    ),
                    "labels": (
                        pruned_target[:, 5].long()
                        if pruned_target is not None
                        else torch.empty(0, dtype=torch.long)
                    ),
                }
            )

        mAP_metric.update(preds=pred_dicts, target=target_dicts)

    return mAP_metric.compute()

### Training


In [None]:
from src.models.detection.cnn_detector import CNNDetector
from itertools import product

learning_rates = [1e-2, 1e-3, 1e-4]
weight_decays = [1e-2, 1e-3]
momentums = [0.8, 0.9]
max_epochs = 300
models = {}

for learning_rate, weight_decay, momentum in product(
    learning_rates, weight_decays, momentums
):
    model = CNNDetector(
        loss_fn=detection_loss,
        learning_rate=learning_rate,
        max_epochs=max_epochs,
        weight_decay=weight_decay,
        momentum=momentum,
    )
    model.fit(train_loader, val_loader, delta=0.03, patience=3)

    mAP_score = compute_mAP(model, val_loader)
    print(
        f"Model with lr: {learning_rate} — weight wecay: {weight_decay} — momentum: {momentum}\nmAP score: {mAP_score}"
    )
    models[model] = mAP_score

### Prediction


In [None]:
def draw_multiple(img, out, target):
    out = revert_y_label(out)
    target = revert_y_label(target)
    _, ax = plt.subplots()
    img = img.squeeze(0).numpy()
    ax.imshow(img, cmap="gray")

    ax.set_xticks(np.arange(0, img.shape[1], CELL_WIDTH))
    ax.set_yticks(np.arange(0, img.shape[0], CELL_HEIGHT))
    ax.grid(True, color="b", linestyle="-", linewidth=2)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    for row in range(out.shape[0]):
        for col in range(out.shape[1]):
            po, xo, yo, wo, ho, _ = out[row][col]
            pt, xt, yt, wt, ht, _ = target[row][col]
            rectOut = plt.Rectangle(
                ((xo - wo / 2) * 60, (yo - ho / 2) * 48),
                wo * 60,
                ho * 48,
                linewidth=3,
                edgecolor="r",
                facecolor="none",
            )
            rectTarget = plt.Rectangle(
                ((xt - wt / 2) * 60, (yt - ht / 2) * 48),
                wt * 60,
                ht * 48,
                linewidth=3,
                edgecolor="g",
                facecolor="none",
            )
            if po > 0.5:
                ax.add_patch(rectOut)
            ax.add_patch(rectTarget)

In [None]:
best_model = max(models, key=models.get)
images, labels = next(iter(val_loader))
outs = best_model.predict(images).cpu()
for i in range(20):
    draw_multiple(images[i], outs[i], labels[i])