In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import random_split
from collections import Counter
import seaborn as sns
from sklearn.metrics import confusion_matrix
from torchvision.utils import draw_bounding_boxes
from torchvision.ops import box_convert, box_iou

In [None]:
SEED = 265
torch.manual_seed(SEED)
torch.set_default_dtype(torch.double)
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

## 2 Object Localization
#### First we load and inspect the localization datasets

In [None]:
loc_train = torch.load('data/localization_train.pt')
loc_val = torch.load('data/localization_val.pt')
loc_test = torch.load('data/localization_test.pt')

In [None]:
print(f'Train data size: {len(loc_train)}')
print(f'Val data size: {len(loc_val)}')
print(f'Test data size: {len(loc_test)}')

In [None]:
first_img, first_label = loc_train[0]

print(f'Shape of first image: {first_img.shape}')
print(f'Type of first image: {type(first_img)}')

print(f'\nShape of first label: {first_label.shape}')
print(f'Type of first label: {type(first_label)})')
first_label

In [None]:
def count_instances(data, data_name=None) -> None:
    """Counts the number of instances of each class in a dataset"""
    counter = Counter([99 if label[0] == 0 else int(label[-1]) for _, label in data])
    sorted_counter = dict(sorted(counter.items()))
    if data_name is not None:
        print(f'Class distribution in {data_name}')
    for key, value in sorted_counter.items():
        print(f'{key}: {value}')

# Assuming train_data, val_data, and test_data are defined elsewhere
count_instances(loc_train, 'Training Data')
count_instances(loc_val, 'Validation Data')
count_instances(loc_test, 'Test Data')


#### Plotting one image from each class

In [None]:
def plot_images(data):
    _, axes = plt.subplots(nrows=2, ncols=6, figsize=(8,3))

    for i, ax in enumerate(axes.flat): 

        if i == 10:
            img = next(img for img, label in data if int(label[0]) == 0)
            img = img.numpy().transpose((1, 2, 0))
            ax.imshow(img, cmap='gray')
            ax.set_title('None')
            ax.axis('off')
            continue

        if i == 11:
            ax.axis('off')
            continue
        
        img, bbox = next((img, label[1:5]) for img, label in data if int(label[-1]) == i)
        img_height, img_width = img.shape[-2], img.shape[-1]

        img = (img * 255).byte()

        bbox[0] *= img_width
        bbox[1] *= img_height
        bbox[2] *= img_width
        bbox[3] *= img_height

        bbox = bbox.type(torch.uint8)

        converted_bbox = box_convert(bbox, in_fmt='cxcywh', out_fmt='xyxy')

        img_with_bbox = draw_bounding_boxes(img, converted_bbox.unsqueeze(0), colors='red')
        img_with_bbox  = img_with_bbox.numpy().transpose((1, 2, 0))
        ax.imshow(img_with_bbox, cmap='gray')
        ax.set_title(i)
        ax.axis('off')

In [None]:
def plot_class(data:torch.tensor, class_label:int, start_idx:int=0) -> None:
    """Plots a subplot with 10 images from a given class, starting at a chosen index"""
    class_images = [img for img, label in data if int(label[-1]) == class_label]
    bboxes = [label[1:5] for img, label in data if int(label[-1]) == class_label]
    _, axes = plt.subplots(nrows=2, ncols=5, figsize=(8,3))

    for i, ax in enumerate(axes.flat):

        idx = start_idx + i
        img = class_images[idx]
        bbox = bboxes[idx]

        img_height, img_width = data[0][0].shape[-2], data[0][0].shape[-1]

        img = (img * 255).byte()

        bbox[0] *= img_width
        bbox[1] *= img_height
        bbox[2] *= img_width
        bbox[3] *= img_height

        bbox = bbox.type(torch.uint8)

        converted_bbox = box_convert(bbox, in_fmt='cxcywh', out_fmt='xyxy')

        img_with_bbox = draw_bounding_boxes(img, converted_bbox.unsqueeze(0), colors='lightgreen')
        img_with_bbox  = img_with_bbox.numpy().transpose((1, 2, 0))
        ax.imshow(img_with_bbox, cmap='gray')
        plt.suptitle(f'CLASS {class_label} - Image {start_idx} to {idx}')
        ax.axis('off')

    plt.show()

