In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.ops import box_convert
from torch.utils.data import TensorDataset


from project_functions import *
from project_objects import *
from project_models import *
from project_constants import DEVICE as device
from project_constants import SEED

In [None]:
torch.manual_seed(SEED)
torch.set_default_dtype(torch.double)

## 2 Object Localization
#### First we load and inspect the localization datasets

In [None]:
loc_train = torch.load('data/localization_train.pt')
loc_val = torch.load('data/localization_val.pt')
loc_test = torch.load('data/localization_test.pt')

In [None]:
print(f'Train data size: {len(loc_train)}')
print(f'Val data size: {len(loc_val)}')
print(f'Test data size: {len(loc_test)}')

In [None]:
first_img, first_label = loc_train[0]

print(f'Shape of first image: {first_img.shape}')
print(f'Type of first image: {type(first_img)}')

print(f'\nShape of first label: {first_label.shape}')
print(f'Type of first label: {type(first_label)})')
first_label

In [None]:
# Assuming train_data, val_data, and test_data are defined elsewhere
count_instances(loc_train, 'Training Data')
count_instances(loc_val, 'Validation Data')
count_instances(loc_test, 'Test Data')


#### Plotting one image from each class

In [None]:
plot_images(loc_train)

In [None]:
plot_class(loc_train, 4, 0)

#### Defining a normalizer and a preprocessor

In [None]:
imgs = torch.stack([img for img, _ in loc_train])

# Define normalizer
normalizer_pipe = transforms.Normalize(
    imgs.mean(dim=(0, 2, 3)), 
    imgs.std(dim=(0, 2, 3))
    )

# Define preprocessor including the normalizer
preprocessor = transforms.Compose([
            normalizer_pipe
        ])

In [None]:
loc_train_norm = [(preprocessor(img), label) for img, label in loc_train]
loc_val_norm = [(preprocessor(img), label) for img, label in loc_val]
loc_test_norm = [(preprocessor(img), label) for img, label in loc_test]

