In [23]:
import torch
from torch.utils.data import DataLoader,Dataset,ConcatDataset
import glob
from PIL import Image
import torchvision.transforms.functional as TF
import xml.etree.ElementTree as ET
import numpy as np
from scipy.stats import multivariate_normal
from torchvision import transforms

In [24]:
class NimbroDataset(Dataset):
    def __init__(self,image_dir,image_size,object_to_channel = {'ball':0,'goalpost':1,'robot':2},object_var = {'ball':30,'goalpost':30,'robot':60}
):
        # List of all xml files
        self.xml_files = glob.glob(image_dir+'/*.xml')
        
        # channel for each object
        self.object_to_channel = object_to_channel
        
        # variance for each object (higher for robot)
        self.object_var = object_var
        
        # for resizing and normalising
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        self.resize_inp = transforms.Compose([transforms.ToPILImage(),transforms.Resize(image_size),transforms.ToTensor()])
        self.resize_op = transforms.Compose([transforms.ToPILImage(),transforms.Resize((image_size[0]//4,image_size[1]//4)),transforms.ToTensor()])
        
        # for caching target and input image path
        self.data = [ i for i in range(len(self.xml_files))]

            
    def __getitem__(self,index):
        
        f = self.xml_files[index]
        
        # if annotated target and input image is saved in the cache
        if type(self.data[index]) != int:
            img_path,target = self.data[index]
        else:   
            img_path,annot = self.parse_annotations(f)
        
        try:
            image = Image.open('/'.join(f.split('/')[:-1] + [img_path.split('/')[-1]]))
        except:
            image = Image.open(f.replace('xml',img_path.split('.')[-1]))
        
        x = TF.to_tensor(image)
        x = self.resize_inp(x)
        x = self.normalize(x)
        
        #  if target not saved in the cache
        if type(self.data[index]) == int:
            # generating and resizing target
            target = self.annotation_to_target(annot,x.shape[1],x.shape[2])
            target = self.resize_op(target)
            # for scaling values approx between 1 and 0
            target[0,:,:] *= 1/0.1373
            target[1,:,:] *= 1/0.1451
            target[2,:,:] *= 1/0.149
            # saving target in cache
            self.data[index] = (img_path,target)
            
        return x,target
    
    def __len__(self):
        return len(self.xml_files)
                           
    def parse_annotations(self, fname):
        '''
        method for parsing xml file to get annotations and path
        '''
        tree = ET.parse(fname)
        dpoints = {'ball': np.array([]), 'goalpost': np.array([]), 'robot': np.array([])}
        image_path = ''
        for elems in tree.iter():
            if elems.tag == "path" or elems.tag == "filename":
                if len(image_path) < len(elems.text):
                    image_path = elems.text
                    
            if elems.tag == "object":
                for elem in elems:
                    if elem.tag == "name":
                        label = str(elem.text)              
                    if elem.tag == "bndbox":
                        i=0
                        bbox_coords = {}
                        for k in elem:
                            bbox_coords[i] = float(k.text)
                            i = i+1
                
                # calculating center
                mid_x_y = [(bbox_coords[0] + bbox_coords[2]) / 2, (bbox_coords[1] + bbox_coords[3]) / 2]
                dpoints[label] = np.append(dpoints[label], mid_x_y[::-1], axis=0)
        for k in dpoints.keys():
            dpoints[k] = dpoints[k].reshape((-1, 2))
        return image_path,dpoints
        
                           
    def annotation_to_target(self,annot,h,w):
        '''
        method to generate target image from annotations
        '''
        target = np.zeros((len(self.object_to_channel),h,w), dtype='float32')
        
        # all coordinates for target pixels to calculate gaussian
        y = np.linspace(0,h - 1,num=h)
        x = np.linspace(0,w - 1,num=w)
        xx,yy = np.meshgrid(x,y)
        z = zip(yy.reshape(-1),xx.reshape(-1))
        coords = [coord for coord in z]
        
        # iterating ober objects
        for object_ in annot.keys():
            # iterating over each center
            for coord in annot[object_]:
                # creating gaussian with given center and variance
                rv = multivariate_normal([int(coord[0]),int(coord[1])], [[self.object_var[object_], 0], [0, self.object_var[object_]]])
                
                # subtracting values of blobs from target with proper scaling
                target[self.object_to_channel[object_],:,:] += self.object_var[object_]*rv.pdf(coords).reshape(h,w)
        
        target = torch.from_numpy(target)
        return target