# Read from Kaggle, unzip file

In [None]:
%load_ext autoreload
%autoreload

import urllib.request as urlrequest
from pathlib import Path

base_data_path = Path() / "data" 

base_data_path.mkdir(parents = True, exist_ok = True)

dataset_location = base_data_path / "dataset"

zip_path = base_data_path / "dataset.zip"

if not zip_path.exists():

    dataset_url = "https://www.kaggle.com/api/v1/datasets/download/alvarole/asirra-cats-vs-dogs-object-detection-dataset"

    response = urlrequest.urlopen(
        dataset_url,
    )

    download_size = response.getheader("Content-Length")
    with open(zip_path, "wb") as f:

        f.write(response.read())


In [None]:

import zipfile
import os

def unzip():

    inner_file = "Asirra_ cat vs dogs"
    with zipfile.ZipFile(zip_path, "r") as zip:
        
        for item in zip.infolist():

            zip.extract(item, base_data_path)

    os.rename(base_data_path / inner_file, dataset_location)

# unzip()


# Dataset creation

In [None]:

import itertools
import xml.etree.ElementTree as ET
import torch
from typing import TypedDict


def patched_dataset_paths(dataset_location):

    return itertools.batched(dataset_location.iterdir(), 2)

class Objects(TypedDict):
    '''
    `bndbox`: (xmin,ymin,xmax,ymax)
    '''

    name: str
    pose: str
    truncated: int
    difficult: int
    bndbox: torch.Tensor

class Metadata(TypedDict):
    '''
    `size`: (width, height, depth)
    '''

    size: torch.Tensor
    objects: list[Objects]

class MetaWithImage(Metadata):

    img_path: str

# specific xml reader implementation for the lolz
def read_metadata(xml_file: Path) -> Metadata:
    '''
    Read labeling from xml file into dict.
    '''

    with open(xml_file, "r", encoding = "utf-8") as f:
        text = ET.canonicalize(from_file=f, strip_text = True)
        
    tree = ET.fromstring(text)

    size = tree.find("size")
    size = torch.tensor([int(elem.text) for elem in size.iter() if not elem.tag == "size"])


    objects = tree.findall("object")
    objects: Objects = [dict(
        name = obj.find("name").text,
        pose = obj.find("pose").text,
        truncated = int(obj.find("truncated").text),
        difficult = int(obj.find("difficult").text),
        bndbox = torch.tensor([
            float(elem.text) 
            for elem in obj.find("bndbox").iter()
            if not elem.tag == "bndbox"
        ])
    ) for obj in objects]

    metadata: Metadata = dict(
        size = size,
        objects = objects
    )

    return metadata

def get_dataset(dataset_location) -> list[MetaWithImage]:

    meta: list[MetaWithImage] = []
    for img, xml_path in patched_dataset_paths(dataset_location):

        metadata: MetaWithImage = read_metadata(xml_path) | dict(img_path = img)
        meta.append(metadata)

    return meta

def dataset_splits(dataset: list[MetaWithImage] | None = None, fractions: tuple[float] = (0.8, 0.1, 0.1)):

    dataset = get_dataset(dataset_location) if dataset is None else dataset
    return  torch.utils.data.random_split(dataset, fractions)



    


In [None]:
import torch
from torchvision.io import decode_image
import torchvision.transforms.v2.functional as tvt

class CatsAndDogsDataset(torch.utils.data.Dataset):
    '''
    `class_dict`: dict
    - "background" matches zero, other values in `class_names` (passed
    to init) match subsequent indices incrementally.
    '''

    def __init__(self, data: list[MetaWithImage], class_names: list[str], resize_to = (300,300)):
        '''

        `class_names`:
        - names of classes in dataset. 

        `resize_to`:
        - should be square
        '''

        classes = ["background"] + class_names
        self.class_dict = {classes[i]: i for i in range(len(classes))}
        self.resize_to = resize_to
        self.data = [self.metadata_transform(val) for val in data]

    def __len__(self):
        return len(self.data)

    def image_transform(self, img):

        return tvt.resize(img, self.resize_to)

    def metadata_transform(self, metadata: MetaWithImage):
        '''
        Transform the bndbox values to be in the range [0,1],
        and change the "objects" contents to be a dict
        { bndbox: torch.Tensor, class: torch.Tensor }, where `bndbox` is
        (N,4) and `class` (N). 
        '''

        resize_x, resize_y = self.resize_to
        width, height, depth = metadata['size']

        bndboxes = []
        classes = []

        for i in range(len(metadata["objects"])):
            obj = metadata['objects'][i]
            bndbox = obj['bndbox']
            bndboxes.append(
                bndbox/torch.tensor([width, height]*2, dtype=torch.float32)
            )

            classes.append(self.class_dict[obj["name"]])

        metadata["size"] = torch.tensor([resize_x, resize_y, depth])
        metadata["objects"] = {
            "bndbox": torch.vstack(bndboxes),
            "class": torch.tensor(classes, dtype=int)
        }

        return metadata

    def __getitem__(self, idx):
        metadata = self.data[idx]
        img_path = metadata["img_path"]
        image = tvt.to_dtype(decode_image(img_path), scale = True)
        image = self.image_transform(image)
        return image, metadata