In [None]:
train_loader = torch.utils.data.DataLoader(loc_train_norm, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(loc_val_norm, batch_size=64, shuffle=False)

loss_fn = LocalizationLoss()

#### Defining models

In [None]:
model_name = 'model5test2'

torch.manual_seed(SEED)
model = MyCNN5((48,60,1))
model.to(device=device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)

loss_train, loss_val, train_perform, val_perform, losses_separated = train(
    n_epochs=50,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader
)

plot_loss(loss_train, loss_val, model_name, save_model=True)
plot_lists(losses_separated, ['detection loss', 'localization loss', 'classification loss'], model_name, save_model=True)
y_true, y_pred = predict(model, val_loader)
plot_predictions(loc_val, y_true, y_pred, label=4, start_idx=0, fig_name=model_name, save_model=True)

#### Selecting the best model

In [None]:
best_model, best_performance = model_selector([model], [val_perform])

# Print additional details of the best model
print("Best Model Details\n--------------------------------------------------------------")
print(f"Network architecture/ layout: {best_model}\n")
#print(f"Optimizer Parameters: {best_data.optimizer")
print(f"Validation Performance: {best_performance}")
#print(f"Validation Accuracy {round(best_data['model_man_val_accuracy'], 2)}")

#### Evaluating the best model on unseen data TBD

In [None]:
test_loader = torch.utils.data.DataLoader(loc_test_norm, batch_size=64, shuffle=False)

test_acc, test_iou, test_performance = compute_performance(best_model, test_loader)
print(10*'-'+'Test Performance' + 10*'-')
print(f"Test Accuracy: {test_acc}\nTest IOU: {test_iou}\nOverall Performance: {test_performance}")

In [None]:
y_true, y_pred = predict(best_model, test_loader)

In [None]:
plot_predictions(loc_test, y_true, y_pred, label=6, start_idx=0)

# 3 Object Detection

#### Loading the data and inspecting the data

In [None]:
train_labels = torch.load('data/list_y_true_train.pt')
val_labels = torch.load('data/list_y_true_val.pt')
test_labels = torch.load('data/list_y_true_test.pt')

In [None]:
print(f'Train label size: {len(train_labels)}')
print(f'Val label size: {len(val_labels)}')
print(f'Test label size: {len(test_labels)}')

In [None]:
det_train = torch.load('data/detection_train.pt')
det_val = torch.load('data/detection_val.pt')
det_test = torch.load('data/detection_test.pt')

In [None]:
print(f'Train label size: {len(det_train)}')
print(f'Val label size: {len(det_val)}')
print(f'Test label size: {len(det_test)}')

In [None]:
H_OUT = 2
W_OUT = 3

In [None]:
def global_to_local(labels_list:list, grid_dimension:tuple):
    '''
    Transfers one list of tensors to local values
    '''
   
    x = grid_dimension[1]
    y = grid_dimension[0]
    vector_length = grid_dimension[2]

    local_tensor = torch.zeros(y,x,vector_length)
    
    x_grid_cells = [(i+1)/x for i in range(x)]
    y_grid_cells = [(i+1)/y for i in range(y)]

    for label in labels_list:
        label = label.clone()
        x_cell = next(i for i, cell in enumerate(x_grid_cells) if label[1] < cell)
        y_cell = next(i for i, cell in enumerate(y_grid_cells) if label[2] < cell)

        if x_cell != 0:
            label[1] -= x_grid_cells[x_cell - 1]
        if y_cell != 0:
            label[2] -= y_grid_cells[y_cell - 1]

        label[1] *= x
        label[3] *= x
        label[2] *= y
        label[4] *= y

        local_tensor[y_cell][x_cell] = label

    return local_tensor

def prepare_labels(label_dataset:list, grid_dimension:tuple):
    '''
    Iterates through each listed tensor, transforms from global to local coordinates, and stacks them into a new tensor.
    '''

    new_tensor = torch.stack([global_to_local(label, grid_dimension) for label in label_dataset])
    new_tensor = new_tensor.permute(0, 3, 1, 2) 

    return new_tensor

def merge_datasets(d1, d2):
    '''
    Combines the new labels with the image data.
    '''

    return TensorDataset(d2[:][0],d1[:])

def local_to_global(labels_tensor:torch.Tensor, grid_dimension:tuple):
    '''
    Returns to original format.
    '''

    list_of_tensors = []


    for i in range(len(labels_tensor)):
        inner =[]
        x = labels_tensor[i,:,:]
        not_all_zero = x.any(dim=-1)
        for each in x[not_all_zero]:
            inner.append(each)
        if inner != []:
            list_of_tensors.append(inner)

    return list_of_tensors

In [None]:
det_train[0][1]

In [None]:
train_labels_local = prepare_labels(train_labels, (2,3,6))
train_labels_local

In [None]:
det_train_manual = merge_datasets(train_labels_local, det_train)
det_train_manual

In [None]:
train_labels_local.shape # vi må fikse dette

In [None]:
train_loader1 = torch.utils.data.DataLoader(det_train_manual, batch_size=64, shuffle=False)

In [None]:
train_loader2 = torch.utils.data.DataLoader(det_train, batch_size=64, shuffle=False)

In [None]:
def are_dataloaders_equal(dataloader1, dataloader2):
    list1 = []
    list2 = []
    
    for batch1, batch2 in zip(dataloader1, dataloader2):
        list1.extend(batch1)
        list2.extend(batch2)
    
    return all(torch.equal(data1, data2) for data1, data2 in zip(list1, list2))

are_dataloaders_equal(train_loader1, train_loader2)

In [None]:
for i in train_labels[0]:
    print(i[1:5])

#### Plotting some images from the detection dataset

In [None]:
def plot_detection_data(imgs, global_labels, start_idx=0):
    """Data should be global"""
    _, axes = plt.subplots(nrows=2, ncols=5, figsize=(8,3))

    for i, ax in enumerate(axes.flat): 
        
        img, labels = imgs[i+start_idx], global_labels[i+start_idx]
        img_height, img_width = img.shape[-2], img.shape[-1]
        img = img.clone()
        img = (img * 255).byte()
        labels = [labels] if not isinstance(labels, list) else labels
        label_classes = ''
        
        for label in labels:
            label_classes += f'{int(label[-1])}  '
            bbox = label[1:5]
            bbox = bbox.clone()
            bbox[0] *= img_width
            bbox[1] *= img_height
            bbox[2] *= img_width
            bbox[3] *= img_height

            bbox = bbox.type(torch.uint8)

            converted_bbox = box_convert(bbox, in_fmt='cxcywh', out_fmt='xyxy')

            img = draw_bounding_boxes(img, converted_bbox.unsqueeze(0), colors='lightgreen')

        img  = img.numpy().transpose((1, 2, 0))
        ax.imshow(img, cmap='gray')
        ax.set_title(label_classes)
        ax.axis('off')
        plt.suptitle(f'Image {start_idx} - {start_idx+9}')

imgs = [img for img,_ in det_train]
plot_detection_data(imgs, train_labels, start_idx=10)

#### Normalizing the dataset TBD

#### Defining the loss function

In [None]:
class DetectionLoss(nn.Module):
    """Custom loss function"""
    def __init__(self):
        super().__init__()
        self.Localization_loss = LocalizationLoss()
        self.loss_tuples = None # fordi treningsloss forventer denne, kan bare endre så den er optional

    def forward(self, y_pred, y_true):
        print(y_pred.shape)
        print(y_true.shape)
        h, w = y_pred.shape[0], y_pred.shape[1]
        loss = 0
        for i in range(h):
            for j in range(w):
                loss += self.Localization_loss(y_pred[i][j], y_true[i][j])
        
        return loss, self.loss_tuples

In [None]:
train_loader = torch.utils.data.DataLoader(det_train_manual, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(det_val, batch_size=64, shuffle=False)

loss_fn2 = DetectionLoss()

In [None]:
class TestCNN2(nn.Module):
    def __init__(self, input_size=None):
        self.input_size = input_size
        
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=1, padding=1, device=device, dtype=torch.double)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), stride=1, padding=1, device=device, dtype=torch.double)
        self.pool4 = nn.MaxPool2d(kernel_size=2)
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.conv6 = nn.Conv2d(in_channels=64, out_channels=15, kernel_size=3, stride=1, padding=1, device=device, dtype=torch.double)
        self.conv6 = nn.Conv2d(in_channels=64, out_channels=6, kernel_size=(4,3), stride=1, padding=1, device=device, dtype=torch.double)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool1(out)
        out = F.relu(self.conv2(out))
        out = self.pool2(out)
        out = F.relu(self.conv3(out))
        out = self.pool3(out)
        out = F.relu(self.conv4(out))
        out = self.pool4(out)
        out = F.relu(self.conv5(out))
        out = self.conv6(out)
        return out

In [None]:
torch.manual_seed(SEED)
model = TestCNN2()
model.to(device=device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)

loss_train, loss_val, train_perform, val_perform, losses_separated = train(
    n_epochs=2,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn2,
    train_loader=train_loader,
    val_loader=val_loader
)

plot_loss(loss_train, loss_val, model_name, save_model=True)
plot_lists(losses_separated, ['detection loss', 'localization loss', 'classification loss'], model_name, save_model=True)
y_true, y_pred = predict(model, val_loader)
plot_predictions(loc_val, y_true, y_pred, label=4, start_idx=0, fig_name=model_name, save_model=True)