In [83]:
import torch
from torch.utils.data import DataLoader,Dataset,ConcatDataset
import glob
from PIL import Image
import torchvision.transforms.functional as TF
import xml.etree.ElementTree as ET
import numpy as np
from scipy.stats import multivariate_normal
from torchvision import transforms

In [84]:

class NimbroDataset(Dataset):
    def __init__(self,image_dir,image_size,object_to_channel = {'ball':0,'goalpost':2,'robot':1},object_var = {'ball':5,'goalpost':5,'robot':10}
):
        # List of all xml files
        self.xml_files = glob.glob(image_dir+'/*.xml')

        # channel for each object
        self.object_to_channel = object_to_channel
        
        # variance for each object (higher for robot)
        self.object_var = object_var
        
        # for resizing and normalising
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        self.resize_inp = transforms.Compose([transforms.ToPILImage(),transforms.Resize(image_size),transforms.ToTensor()])        
        
        # for caching target and input image path
        self.data = [ i for i in range(len(self.xml_files))]

        self.image_size = image_size
            
    def __getitem__(self,index):
        
        f = self.xml_files[index]
        
        # if annotated target and input image is saved in the cache
        if type(self.data[index]) != int:
            img_path,target = self.data[index]
        else:   
            img_path,annot = self.parse_annotations(f)

        try:
            image = Image.open(f.replace('xml',img_path.split('.')[-1]))

        except:
            image = Image.open('/'.join(f.split('/')[:-1] + [img_path.split('/')[-1]]))
        
        
        x = TF.to_tensor(image)

        #  if target not saved in the cache
        if type(self.data[index]) == int:
            # generating target
            target = self.annotation_to_target(annot,x.shape[1],x.shape[2])
            # saving target in cache
            self.data[index] = (img_path,target)
            
        
        x = self.resize_inp(x)
        x = self.normalize(x)
        
        return x,target
    
    def __len__(self):
        return len(self.xml_files)
                           
    def parse_annotations(self, fname):
        '''
        method for parsing xml file to get annotations and path
        '''
        tree = ET.parse(fname)
        dpoints = {'ball': np.array([]), 'goalpost': np.array([]), 'robot': np.array([])}
        image_path = ''
        for elems in tree.iter():
            if elems.tag == "path" or elems.tag == "filename":
                if len(image_path) < len(elems.text):
                    image_path = elems.text
                    
            if elems.tag == "object":
                for elem in elems:
                    if elem.tag == "name":
                        label = str(elem.text)              
                    if elem.tag == "bndbox":
                        i=0
                        bbox_coords = {}
                        for k in elem:
                            bbox_coords[i] = float(k.text)
                            i = i+1
                
                # calculating bottom center for goalposts
                if label == 'goalpost':
                        mid_x_y = [(bbox_coords[0] + bbox_coords[2]) / 2, bbox_coords[3]]
                    
                # for ball and robot, it is a normal center
                else:
                    mid_x_y = [(bbox_coords[0] + bbox_coords[2]) / 2, (bbox_coords[1] + bbox_coords[3]) / 2]
                dpoints[label] = np.append(dpoints[label], mid_x_y[::-1], axis=0)
        for k in dpoints.keys():
            dpoints[k] = dpoints[k].reshape((-1, 2))
        return image_path,dpoints
        
                           
    def annotation_to_target(self,annot,h,w):
        '''
        method to generate target image from annotations
        '''
        # creating empty target of one fourth of input image size
        h_,w_ = self.image_size[0]//4,self.image_size[1]//4
        target = np.zeros((len(self.object_to_channel),h_,w_), dtype='float32')
        
        # all coordinates for target pixels to calculate gaussian
        y = np.linspace(0,h_ - 1,num=h_)
        x = np.linspace(0,w_- 1,num=w_)
        xx,yy = np.meshgrid(x,y)
        z = zip(yy.reshape(-1),xx.reshape(-1))
        coords = [coord for coord in z]
        
        # iterating over objects (Ball, Goal posts, Robot)
        for object_ in annot.keys():
            # iterating over each center
            for coord in annot[object_]:
                variance = self.object_var[object_]
                
                #normalising center according to original size
                mean = [coord[0]*h_/h,coord[1]*w_/w]
                
                # creating gaussian with given center and variance
                rv = multivariate_normal(mean, [[variance, 0], [0, variance]])
                
                # adding values of blobs from target with proper scaling
                target[self.object_to_channel[object_],:,:] += variance*rv.pdf(coords).reshape(h_,w_)
        
        target = torch.from_numpy(target)
        
        # for scaling values approx between 1 and 0
        target[0,:,:] /= target[0,:,:].max() + 0.00000001
        target[1,:,:] /= target[1,:,:].max() + 0.00000001
        target[2,:,:] /= target[2,:,:].max() + 0.00000001
        return target