In [None]:
train_split, validation_split, test_split = map(
    lambda split: CatsAndDogsDataset(split, ["cat", "dog"]),
    dataset_splits()
)

print(len(train_split))

# Test datasets with plotting

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.v2.functional as vision_transforms
from torchvision.utils import draw_bounding_boxes

def to_plottable(img):

    return vision_transforms.to_pil_image(img) 

def add_bb(img, bndbox: torch.Tensor, color = "cyan"):
    '''
    `bdnbox`:

    (N, 4) with (xmin,ymin,xmax,ymax) (relative to dimensions of image).
    '''

    width, height, _ = [val.item() for val in meta["size"]]
    bb = bndbox*torch.tensor([width, height, width, height])
    img = draw_bounding_boxes(img, bb, colors = color)

    return img

plt.figure()
im, meta = train_split[int(torch.rand(1).item()*len(train_split))]
im = add_bb(im, meta["objects"]["bndbox"])
plt.imshow(to_plottable(im))
print(meta["objects"]["bndbox"])
plt.show()


# Test IoU calculation

In [None]:
import src.default_box as default_box
import src.utils.math as math_utils


im, meta = train_split[0]
bndbox = meta["objects"]["bndbox"]
width, height, _ = [val.item() for val in meta["size"]]

print(width, height)
boxes = default_box.default_boxes(
    scale = 0.8,
    centers = default_box.default_box_centers(width//8, height//8)
)

num_boxes, num_ratios, _ = boxes.shape
boxes = boxes.reshape([num_boxes*num_ratios, 4])

print(bndbox)
print(boxes.shape)
iou = math_utils.intersection_over_union(
    boxes,
    bndbox
)
print(iou.shape)
iou = iou.reshape([num_ratios, len(bndbox), num_boxes])
print(iou.shape)

print([(iou >= val).sum() for val in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]])


# Design network and training setup

