# Read from Kaggle, unzip file

In [None]:
import urllib.request as urlrequest
from pathlib import Path

base_data_path = Path() / "data" 

base_data_path.mkdir(parents = True, exist_ok = True)

dataset_location = base_data_path / "dataset"

zip_path = base_data_path / "dataset.zip"

if not zip_path.exists():

    dataset_url = "https://www.kaggle.com/api/v1/datasets/download/alvarole/asirra-cats-vs-dogs-object-detection-dataset"

    response = urlrequest.urlopen(
        dataset_url,
    )

    download_size = response.getheader("Content-Length")
    with open(zip_path, "wb") as f:

        f.write(response.read())


In [None]:

import zipfile
import os

def unzip():

    inner_file = "Asirra_ cat vs dogs"
    with zipfile.ZipFile(zip_path, "r") as zip:
        
        for item in zip.infolist():

            zip.extract(item, base_data_path)

    os.rename(base_data_path / inner_file, dataset_location)

# unzip()


# Dataset creation

In [None]:

import itertools
import xml.etree.ElementTree as ET
import torch
from typing import TypedDict


def patched_dataset_paths(dataset_location):

    return itertools.batched(dataset_location.iterdir(), 2)

class Objects(TypedDict):
    '''
    `bndbox`: (xmin,ymin,xmax,ymax)
    '''

    name: str
    pose: str
    truncated: int
    difficult: int
    bndbox: torch.Tensor

class Metadata(TypedDict):
    '''
    `size`: (width, height, depth)
    '''

    size: torch.Tensor
    objects: list[Objects]

class MetaWithImage(Metadata):

    img_path: str

# specific xml reader implementation for the lolz
def read_metadata(xml_file: Path) -> Metadata:
    '''
    Read labeling from xml file into dict.
    '''

    with open(xml_file, "r", encoding = "utf-8") as f:
        text = ET.canonicalize(from_file=f, strip_text = True)
        
    tree = ET.fromstring(text)

    size = tree.find("size")
    size = torch.tensor([int(elem.text) for elem in size.iter() if not elem.tag == "size"])


    objects = tree.findall("object")
    objects: Objects = [dict(
        name = obj.find("name").text,
        pose = obj.find("pose").text,
        truncated = int(obj.find("truncated").text),
        difficult = int(obj.find("difficult").text),
        bndbox = torch.tensor([
            float(elem.text) 
            for elem in obj.find("bndbox").iter()
            if not elem.tag == "bndbox"
        ])
    ) for obj in objects]

    metadata: Metadata = dict(
        size = size,
        objects = objects
    )

    return metadata

def get_dataset(dataset_location) -> list[MetaWithImage]:

    meta: list[MetaWithImage] = []
    for img, xml_path in patched_dataset_paths(dataset_location):

        metadata: MetaWithImage = read_metadata(xml_path) | dict(img_path = img)
        meta.append(metadata)

    return meta

def dataset_splits(dataset: list[MetaWithImage] | None = None, fractions: tuple[float] = (0.8, 0.1, 0.1)):

    dataset = get_dataset(dataset_location) if dataset is None else dataset
    return  torch.utils.data.random_split(dataset, fractions)



    


In [None]:
import torch
from torchvision.io import read_image
import torchvision.transforms.v2.functional as tvt

class CatsAndDogsDataset(torch.utils.data.Dataset):
    def __init__(self, data: list[MetaWithImage], resize_to = (300,300)):

        self.resize_to = resize_to
        self.data = [self.metadata_transform(val) for val in data]

    def __len__(self):
        return len(self.data)

    def image_transform(self, img):

        return tvt.resize(img, self.resize_to)

    def metadata_transform(self, metadata: MetaWithImage):
        '''
        Transform the bndbox values to be in the range [0,1].
        '''

        resize_x, resize_y = self.resize_to
        width, height, depth = metadata['size']
        for i in range(len(metadata["objects"])):
            obj = metadata['objects'][i]
            bndbox = obj['bndbox']
            metadata["objects"][i]['bndbox'] = (
                bndbox/torch.tensor([width, height]*2)
            )

        metadata["size"] = torch.tensor([resize_x, resize_y, depth])

        return metadata

    def __getitem__(self, idx):
        metadata = self.data[idx]
        img_path = metadata["img_path"]
        image = read_image(img_path)
        image = self.image_transform(image)
        return image, metadata


In [None]:
train_split, validation_split, test_split = dataset_splits()


train_split = CatsAndDogsDataset(train_split)
validation_split = CatsAndDogsDataset(validation_split)
test_split = CatsAndDogsDataset(test_split)

print(len(train_split))

# Test datasets with plotting

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.v2.functional as vision_transforms
from torchvision.utils import draw_bounding_boxes

def to_plottable(img):

    return vision_transforms.to_pil_image(img) 

def add_bb(img, meta: MetaWithImage):

    width, height, _ = [val.item() for val in meta["size"]]
    for _object in meta["objects"]:
        bb = _object["bndbox"].reshape((-1, 4))
        bb = bb*torch.tensor([width, height, width, height])
        print(bb)
        img = draw_bounding_boxes(img, bb, colors = "cyan")

    return img

plt.figure()
im, meta = train_split[0]
im = add_bb(im, meta)
plt.imshow(to_plottable(im))
print(meta["objects"])
plt.show()


# Test IoU calculation

In [None]:
import src.default_box as default_box
import src.utils.math as math_utils


im, meta = train_split[0]
bndbox = meta["objects"][0]["bndbox"].reshape([-1,4])
width, height, _ = [val.item() for val in meta["size"]]

print(width, height)
step = 40
boxes = default_box.default_boxes(
    scale = 0.7,
    centers = default_box.default_box_centers(width, height, width_step = step, height_step = step)
)

num_boxes, num_ratios, _ = boxes.shape
boxes = boxes.reshape([num_boxes*num_ratios, 4])

print(bndbox)
print(boxes.shape)
iou = math_utils.intersection_over_union(boxes.reshape([num_boxes*num_ratios, 4]), bndbox)
print(iou.shape)
iou = iou.reshape([num_ratios, len(bndbox), num_boxes])
print(iou.shape)

print([(iou >= val).sum() for val in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]])
