In [0]:
import os
import glob
from PIL import Image, ImageDraw
import xml.etree.ElementTree as ET
import torch
from torchvision import transforms
from torch.utils.data import Dataset

In [0]:
class TissueDataset(Dataset): # create dataset class

    def __init__(self, img_dir='data/images', annotation_dir='data/annotations', transform=None):
        self.img_dir = img_dir # what directory are the images in
        self.annotation_dir = annotation_dir # what directory are the annotations in
        self.transform = transform # what transforms were passed to the initialiser

        self.img_names = os.listdir(img_dir) # list all files in the img folder
        self.img_names = [ filename for filename in self.img_names if filename.endswith( 'png' ) ]
        self.img_names.sort() # order the images alphabetically
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names] # join folder and file names

        self.annotation_names = os.listdir(annotation_dir) # list all annotation files
        self.annotation_names = [ filename for filename in self.annotation_names if filename.endswith( 'xml' ) ]
        self.annotation_names.sort() # order annotation files alphabetically
        self.annotation_names = [os.path.join(annotation_dir, ann_name) for ann_name in self.annotation_names] # join folder and file names


    def __getitem__(self, idx):
        img_name = self.img_names[idx] # get the path of the image at that index
        img = Image.open(img_name) # open the image using the path

        annotation_name = self.annotation_names[idx] # get the path to the label file
        annotation_tree = ET.parse(annotation_name) # use xml parser to load the file
        bndbox_xml = annotation_tree.find("object").find("bndbox") # get the tag which contains our labels
        
        # get the x and y values for the corners of the rectangle
        xmax = int(bndbox_xml.find('xmax').text) 
        ymax = int(bndbox_xml.find('ymax').text)
        xmin = int(bndbox_xml.find('xmin').text)
        ymin = int(bndbox_xml.find('ymin').text)

        # Convert from corner co-ordinates format into center co-ordinate, width and height format
        w = xmax - xmin #
        h = ymax - ymin
        x = int(xmin + w / 2)
        y = int(ymin + h / 2)

        # Normlise the labels so the values are expressed as a proportion of the whole image
        x /= img.size[0]
        w /= img.size[0]
        y /= img.size[1]
        h /= img.size[1]

        bndbox = (x, y, w, h) # create tuple of bounding box dimensions
        
        if self.transform: # if any transforms were given to initialiser
            img = self.transform(img) # apply any transforms

        bndbox = torch.tensor(bndbox) # convert bounding box tuple to tensor

        return img, bndbox


    def __len__(self):
        return len(self.img_names)

In [0]:
tissueDataset = TissueDataset(transform=transforms.ToTensor())

In [39]:
print('len dataset:', len(tissueDataset))

len dataset: 1


In [0]:
# Convert from  center co-ordinate, width and height format into corner co-ordinates format
def unpack_bndbox(bndbox, img):
    x, y, w, h = tuple(bndbox)
    x *= img.size[0] 
    w *= img.size[0]
    y *= img.size[1]
    h *= img.size[1]
    xmin = x - w / 2
    xmax = x + w / 2
    ymin = y - h / 2
    ymax = y + h / 2
    bndbox = [xmin, ymin, xmax, ymax]
    return bndbox

In [0]:
def show(batch, pred_bndbox=None):
    img, bndbox = batch

    img = transforms.ToPILImage()(img)
    img = transforms.Resize((512, 512))(img)
    draw = ImageDraw.Draw(img)

    bndbox = unpack_bndbox(bndbox, img)
    draw.rectangle(bndbox)
    if pred_bndbox is not None:
        pred_bndbox = unpack_bndbox(pred_bndbox, img)
        draw.rectangle(pred_bndbox, outline=1000)
    img.show()

In [0]:
show(tissueDataset[0])