plot_class(loc_train, 3, 10)

#### Defining a normalizer and a preprocessor

In [None]:
imgs = torch.stack([img for img, _ in loc_train])

# Define normalizer
normalizer_pipe = transforms.Normalize(
    imgs.mean(dim=(0, 2, 3)), 
    imgs.std(dim=(0, 2, 3))
    )

# Define preprocessor including the normalizer
preprocessor = transforms.Compose([
            normalizer_pipe
        ])

In [None]:
loc_train = [(preprocessor(img), label) for img, label in loc_train]
loc_val = [(preprocessor(img), label) for img, label in loc_val]
loc_test = [(preprocessor(img), label) for img, label in loc_test]

#### Defining the loss function

In [None]:
class LocalizationLoss(nn.Module):
    """Custom loss function"""
    def __init__(self):
        super().__init__()
        self.L_a = nn.BCEWithLogitsLoss()  # detection loss
        self.L_b = nn.MSELoss()  # localization loss
        self.L_c = nn.CrossEntropyLoss()  # classification loss

    def forward(self, y_pred, y_true):

        det_pred = y_pred[:, 0]
        bbox_pred = y_pred[:, 1:5]
        class_pred = y_pred[:, 5:]

        det_true = y_true[:, 0]
        bbox_true = y_true[:, 1:5]
        class_true = y_true[:, -1].long()

        L_a = self.L_a(det_pred, det_true)

        object_detected = det_true == 1

        L_b = self.L_b(bbox_pred[object_detected], bbox_true[object_detected])
        L_c = self.L_c(class_pred[object_detected], class_true[object_detected])

        return L_a + L_b + L_c

#### Function to compute size of fully connected layer

In [None]:
def get_output_size(input_size, layer):
    H_in = input_size[0]
    W_in = input_size[1]
    C_in = layer.in_channels
    C_out = layer.out_channels
    kernel_size = layer.kernel_size
    padding = layer.padding
    stride = layer.stride

    H_out = (H_in+2*padding[0]-kernel_size[0])/stride[0]
    W_out = (W_in+2*padding[1]-kernel_size[1])/stride[1]

    return H_out * W_out * C_out

In [None]:
conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)

In [None]:
def output_pipe(layers, input_size, num_outputs):
    output_size = input_size
    for layer in layers:
        output_size = get_output_size(output_size, layer)

#### Defining models

In [None]:
class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.fc1 = nn.Linear(12*15*64, 15, device=device, dtype=torch.double)
        self.flatten = nn.Flatten()

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2, stride=2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2, stride=2)
        out = F.relu(self.conv3(out))
        out = self.flatten(out)
        out = self.fc1(out)
        return out

In [None]:
def compute_performance(model, loader):
    '''
    Function that uses a model to predict and calculate accuracy
    '''
    model.eval()
    correct = 0
    total = 0
    iou_sum = 0

    with torch.inference_mode():
        for imgs, labels in loader:
            imgs = imgs.to(device=device, dtype=torch.double)
            labels = labels.to(device=device, dtype=torch.double)

            outputs = model(imgs)

            det_pred = F.sigmoid(outputs[:, 0])
            object_detected = det_pred > 0.5

            _, class_pred = torch.max(outputs[:, 5:], dim=1)

            det_true = labels[:, 0].int()
            class_true = labels[:, -1].int()

            total += labels.shape[0]
            correct += ((object_detected == 0) & (det_true == 0)).sum()
            correct += ((object_detected == 1) & (det_true == 1) & (class_pred == class_true)).sum()

            bbox_pred = outputs[:, 1:5]
            bbox_true = labels[:, 1:5]

            iou_sum += box_iou(bbox_pred[object_detected], bbox_true[object_detected]).sum()

    acc =  correct / total
    iou = iou_sum / total

    performance = (acc + iou) / 2
    
    return acc, iou, performance

In [None]:
# model 2

In [None]:
# model 3

In [None]:
# model 4

In [None]:
# model 5

#### Function to plot training and validation loss

