# Custom Datasets

__Prerequisites__

- [Convolutional Neural Networks](https://github.com/AI-Core/Convolutional-Neural-Networks)

So far, we have only used the MNIST dataset, which is easily accessible through torchvision. What do we do when we have our own data which we want to use with PyTorch?

To be compatible with a torch dataloader, we have to write a class from which we can create instances of the dataset. This class must overwrite the \_\_len\_\_ and \_\_getitem\_\_ functions. The \_\_len\_\_ function must return the length of the dataset we are loading in. The \_\_getitem\_\_ function must return an example datapoint given the index of it.

In our \_\_init\_\_ function, we read in the variables which we can use to get items from our dataset. In this case, 

Today, we will implement a class which loads in the S40 detection dataset. Detection is when we want our algorithm to draw rectangles around the locations of specific objects within the image. As opposed to classification where we simply have a binary output indicating if the object is contained within the image.
The dataset consists of an "images" folder which contains the input images and an "annotations" folder which, for each image, contains an xml file with the same name as the image and contains the co-ordinates for the top-left and bottom-right corners of the rectangular bounding box.

In [None]:
import os
from PIL import Image, ImageDraw
import xml.etree.ElementTree as ET
import torch
from torchvision import transforms

class S40dataset():

    def __init__(self, img_dir='S40-data/images', annotation_dir='S40-data/annotations', transform=None):
        self.img_dir = img_dir
        self.annotation_dir = annotation_dir
        self.transform = transform

        self.img_names = os.listdir(img_dir)                                  # list all files in the img folder
        self.img_names.sort()                                                   # order the images alphabetically
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names]   # join folder and file names

        self.annotation_names = os.listdir(annotation_dir)                      # list all annotation files
        self.annotation_names.sort()                                            # order annotation files alphabetically
        self.annotation_names = [os.path.join(annotation_dir, ann_name) for ann_name in self.annotation_names]   # join folder and file names

        print(self.img_names)
        print(self.annotation_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx] #get the path of the image at that index
        img = Image.open(img_name) #open the image using the path

        annotation_name = self.annotation_names[idx] #get the path to the label file
        annotation_tree = ET.parse(annotation_name) #use xml parser to load the file
        bndbox_xml = annotation_tree.find("object").find("bndbox") #get the tag which contains our labels
        
        #get the x and y values for the corners of the rectangle
        xmax = int(bndbox_xml.find('xmax').text) 
        ymax = int(bndbox_xml.find('ymax').text)
        xmin = int(bndbox_xml.find('xmin').text)
        ymin = int(bndbox_xml.find('ymin').text)
        #print(xmax, ymax, xmin, ymin)

        # Convert from corner co-ordinates format into center co-ordinate, width and height format
        w = xmax - xmin #
        h = ymax - ymin
        x = int(xmin + w / 2)
        y = int(ymin + h / 2)

        # Normlise the labels so the values are expressed as a proportion of the whole image
        x /= img.size[0]
        w /= img.size[0]
        y /= img.size[1]
        h /= img.size[1]

        bndbox = (x, y, w, h)
        
        #apply any transforms
        if self.transform:
            img = self.transform(img)

        bndbox = torch.tensor(bndbox)

        return img, bndbox

    def __len__(self):
        return len(self.img_names)

# Convert from  center co-ordinate, width and height format into corner co-ordinates format
def unpack_bndbox(bndbox, img):
    bndbox = list(bndbox[0])
    x, y, w, h = tuple(bndbox)
    x *= img.size[0]
    w *= img.size[0]
    y *= img.size[1]
    h *= img.size[1]
    xmin = x - w / 2
    xmax = x + w / 2
    ymin = y - h / 2
    ymax = y + h / 2
    bndbox = [xmin, ymin, xmax, ymax]
    return bndbox

def show(batch, pred_bndbox=None):
    img, bndbox = batch

    img = img[0]
    print(img.shape)
    img = transforms.ToPILImage()(img)
    img = transforms.Resize((512, 512))(img)
    draw = ImageDraw.Draw(img)

    bndbox = unpack_bndbox(bndbox, img)
    print(bndbox)
    draw.rectangle(bndbox)
    if pred_bndbox is not None:
        pred_bndbox = unpack_bndbox(pred_bndbox, img)
        draw.rectangle(pred_bndbox, outline=1000)
    img.show()

How do we build an algorithm to learn solutions to this problem? Find out in the CNN Detection notebook!

__Next Steps__

- [CNN Detection](https://github.com/AI-Core/Convolutional-Neural-Networks/blob/master/CNN%20Detection.ipynb)