See [SSD paper](https://arxiv.org/abs/1512.02325), also [Dive into deep learning](https://d2l.ai/) has good practical material and examples.

SSD consists of a base network followed by successive prediction layers
which generate the class predictions and bounding box offsets at different
scales. 

- The base network downsamples the input, decreasing the width and height
while adding more channels.

- The prediction layers make the predictions
by convolving over their input and outputting a value for each class
and each default box offset. 
    * The output from a layer is further downsampled by e.g. pooling,
    creating a larger receptive field for the next layer. As a result,
    the scale of default boxes should increase further into the net.


## Network

In [None]:
import torch
from math import floor

import src.utils.reshape as reshape

def down_sampler(in_channels, out_channels, device = None):

    if device is None:
        device = torch.get_default_device()

    return torch.nn.Sequential(*[
        torch.nn.Conv2d(
            in_channels = in_channels,
            out_channels = out_channels,
            kernel_size = 3,
            padding = 1,
            device = device
        ),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(kernel_size = 2)
    ])

def prediction_layer(in_channels, num_classes, num_ratios, device = None):

    if device is None:
        device = torch.get_default_device()

    return torch.nn.ModuleDict(
        dict(
            class_pred = torch.nn.Conv2d(
                in_channels = in_channels,
                out_channels = num_classes*num_ratios,
                kernel_size = 3,
                padding = 1,
                device = device
            ),
            box_pred = torch.nn.Conv2d(
                in_channels = in_channels,
                out_channels = 4*num_ratios,
                kernel_size = 3,
                padding = 1,
                device = device
            )
        )
    )

def max_pool_change(in_size):
    '''
    How much an input size of `in_size` changes using
    torch.nn.MaxPool2D(kernel_size = 2)
    '''
            
    return floor((in_size - (2-1) - 1)/2 + 1)

def repeat_apply(func, _input, num):

    for _ in range(num):
        _input = func(_input)

    return _input

def generate_default_boxes(width, scale):

    return default_box.default_boxes(
        scale = scale,
        centers = default_box.default_box_centers(width, width)
    )

class SSD(torch.nn.Module):
    '''

    `num_classes`: int
    
    `num_ratios`: int
    - Number of ratios used for the default boxes.

    `default_boxes`: torch.Tensor
    - default boxes for each feature map layer, of shape
    (num_ratios*feature_map_height*feature_map_width,4). Indexing
    over boxes per pixel works by [num_ratios*i:num_ratios*(i+1), 4]

    `pixels_per_layer`: list of int
    - pixels per feature map layer, useful for iterating over `default_boxes`
    in layer-by-layer manner.
    '''

    def __init__(self, num_classes, num_ratios, in_channels = 3, width = 300, device = None):
        '''
        `num_classes`:
        - Number of classes including background "class".
        '''

        if device is None:
            device = torch.get_default_device()

        self.device = device

        super().__init__()

        base_channels = [in_channels,9,27,81]

        self.num_classes = num_classes
        self.num_ratios = num_ratios

        self.max_pool = torch.nn.MaxPool2d(kernel_size = 2)

        self.base_layers = torch.nn.ModuleDict({
            "box": torch.nn.Sequential(*[
                down_sampler(base_channels[i], base_channels[i+1], device = device)
                for i in range(len(base_channels)-1)
            ]),
            "class": torch.nn.Sequential(*[
                down_sampler(base_channels[i], base_channels[i+1], device = device)
                for i in range(len(base_channels)-1)
            ])
        })

        self.prediction_layers = [
            prediction_layer(base_channels[-1], num_classes, num_ratios, device = device)
            for _ in range(3)
        ]
        
        # pre-generate the default boxes
        # ------------------------------
        start_size = repeat_apply(max_pool_change, width, len(base_channels)-1)
        
        scales = default_box.scales(len(self.prediction_layers))

        def_boxes = []
        self.pixels_per_layer = []

        for i in range(len(self.prediction_layers)):


            def_box = generate_default_boxes(start_size, scales[i])
            pixels = def_box.shape[0]
            self.pixels_per_layer.append(pixels)
            # shape to match the output from the box predictions:
            # [pixels*ratios, boxes]
            def_box = (
                def_box
                .flatten(start_dim = 1)
                .reshape([pixels*num_ratios, 4])
            )
            def_boxes.append(
                def_box
            )

            start_size = max_pool_change(start_size)

        self.default_boxes = torch.vstack(def_boxes)
        # ===============================



    def forward(self, x):
        '''
        Return dict with kwords class_preds, box_preds,
        both a list of torch.Tensor of predictions per feature map layer.
        '''

        x = x.to(device = self.device)
        X_box = self.base_layers["box"](x)
        X_class = self.base_layers["class"](x)

        pl = self.prediction_layers

        class_preds = []
        box_preds = []
        for i in range(len(pl)):

            # predict classes 
            # (reshape to have predicted classes per default box)
            class_pred = reshape.items_per_pixel(
                pl[i]["class_pred"](X_class),
                self.num_classes
            )
            class_preds.append(
                torch.nn.functional.log_softmax(class_pred, dim=2).to("cpu")
                
            )

            # predict box offsets
            box_pred = reshape.items_per_pixel(pl[i]["box_pred"](X_box), 4)
            # restrict w/h offset to > 0 (w/h used to scale width and height)
            box_pred[:,:,[2,3]] = torch.nn.functional.softplus(box_pred[:,:,[2,3]])

            box_preds.append(
                box_pred.to("cpu")
            )

            X_box = self.max_pool(X_box)
            X_class = self.max_pool(X_class)

        return dict(
            class_preds = class_preds,
            box_preds = box_preds
        )


In [None]:
from itertools import accumulate
import src.utils.torch_util as torch_utils


dev = "cpu"
dev = torch_utils.TORCH_DEVICE
model = SSD(3, 5, device = dev)

# test properties
print(model.default_boxes.shape)
print(model.pixels_per_layer)

img1 = train_split[0][0]
img2 = train_split[1][0]
img_stack = torch.stack([img1, img2])

res = model(img_stack)
print(res["class_preds"][0].shape)
print(f"{torch.hstack(res["class_preds"]).shape=}")

bpreds = res["box_preds"][0]
print(res["box_preds"][0].shape)

print(model.default_boxes[:5,:])

# should match in dimensions
box_preds = torch.hstack(res["box_preds"])
def_boxes = model.default_boxes

box_sum = model.default_boxes + box_preds
print(f"{box_sum.shape=}")

# how to go over feature layers
layer_indices = list(accumulate(model.pixels_per_layer))
print(layer_indices)

assert torch.all(
    torch.tensor(box_sum[0][layer_indices[-2]*model.num_ratios:,:].shape) 
    == torch.tensor([model.pixels_per_layer[-1]*model.num_ratios, 4])
)


## Utilities for calculating loss

In [None]:

# print(train_split[0])


dat = train_split[0][1]
bndbox, clslist = dat["objects"].values()

def classes_and_boxes_truth(
    iou: torch.Tensor,
    ground_truth_classes: torch.Tensor,
    threshold = 0.5,
) -> tuple[torch.Tensor, torch.Tensor]:
    '''
    Calculate a tensor indicating which default box is considered
    to overlap which class and which ground truth box.

    Return shape is (number of default boxes), with each element
    a class indicator index (background is zero) or box index (-1 is
    no box).

    `iou`:
    - The intersection-over-union of default boxes and ground truth
    boxes, as per
    
    > utils.math.intersection_over_union(model.default_boxes, ground_truth_boxes)
    '''

    matches_max = iou.max(dim=0)
    boxes = matches_max.indices
    classes = ground_truth_classes[boxes]
    background = matches_max.values <= threshold
    boxes[background] = -1
    classes[background] = 0
    return classes, boxes

iou = math_utils.intersection_over_union(
    model.default_boxes,
    bndbox
)
classes, boxes = classes_and_boxes_truth(iou, clslist)

In [None]:
print(iou.shape)
(iou > 0.5).sum()

class_preds = torch.hstack(res["class_preds"]).to("cpu")

print(class_preds.shape)

print(classes.shape)
print(classes.sum())
print((classes == 0).sum())
print((classes == 1).sum())
print((classes == 2).sum())

torch.nn.functional.nll_loss(class_preds[0], classes)

In [None]:


# matching predicted boxes to ground truth.
# - Only calculate for boxes which have a matching ground truth?
# - Only calculate for boxes which have a predicted class other than background?
#   - Or should box predictions and class predictions be considered
#   independent of each other? Is the box prediction layer expected
#   to learn to match the ground truth box independent of what the
#   target class actually is? Of course, the two are modeled as
#   separate, so independence should be assumed, I guess.

def offset_default_boxes(
    default_boxes: torch.Tensor,
    predicted_offsets: torch.Tensor,
):
    '''
    Offset `default_boxes` based on `predicted_offsets`.
    '''

    predicted = torch.clone(default_boxes).detach()
    # move boxes
    predicted += predicted_offsets[:, [0,1,0,1]]
    # scale boxes
    width_height_scaled = (
        (predicted[:,[2,3]] - predicted[:,[0,1]])
        *predicted_offsets[:, [2,3]]
    )
    predicted[:,[2,3]] = predicted[:,[0,1]]+width_height_scaled

    return predicted

def calculate_box_loss(
    default_boxes: torch.Tensor,
    predicted_offsets: torch.Tensor,
    ground_truth_boxes: torch.Tensor,
    ground_truth_overlap_index: torch.Tensor
):
    '''
    Calculate a Smooth L1 Loss between predicted object boxes
    and actual ones. 
    
    `default_boxes` (xmin, ymin, xmax, ymax) will
    be offset by `predicted_offsets` (x,y,w,h): x/y is used
    to move the entire box along an axis, allowing any real value;
    w/h >= 0 is used to scale the box, keeping the xmin/ymin stationary
    while increasing/decreasing the distance of xmax/ymax from the
    former. The result of the offset is compared against the ground
    truth box in `ground_truth_boxes` which the default box in question
    matched against (see `classes_and_boxes_truth()`) based on 
    `ground_truth_overlap_index`.
    '''

    # calculate only for values which had a match
    match = ground_truth_overlap_index != -1
    predicted = offset_default_boxes(default_boxes[match], predicted_offsets[match])

    # match the ground truths to the default boxes based on
    # overlap
    indices = ground_truth_overlap_index[match]
    actual = ground_truth_boxes[indices,:]

    return torch.nn.functional.smooth_l1_loss(
        input = predicted,
        target = actual,
    )


calculate_box_loss(
    model.default_boxes,
    box_preds[0],
    bndbox,
    boxes
)

### class losses

The losses here are at least supposed to perform the hard negative
mining (HNM) suggested for the SSD, namely only using some of the
many negative samples to train the model. This is supposed to
improve the training, as it emphasises positive examples more.

The two ways of calculating (or including) HNM here are to do it
1. based on the true values, so pick only a fraction of the true background
items for comparison
2.  based on the predicted values, so pick only a fraction of the predicted
background items for comparison.

Not personally sure how they originally exactly did it, but no harm
in doing it both ways and summing (or something more complex, perhaps),
I figure, besides computational considerations. 
To reason this, one can consider only using one or the other. 

1. By only picking values based on the true classes, there's possibly
very few false positive examples to compare: by picking a fraction of background indices from the true class
set, it's possible that many of the excluded indices would be (falsely)
predicted as some positive class. Therefore many of these false positive
examples would be lost. We would focus on what we do want to predict
as positive, but not as much what we don't.
2. By only picking background fractions
from the predicted set, many of the true positive examples would potentially be
lost, especially at first as the predictions are likely uniformly
distributed across the classes.

In general, the more each class
can be compared with each other, the better.

In [None]:


def calculate_class_loss_true(
    predicted_classes,
    matched_classes,
    negative_to_positive_ratio: 3.0
):
    '''
    Calculate an NLL loss between the class predictions in
    `predicted_classes`, (N,C) and true class values in `matched_classes`,
    (N). `predicted_classes` rows should be log-probabilities for each
    class, with the column index matching the set class index (0 being
    background).

    Potentionally ignores excessive negative examples based on `matched_classes`
    and `negative_to_positive_ratio`.

    `negative_to_positive_ratio`:
    - Determines how many more negative (background) class examples
    should used compared to the number of available positive examples.
    Lower ratios increase focus on correct positive predictions, 
    and higher ones increase focus on correct background predictions.
    It is assumed that the number of negative examples is much higher,
    therefore this will only affect the picking of those examples.
    The number of picked negative examples is still between zero
    and the total number of negative examples.
    '''
    
    # pick negative/positive examples in a (at most) 3:1 ratio
    # --------------------------------------------------------
    true_background = matched_classes == 0
    num_background = true_background.sum()
    true_other = torch.logical_not(true_background)
    num_other = true_other.sum()
    neg_to_pick = max(
        min(num_background, negative_to_positive_ratio*num_other),
        0
    )
    neg_to_pick = int(neg_to_pick)
    # randomly get indices of negative examples
    choices_idx = torch.argwhere(true_background)
    choices_idx = choices_idx[torch.randperm(num_background)[:neg_to_pick]]

    # set negative example indices (alongside already existing positive ones)
    true_other[choices_idx] = True

    # =======================================================
    

    nll = torch.nn.functional.nll_loss(
        predicted_classes[true_other],
        matched_classes[true_other],
    )

    return nll


def calculate_class_loss_predicted(
    predicted_classes,
    matched_classes,
    negative_to_positive_ratio: 3.0
):
    '''
    Calculate an NLL loss between the class predictions in
    `predicted_classes`, (N,C) and true class values in `matched_classes`,
    (N). `predicted_classes` rows should be log-probabilities for each
    class, with the column index matching the set class index (0 being
    background).

    Potentionally ignores excessive negative examples based on `predicted_classes`
    and `negative_to_positive_ratio`.

    `negative_to_positive_ratio`:
    - Determines how many more negative (background) class examples
    should used compared to the number of available positive examples.
    Lower ratios increase focus on correct positive predictions, 
    and higher ones increase focus on correct background predictions.
    It is assumed that the number of negative examples is much higher,
    therefore this will only affect the picking of those examples.
    The number of picked negative examples is still between zero
    and the total number of negative examples. 
    '''
    
    # pick negative/positive examples in a (at most) 3:1 ratio
    # --------------------------------------------------------
    predicted_background = torch.argmax(predicted_classes, dim=1) == 0
    num_background = predicted_background.sum()
    predicted_other = torch.logical_not(predicted_background)
    num_other = predicted_other.sum()
    neg_to_pick = max(
        min(num_background, negative_to_positive_ratio*num_other),
        0
    )
    neg_to_pick = int(neg_to_pick)
    # randomly get indices of negative examples
    choices_idx = torch.argwhere(predicted_background)
    choices_idx = choices_idx[torch.randperm(num_background)[:neg_to_pick]]

    # set negative example indices (alongside already existing positive ones)
    predicted_other[choices_idx] = True

    # =======================================================
    
    nll = torch.nn.functional.nll_loss(
        predicted_classes[predicted_other],
        matched_classes[predicted_other],
    )

    return nll


def calculate_class_loss(
    predicted_classes,
    matched_classes,
    negative_to_positive_ratio: 3.0
):
    '''
    Calculate an NLL loss between the class predictions in
    `predicted_classes`, (N,C) and true class values in `matched_classes`,
    (N). `predicted_classes` rows should be log-probabilities for each
    class, with the column index matching the set class index (0 being
    background).

    `negative_to_positive_ratio`:
    - Determines how many more negative (background) class examples
    should used compared to the number of available positive examples.
    Lower ratios increase focus on correct positive predictions, 
    and higher ones increase focus on correct background predictions.
    It is assumed that the number of negative examples is much higher,
    therefore this will only affect the picking of those examples.
    The number of picked negative examples is still between zero
    and the total number of negative examples. 
    '''
    
    nll = (
        calculate_class_loss_predicted(
            predicted_classes,
            matched_classes,
            negative_to_positive_ratio
        )
        + calculate_class_loss_true(
            predicted_classes,
            matched_classes,
            negative_to_positive_ratio
        )
    )

    return nll


In [None]:

def calculate_losses(
    default_boxes,
    predicted_offsets,
    predicted_classes,
    ground_truth_boxes,
    matched_classes,
    ground_truth_overlap_index
):
    '''
    Return dict of l1, nll.
    '''
    
    l1 = calculate_box_loss(
        default_boxes,
        predicted_offsets,
        ground_truth_boxes,
        ground_truth_overlap_index
    )

    nll = calculate_class_loss(predicted_classes,matched_classes, 3.0)

    return dict(
        l1 = l1,
        nll = nll
    )

calculate_losses(model.default_boxes, box_preds[0], class_preds[0], bndbox, classes, boxes)


In [None]:

def calculate_loss(
        model: SSD,
        prediction,
        ground_truth_boxes:list[torch.Tensor],
        ground_truth_classes: list[torch.Tensor],
        weight: float = 1.0,
        iou: list[torch.Tensor | None] | None = None
):
    '''
    Calculate the losses for the output of `model`'s forward pass, `prediction`.
    Each batch of `ground_truth_boxes` is (N,4), containing the bounding boxes (xmin, ymin, xmax, ymax) 
    of the input image objects. Each batch of `ground_truth_classes` is (N), with the elements matching
    the true class of the bounding boxes in `ground_truth_boxes`.

    Return dict of "combined", "l1", and "nll" over all batches, where "combined"
    is the combined loss of the other two (including the weight).

    `weight` is used as weight for the returned loss as l1 + weight*nll,
    where l1 is the loss for the predicted classes, and nll the loss
    for the predicted boxes.

    `iou` can be passed if it has been calculated using

    > iou = src.utils.math.intersection_over_union(model.default_boxes, target_bndbox)
    '''

    batches = len(ground_truth_boxes)

    total_losses = dict(
        l1 = 0,
        nll = 0
    )
    class_preds = torch.hstack(prediction["class_preds"])
    box_preds = torch.hstack(prediction["box_preds"])
    for i in range(batches):

        if iou is None or (curr_iou := iou[i]) is None:
            curr_iou = math_utils.intersection_over_union(
                model.default_boxes,
                ground_truth_boxes[i]
            )

        classes, boxes = classes_and_boxes_truth(curr_iou, ground_truth_classes[i])
        losses = calculate_losses(
            model.default_boxes,
            box_preds[i],
            class_preds[i],
            ground_truth_boxes[i],
            classes,
            boxes
        )
        total_losses["l1"] += losses["l1"]
        total_losses["nll"] += losses["nll"]

    return dict(
        combined = total_losses["l1"] + weight*total_losses["nll"]
    ) | total_losses


# use same ground truths for testing
calculate_loss(model, res, [bndbox]*2, [clslist]*2)


# Training

In [None]:
from itertools import batched
from random import sample


image_width = 128

train_split, test_split, validation_split = map(
    lambda split: CatsAndDogsDataset(split, class_names= ["cat", "dog"], resize_to=(image_width,image_width)),
    dataset_splits(fractions = (0.8,0.10,0.10))
)

# the PyTorch Dataloader seems to do some extra stuff that
# doesn't really fit the data from the dataset. Could
# probably create a new Dataloader class or maybe
# have the dataset return the data somewhat differently,
# but doing this for now
class Batcher:
    '''
    Iterator that iterates over given data in batches, possibly randomly,
    returning the set of indices and data points for each batch.
    '''

    def __init__(self, data: list, batch_size: int, shuffle = True):
        '''
        `shuffle`:
        - if False, Don't randomise the order.  
        '''

        self.idx_and_data = tuple(enumerate(data))
        self.batch_size = batch_size
        self.shuffle = shuffle
        self._iter = None

    def __len__(self):

        return len(self.data)
    
    def __iter__(self):

        if self.shuffle:
            self._iter = batched(
                sample(
                    self.idx_and_data,
                    len(self.idx_and_data)
                ),
                n = self.batch_size
            )
        else:
            self._iter = batched(
                self.idx_and_data,
                n = self.batch_size
            )

        return self

    def __next__(self):

        if self._iter is None:
            raise StopIteration

        idx, data = zip(*next(self._iter))
        return idx, data



In [None]:
def get_image_and_labels(data: list):

    imgs = torch.stack([val[0] for val in data])
    objs = [val[1]["objects"] for val in data]
    bndboxes = [obj["bndbox"] for obj in objs]
    classes = [obj["class"] for obj in objs]

    return dict(
        imgs = imgs,
        bndboxes = bndboxes,
        classes = classes
    )


def train(
    model: SSD,
    optimiser: torch.optim.Optimizer,
    train_data: Batcher,
    test_data: list,
    epochs: int,
    test_every = 10
):

    print()

    l1_losses = []
    nll_losses = []

    # setup test data
    test_imgs = torch.stack([val[0] for val in test_data])
    objs = [val[1]["objects"] for val in test_data]
    test_bndboxes = [obj["bndbox"] for obj in objs]
    test_classes = [obj["class"] for obj in objs]

    test_imgs, test_bndboxes, test_classes = get_image_and_labels(
        test_data
    ).values()

    test_ious = math_utils.batched_intersection_over_union(
        model.default_boxes,
        test_bndboxes
    )


    for epoch in range(epochs):

        for batch_idx, batch in train_data:

            batch_size = len(batch)

            # setup training data
            train_imgs, train_bndboxes, train_classes = get_image_and_labels(
                batch
            ).values()

            train_ious = math_utils.batched_intersection_over_union(
                model.default_boxes,
                train_bndboxes
            )

            # train
            optimiser.zero_grad()

            pred = model(train_imgs)

            loss = calculate_loss(
                model,
                pred,
                train_bndboxes,
                train_classes,
                iou = train_ious
            )

            loss["l1"] /= batch_size
            loss["nll"] /= batch_size

            loss["l1"].backward()
            loss["nll"].backward()
            optimiser.step()

        if epoch % test_every == 0 or epoch == (epochs-1):
        
            with torch.no_grad():

                test_pred = model(test_imgs)
                test_loss = calculate_loss(
                    model,
                    test_pred,
                    test_bndboxes,
                    test_classes,
                    iou = test_ious
                )

            l1_losses.append(test_loss["l1"]/len(test_imgs))
            nll_losses.append(test_loss["nll"]/len(test_imgs))

        print(f"{epoch+1}/{epochs} l1 Loss: {l1_losses[-1]:.5f} nll Loss: {nll_losses[-1]:.5f}", end=" "*30+"\r")


    print()
    return l1_losses, nll_losses

## Proper training loop

In [None]:
model_dev = "cpu"
model_dev = torch_utils.TORCH_DEVICE
model = SSD(3, 5, device = model_dev, width = image_width)

optim = torch.optim.Adam(model.parameters())

examples = min(len(train_split), 3000)
# not particularly the way to use Datasets, but works
train_dat = [train_split[i] for i in range(examples)]
test_dat = [test_split[i] for i in range(5)]

batch_size = 30

losses = train(
    model,
    optim,
    Batcher(train_dat, batch_size = batch_size, shuffle = True),
    test_dat,
    epochs = 10,
    test_every = 1
)

# TODO: check out TensorBoard

fig, ax = plt.subplots(1, 2)

ax[0].plot(losses[0], label="l1")
ax[0].legend()
ax[1].plot(losses[1], label="nll")
ax[1].legend()

## Testing loop

In [None]:
test_model = SSD(
    3,
    5,
    device = torch_utils.TORCH_DEVICE,
    width = image_width
)
test_optim = torch.optim.Adam(test_model.parameters())


# TODO: lots of .to calls? loss calculation seems to also be a time hog
# Basically, the issue is having to move from GPU to CPU, and obviously
# the lack of vectorisation in parts. Without vectorisation, calculating
# purely on GPU seems to run slower compared to CPU.
train_test = [train_split[i] for i in range(30)]
test_test = [test_split[i] for i in range(1)]
%prun -s "cumulative" _ = train(test_model, test_optim, Batcher(train_test, batch_size = 30), test_test, epochs=100)

# Test trained model

In [None]:

im, meta = validation_split[0]

pred = model(im.unsqueeze(dim=0))

bndbox, classes = meta["objects"]["bndbox"], meta["objects"]["class"]
iou = math_utils.intersection_over_union(model.default_boxes, bndbox)

cl, boxes = classes_and_boxes_truth(iou, classes)
matched = boxes != -1
print(boxes.unique())
predicted = offset_default_boxes(model.default_boxes[matched], torch.hstack(pred["box_preds"])[0][matched])

base_im = add_bb(im, meta["objects"]["bndbox"])
print(torch.hstack(pred["box_preds"]))
predicted_im = add_bb(base_im, predicted, color = "red")
plt.imshow(to_plottable(predicted_im))
    


In [None]:

print(model.default_boxes[matched])
print(torch.hstack(pred["box_preds"])[0][matched])
print(predicted)

# compare boxes
# TODO: Non-Maximum Suppression
print(torch.unique(boxes))
true_boxes = meta["objects"]["bndbox"][boxes[matched]]
print(true_boxes)
print((predicted - true_boxes).abs().sum(dim=1))

In [None]:

# compare classes

pred_cls = torch.hstack(pred["class_preds"])[0].max(dim=1).indices
classes_names = list(map(
    lambda tup: tup[0],
    sorted(validation_split.class_dict.items(), key = lambda tup: tup[1])
    ))
print("predictions per class:",[((pred_cls == i).sum().item(),classes_names[i]) for i in range(3)])

print("Correct predictions in total:",(pred_cls == cl).sum().item())
print("Correct predictions per class:", [(torch.logical_and((pred_cls == i), (cl == i)).sum().item(), classes_names[i]) for i in range(3)])


## Calculate accuracies for all examples

In [None]:

val_imgs, val_bndboxes, val_classes = get_image_and_labels(
    [validation_split[i] for i in range(len(validation_split))]
).values()

with torch.no_grad():
    val_preds = model(val_imgs)

val_box_preds = torch.hstack(val_preds["box_preds"])

val_class_preds = torch.hstack(val_preds["class_preds"])
val_class_preds = torch.max(val_class_preds, dim=-1).indices


val_iou = math_utils.batched_intersection_over_union(model.default_boxes, val_bndboxes)

class_score = {k:torch.tensor([0.0]) for k in validation_split.class_dict.values()}
class_dict_reverse = {v:k for k,v in validation_split.class_dict.items()}
for i in range(len(validation_split)):
    true_val_classes, true_val_boxes = classes_and_boxes_truth(val_iou[i], val_classes[i])

    for cl in map(int, true_val_classes.unique()):

        is_class = true_val_classes == cl

        accuracy = (
            (val_class_preds[i][is_class] == true_val_classes[is_class]).sum()/(is_class.sum())
        )

        class_score[cl] += accuracy

class_score = {class_dict_reverse[k]+" accuracy":(v/len(validation_split)).item() for k,v in class_score.items()}



print(class_score)

# TODO

- Non-maximum suppression: useful for evaluating the predictions as well,
currently kind of difficult in a sensible way at least with the box predictions
 - Better accuracy calculations (currently just the class scores)
- Improve speed? loss calculations quite slow, GPU -> CPU -> GPU changes
slow the training