In [None]:
def plot_loss(train_loss:list, val_loss:list, title:str) -> None:
    """Plots the training and validation loss"""
    _, ax = plt.subplots()
    ax.plot(np.arange(1,len(train_loss)+1), train_loss, label='Training loss')
    ax.plot(np.arange(1,len(val_loss)+1), val_loss, label='Validation loss')
    ax.set_title(title)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()

In [None]:
def train(n_epochs, optimizer, model, loss_fn, train_loader, val_loader):
    
    n_batch_train = len(train_loader)
    n_batch_val = len(val_loader)

    losses_train = []
    losses_val = []

    optimizer.zero_grad(set_to_none=True)
    
    for epoch in range(1, n_epochs + 1):
        
        loss_train = 0
        loss_val = 0

        model.train()

        for imgs, labels in train_loader:

            imgs = imgs.to(device=device, dtype=torch.double)
            labels = labels.to(device=device, dtype=torch.double)

            outputs = model(imgs)
            
            loss = loss_fn(outputs, labels)
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()
            
        model.eval()

        with torch.inference_mode(): # <-- Equivalent to no_grad, if no error is provided this is preferred.
            for imgs, labels in val_loader:

                imgs = imgs.to(device=device, dtype=torch.double)
                labels = labels.to(device=device, dtype=torch.double)

                outputs = model(imgs)

                loss = loss_fn(outputs, labels)
                loss_val += loss.item()
            
        losses_train.append(loss_train / n_batch_train)
        losses_val.append(loss_val / n_batch_val)

        #if epoch == 1 or epoch % 10 == 0:
        print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(datetime.now().strftime('%H:%M:%S'), epoch, loss_train / n_batch_train))
        print('{}  |  Epoch {}  |  Validation loss {:.3f}'.format(datetime.now().strftime('%H:%M:%S'), epoch, loss_val / n_batch_val))

    train_acc, train_iou, train_performance = compute_performance(model, train_loader)
    val_acc, val_iou, val_performance = compute_performance(model, val_loader)
    print(f'Training performance: Accuracy = {train_acc}, IOU = {train_iou}, Overall = {train_performance}')
    print(f'Training performance: Accuracy = {val_acc}, IOU = {val_iou}, Overall = {val_performance}')

    return losses_train, losses_val, train_performance, val_performance

