# Read from Kaggle, unzip file

In [18]:
import urllib.request as urlrequest
from pathlib import Path

base_data_path = Path() / "data" 

base_data_path.mkdir(parents = True, exist_ok = True)

dataset_location = base_data_path / "dataset"

zip_path = base_data_path / "dataset.zip"

if not zip_path.exists():

    dataset_url = "https://www.kaggle.com/api/v1/datasets/download/alvarole/asirra-cats-vs-dogs-object-detection-dataset"

    response = urlrequest.urlopen(
        dataset_url,
    )

    download_size = response.getheader("Content-Length")
    with open(zip_path, "wb") as f:

        f.write(response.read())


In [19]:

import zipfile
import os

def unzip():

    inner_file = "Asirra_ cat vs dogs"
    with zipfile.ZipFile(zip_path, "r") as zip:
        
        for item in zip.infolist():

            zip.extract(item, base_data_path)

    os.rename(base_data_path / inner_file, dataset_location)

# unzip()


# Dataset creation

In [22]:

import itertools
import xml.etree.ElementTree as ET
import torch


def patched_dataset_paths(dataset_location):

    return itertools.batched(dataset_location.iterdir(), 2)


# specific xml reader implementation for the lolz
def read_metadata(xml_file: Path):
    '''
    Read labeling from xml file into dict.
    '''

    with open(xml_file, "r", encoding = "utf-8") as f:
        text = ET.canonicalize(from_file=f, strip_text = True)
        
    tree = ET.fromstring(text)

    size = tree.find("size")
    size = {elem.tag: int(elem.text) for elem in size.iter() if not elem.tag == "size"}


    objects = tree.findall("object")
    objects = [dict(
        name = obj.find("name").text,
        pose = obj.find("pose").text,
        truncated = int(obj.find("truncated").text),
        difficult = int(obj.find("difficult").text),
        bndbox = {
            elem.tag: float(elem.text) 
            for elem in obj.find("bndbox").iter()
            if not elem.tag == "bndbox"
        }
    ) for obj in objects]

    metadata = dict(
        size = size,
        objects = objects
    )

    return metadata

def get_dataset(dataset_location) -> list[dict]:

    meta = []
    for img, xml_path in patched_dataset_paths(dataset_location):

        metadata = read_metadata(xml_path) | dict(img_path = img)
        meta.append(metadata)

    return meta

def dataset_splits(dataset: list | None = None, fractions: tuple[float] = (0.8, 0.1, 0.1)):

    dataset = get_dataset(dataset_location) if dataset is None else dataset
    return  torch.utils.data.random_split(dataset, fractions)



    


In [26]:
import torch
from torchvision.io import read_image

class CatsAndDogsDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None, target_transform=None):

        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        metadata = self.data[idx]
        img_path = metadata["img_path"]
        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            metadata = self.target_transform(metadata)
        return image, metadata


In [None]:


train_split, validation_split, test_split = dataset_splits()

train_split = CatsAndDogsDataset(train_split)
validation_split = CatsAndDogsDataset(validation_split)
test_split = CatsAndDogsDataset(test_split)

print(len(train_split))

# Test datasets with plotting

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as vision_transforms
from torchvision.utils import draw_bounding_boxes

def to_plottable(img):

    return vision_transforms.to_pil_image(img)

def add_bb(img, meta):

    for object in meta["objects"]:

        bb = torch.tensor(list(object["bndbox"].values())).reshape((-1, 4))
        print(bb)
        img = draw_bounding_boxes(img, bb, colors = "cyan")

    return img

plt.figure()
im, meta = train_split[0]
im = add_bb(im, meta)
plt.imshow(to_plottable(im))
print(meta["objects"])
plt.show()
