In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from src.models.localization.cnn_localizer import CNNLocalizer
from src.util.loss_funcs import detection_loss
from src.util.transform_dataset import TransformDataset, get_transform
import matplotlib.pyplot as plt

# Detection

In [None]:
torch.manual_seed(123)
torch.set_default_dtype(torch.float32) # TODO maybe remove
batch_size = 128       
torch.set_printoptions(profile="full")     

### Load data and preprocessing

In [None]:
H_in, W_in = 48, 60
H_out, W_out = 2, 3
CELL_WIDTH, CELL_HEIGHT = W_in / W_out, H_in / H_out

def get_cell(x, y):
    row = (y * H_in) // (CELL_HEIGHT)
    col = (x * W_in) // (CELL_WIDTH)
    return int(row), int(col)


def convert_Y_label(Y:torch.Tensor):
    converted_Y = [[[0,0,0,0,0,0] for _ in range(W_out)] for _ in range(H_out)]

    for digit in Y:
        p, x, y, w, h, c = digit
        row, col = get_cell(x.item(), y.item())

        x = ((x * W_in) - col * CELL_WIDTH) / (CELL_WIDTH)
        y = ((y * H_in) - row * CELL_HEIGHT) / (CELL_HEIGHT)
        w *= W_out
        h *= H_out
         
        converted_Y[row][col] = [p, x, y, w, h, c]

    return torch.Tensor(converted_Y)

def revert_y_label(Y: torch.Tensor):
    reverted_y = []
    for row in range(H_out):
        for col in range(W_out):
            p, x, y, w, h, c = Y[row][col]
            x = CELL_WIDTH * x + col * CELL_WIDTH
            Y = CELL_HEIGHT * y + row * CELL_HEIGHT
            w /= W_out
            h /= H_out

            reverted_y.append((p, x, y, w, h, c))
    return torch.Tensor(reverted_y)

In [None]:
train_true = torch.load("data/list_y_true_train.pt")
val_true = torch.load("data/list_y_true_val.pt")
test_true = torch.load("data/list_y_true_test.pt")

train_images = torch.load("data/detection_train.pt", weights_only=False).tensors[0]
print(train_images.shape)
val_images = torch.load("data/detection_val.pt", weights_only=False).tensors[0]
test_images = torch.load("data/detection_test.pt", weights_only=False).tensors[0]


converted_data = [torch.zeros(N, H_out, W_out, 6) for N in [len(train_true), len(val_true), len(test_true)]]
for i, dataset in enumerate([train_true, val_true, test_true]):
    for j in range(len(dataset)):
        converted_data[i][j] = convert_Y_label(dataset[j])

train_labels, val_labels, test_labels = converted_data

transforms = get_transform(train_images)

train_loader = DataLoader(TransformDataset(TensorDataset(train_images, train_labels), transforms), batch_size=batch_size, shuffle=False)
val_loader = DataLoader(TransformDataset(TensorDataset(val_images, val_labels), transforms), batch_size=batch_size, shuffle=False)
test_loader = DataLoader(TransformDataset(TensorDataset(test_images, test_labels), transforms), batch_size=batch_size, shuffle=False)


### Training

In [None]:
from src.models.detection.cnn_detector import CNNDetector
from itertools import product
learning_rates = [0.001]
epochs = [20]

models = {}

for learning_rate, num_epochs in product(learning_rates, epochs):
    model = CNNDetector(loss_fn=detection_loss, learning_rate=learning_rate, num_epochs=num_epochs)
    model.fit(train_loader)

    # IoU_score = compute_IoU_localization(model, val_loader, None)
    # accuracy_score = compute_accuracy_localization(model, val_loader, None)

    # print(f'Learning rate: {learning_rate}, num_epochs: {num_epochs}')
    # print(f'IoU score: {IoU_score}')
    # print(f'Accuracy score: {accuracy_score}')
    # models[model] = (IoU_score+accuracy_score) / 2

### Prediction

In [None]:
def draw_multiple(img, out, target):

    fig, ax = plt.subplots()


    img = img.squeeze(0).numpy()
    ax.imshow(img, cmap="gray")
    ax.axis("off")
    for row in out.shape[0]: 
        for col in out.shape[1]: 
            po, xo, yo, wo, ho, _ = revert_y_label(out[row][col])
            pt, xt, yt, wt, ht, _ = revert_y_label(target[row][col])

            rectOut = plt.Rectangle(((xo-wo/2) * 60, (yo-ho/2)*48), wo*60, ho*48, linewidth=3, edgecolor='r', facecolor='none')
            rectTarget = plt.Rectangle(((xt-wt/2) * 60, (yt-ht/2)*48), wt*60, ht*48, linewidth=3, edgecolor='g', facecolor='none')

            if po > 0:
                ax.add_patch(rectOut)
            if pt:
                ax.add_patch(rectTarget)
            # ax.text(0, 53, f"{out},\n{target[row][col]}")

In [None]:
best_model = max(models, key=models.get)
images, labels = next(iter(val_loader))
outs = model.predict(images).cpu()
for i in range(20):
    draw_multiple(images[i], outs[i], labels[i])