In [None]:
train_loader = torch.utils.data.DataLoader(loc_train, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(loc_val, batch_size=64, shuffle=False)

torch.manual_seed(SEED)
model = MyCNN()
model.to(device=device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0, weight_decay=0)
loss_fn = LocalizationLoss()

In [None]:
loss_train, loss_val, train_perform, val_perform = train(
    n_epochs=5,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader
)

plot_loss(loss_train, loss_val, 'Model')

#### Selecting the best model TBD

In [None]:
def model_selector(models:list, performances:list):
    """Given a list of models, returns the model that has best accuracy score on validation data"""
    best_model = None
    best_performance = 0

    for idx, model in enumerate(models):
        if performances[idx] > best_performance:
            best_model = model
            best_performance = performances[idx]

    return best_model, best_performance

In [None]:
best_model, best_performance = model_selector([model], [performance])

# Print additional details of the best model
print("Best Model Details\n--------------------------------------------------------------")
print(f"Network architecture/ layout: {best_model}\n")
#print(f"Optimizer Parameters: {best_data.optimizer")
print(f"Validation Performance: {best_performance}")
#print(f"Validation Accuracy {round(best_data['model_man_val_accuracy'], 2)}")

#### Evaluating the best model on unseen data TBD

In [None]:
test_loader = torch.utils.data.DataLoader(loc_test, batch_size=64, shuffle=False)

test_acc, test_iou, test_performance = compute_performance(best_model, test_loader)
print(10*'-'+'Test Performance' + 10*'-')
print(f"Test Accuracy: {test_acc}\nTest IOU: {test_iou}\nOverall Performance: {test_performance}")

In [None]:
def predict(model, loader):
    '''
    Function that creates a y and y_pred tensor given a model and a loader
    '''
    model.eval()
    
    y_true = torch.empty(0, device=device)
    y_pred = torch.empty(0, device=device)

    with torch.inference_mode(): # <-- Equivalent to no_grad, if no error is provided this is preferred.
        for imgs, labels in loader:
            
            imgs = imgs.to(device=device, dtype=torch.double) 
            labels = labels.to(device=device)
            outputs = model(imgs)
            
            _, class_pred = torch.max(outputs[:, 5:], dim=1)

            predicted = torch.cat((outputs[:, :5], class_pred.unsqueeze(1)), dim=1)
            
            y_true = torch.cat((y_true, labels), dim=0)
            y_pred = torch.cat((y_pred, predicted.data), dim=0)
                
    return y_true, y_pred

In [None]:
def plot_predictions(imgs, y_true:torch.tensor, y_pred:torch.tensor, label:int=0, start_idx:int=0) -> None:
    """Plots things"""
    class_mask = y_true[:, -1] == label
    class_imgs = [img for idx, img in enumerate(imgs) if class_mask[idx]]
    class_true, class_pred = y_true[class_mask], y_pred[class_mask]
    
    true_bboxes = [label[1:5] for label in class_true]
    pred_bboxes = [label[1:5] for label in class_pred]

    _, axes = plt.subplots(nrows=2, ncols=5, figsize=(8,3))

    for i, ax in enumerate(axes.flat):

        idx = start_idx + i
        img = class_imgs[idx]

        img_height, img_width = img.shape[-2], img.shape[-1]
        img = (img * 255).byte()

        if int(class_true[idx][0]) == 1:
            true_bbox = true_bboxes[idx] # TODO repetiv kode, lage en funksjon
            true_bbox[0] *= img_width
            true_bbox[1] *= img_height
            true_bbox[2] *= img_width
            true_bbox[3] *= img_height
            
            true_bbox = true_bbox.type(torch.uint8)
            true_bbox_converted = box_convert(true_bbox, in_fmt='cxcywh', out_fmt='xyxy')
            true_bbox_converted = true_bbox_converted.unsqueeze(0)

            img = draw_bounding_boxes(img, true_bbox_converted, colors='lightgreen')

        if F.sigmoid(class_pred[idx][0]) > 0.5:
            pred_bbox = pred_bboxes[idx]
            pred_bbox[0] *= img_width
            pred_bbox[1] *= img_height
            pred_bbox[2] *= img_width
            pred_bbox[3] *= img_height
            
            pred_bbox = pred_bbox.type(torch.uint8)
            pred_bbox_converted = box_convert(pred_bbox, in_fmt='cxcywh', out_fmt='xyxy')
            pred_bbox_converted = pred_bbox_converted.unsqueeze(0)

            img = draw_bounding_boxes(img, pred_bbox_converted, colors='red')
            
        img = img.numpy().transpose((1, 2, 0))
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Pred: {int(class_pred[idx][-1])}')
        plt.suptitle(f'True label: {label} - Image {start_idx} to {idx}')
        ax.axis('off')

    plt.show()

In [None]:
y_true, y_pred = predict(best_model, test_loader)

In [None]:
loc_test = torch.load('data/localization_test.pt')

In [None]:
imgs = [img for img,_ in loc_test]
plot_predictions(imgs, y_true, y_pred, label=3, start_idx=10)

# 3 Object Detection

#### Loading the data and inspecting the data

In [None]:
train_labels = torch.load('data/list_y_true_train.pt')
val_labels = torch.load('data/list_y_true_val.pt')
test_labels = torch.load('data/list_y_true_test.pt')

In [None]:
print(f'Train label size: {len(train_labels)}')
print(f'Val label size: {len(val_labels)}')
print(f'Test label size: {len(test_labels)}')

In [None]:
det_train = torch.load('data/detection_train.pt')
det_val = torch.load('data/detection_val.pt')
det_test = torch.load('data/detection_test.pt')

In [None]:
print(f'Train label size: {len(det_train)}')
print(f'Val label size: {len(det_val)}')
print(f'Test label size: {len(det_test)}')

In [None]:
def local_to_global(bbox, W_out, H_out, W_img, H_img):
    """Does things """


    return bbox

    


def global_to_local(global_coordinates, H_out, W_out):
    """Does things """
    w_boundaries = [i / W_out for i in range(0, W_out)]
    w_boundaries = [(idx, boundary) for idx, boundary in enumerate(w_boundaries)]
    w_boundaries.reverse()

    h_boundaries = [i / H_out for i in range(0, H_out)]
    #h_boundaries = [(idx, boundary) for idx, boundary in enumerate(h_boundaries)]
    h_boundaries.reverse()

    print(w_boundaries, h_boundaries)

    x_global, y_global = global_coordinates[0].item(), global_coordinates[1].item()
    x_boundary, x_cell = next(((boundary, idx) for idx, boundary in w_boundaries if boundary < y_global))
    y_boundary, y_cell = next(((boundary, idx) for idx, boundary in enumerate(h_boundaries) if boundary < x_global))

    x_local = abs((x_global - x_boundary) * W_out)
    y_local = abs((y_global - y_boundary) * H_out)

    return x_local, y_local, (x_cell, y_cell)
    
global_cord = [0.5, 0.6771, 0.1667, 0.4792]
#print(global_to_local(global_cord, 2, 3))

In [None]:
def plot_detection_data(data, H_out, W_out):
    """W_out H_out should be shape of grid"""
    _, axes = plt.subplots(nrows=2, ncols=5, figsize=(8,3))

    for i, ax in enumerate(axes.flat): 
        
        img, label = data[i]
        img_height, img_width = img.shape[-2], img.shape[-1]
        img = (img * 255).byte()

        for row_idx, row in enumerate(label):
            for col_idx, gridcell in enumerate(row):
                if int(gridcell[0]) == 1:
                    bbox = gridcell[1:5].clone()
                    
                    cell_width = img_width / W_out
                    cell_height = img_height / H_out
                    x_boundary = col_idx * cell_width
                    y_boundary = row_idx * cell_height
                    
                    bbox[0] = ((bbox[0] / x_boundary) * W_out) * img_width
                    bbox[1] = ((bbox[1] / y_boundary) * H_out) * img_height
                    bbox[2] = bbox[2] * img_width
                    bbox[3] = bbox[3] * img_height

                    bbox = bbox.type(torch.uint8)

                    converted_bbox = box_convert(bbox, in_fmt='cxcywh', out_fmt='xyxy')

                    img = draw_bounding_boxes(img, converted_bbox.unsqueeze(0), colors='lightgreen')

        img  = img.numpy().transpose((1, 2, 0))
        ax.imshow(img, cmap='gray')
        ax.set_title(i)
        ax.axis('off')

plot_detection_data(data=det_train, H_out=2, W_out=3)

In [None]:
train_labels = torch.load('data/list_y_true_train.pt')
val_labels = torch.load('data/list_y_true_val.pt')
test_labels = torch.load('data/list_y_true_test.pt')

In [None]:
def prepare_labels(y_true, H_out, W_out):
    """
    1. Create empty tensor in right format
    2. Change to local coordinates
    3. Which grid cell does each object belong to?
    4. Place in correct cell
    """
    label_tensor = torch.zeros(2, 3, 6)
    for label in y_true:
        x_local, y_local, grid_pos = global_to_local(label[1:5], H_out, W_out)
        print(grid_pos)
        label[1] = x_local
        label[2] = y_local
        label[3] *= W_out
        label[4] *= H_out
        label_tensor[grid_pos[1], grid_pos[0]] = label
    return label_tensor

In [None]:
train_labels = torch.load('data/list_y_true_train.pt')
label = train_labels[1]
label

In [None]:
det_train[1][1]

In [None]:
prepare_labels(label, 2, 3)

#### Defining the loss function

In [None]:
class DetectionLoss(nn.Module):
    """Custom loss function"""
    def __init__(self):
        super().__init__()
        self.L_a = nn.BCEWithLogitsLoss()  # detection loss
        self.L_b = nn.MSELoss()  # localization loss
        self.L_c = nn.CrossEntropyLoss()  # classification loss

    def forward(self, y_pred, y_true):

        det_pred = y_pred[:, 0]
        bbox_pred = y_pred[:, 1:5]
        class_pred = y_pred[:, 5:]

        det_true = y_true[:, 0]
        bbox_true = y_true[:, 1:5]
        class_true = y_true[:, -1].long()

        L_a = self.L_a(det_pred, det_true)

        object_detected = det_true == 1

        L_b = self.L_b(bbox_pred[object_detected], bbox_true[object_detected])
        L_c = self.L_c(class_pred[object_detected], class_true[object_detected])

        return L_a + L_b + L_c