In [49]:
import torch
from torch.utils.data import DataLoader,Dataset,ConcatDataset
import glob
from PIL import Image
import torchvision.transforms.functional as TF
import xml.etree.ElementTree as ET
import numpy as np
from scipy.stats import multivariate_normal
from torchvision import transforms

In [131]:
class NimbroDataset(Dataset):
    def __init__(self,image_dir,image_size,object_to_channel = {'ball':0,'goalpost':1,'robot':2},object_var = {'ball':30,'goalpost':30,'robot':100}
):
        self.xml_files = glob.glob(image_dir+'/*.xml')
        self.object_to_channel = object_to_channel
        self.object_var = object_var
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        self.resize_inp = transforms.Compose([transforms.ToPILImage(),transforms.Resize(image_size),transforms.ToTensor()])
        self.resize_op = transforms.Compose([transforms.ToPILImage(),transforms.Resize((image_size[0]//4,image_size[1]//4)),transforms.ToTensor()])
        self.data = [ i for i in range(len(self.xml_files))]

            
    def __getitem__(self,index):
        # if saved in the cache
        f = self.xml_files[index]
        if type(self.data[index]) != int:
            img_path,target = self.data[index]
        else:   
            img_path,annot = self.parse_annotations(f)
        
        try:
            image = Image.open('/'.join(f.split('/')[:-1] + [img_path.split('/')[-1]]))
        except:
            image = Image.open(f.replace('xml',img_path.split('.')[-1]))
        
        x = TF.to_tensor(image)
        x = self.resize_inp(x)
        x = self.normalize(x)
        
        #  if not saved in the cache
        if type(self.data[index]) == int:
            target = self.annotation_to_target(annot,x.shape[1],x.shape[2])
            target = self.resize_op(target)
            self.data[index] = (img_path,target)
            
        return x,target
    
    def __len__(self):
        return len(self.xml_files)
                           
    def parse_annotations(self, fname):
        tree = ET.parse(fname)
        dpoints = {'ball': np.array([]), 'goalpost': np.array([]), 'robot': np.array([])}
        image_path = ''
        for elems in tree.iter():
            if elems.tag == "path" or elems.tag == "filename":
                if len(image_path) < len(elems.text):
                    image_path = elems.text
                    
            if elems.tag == "object":
                for elem in elems:
                    if elem.tag == "name":
                        label = str(elem.text)              
                    if elem.tag == "bndbox":
                        i=0
                        bbox_coords = {}
                        for k in elem:
                            bbox_coords[i] = float(k.text)
                            i = i+1
                mid_x_y = [(bbox_coords[0] + bbox_coords[2]) / 2, (bbox_coords[1] + bbox_coords[3]) / 2]
                dpoints[label] = np.append(dpoints[label], mid_x_y[::-1], axis=0)
        for k in dpoints.keys():
            dpoints[k] = dpoints[k].reshape((-1, 2))
        return image_path,dpoints
        
                           
    def annotation_to_target(self,annot,h,w):
        target = np.ones((len(self.object_to_channel),h,w), dtype='float32')
        y = np.linspace(0,h - 1,num=h)
        x = np.linspace(0,w - 1,num=w)
        xx,yy = np.meshgrid(x,y)
        z = zip(yy.reshape(-1),xx.reshape(-1))
        coords = [coord for coord in z]
        for object_ in annot.keys():
            for coord in annot[object_]:
                rv = multivariate_normal([int(coord[0]),int(coord[1])], [[self.object_var[object_], 0], [0, self.object_var[object_]]])
                target[self.object_to_channel[object_],:,:] -= 6.3*self.object_var[object_]*rv.pdf(coords).reshape(h,w)
        
        target = torch.from_numpy(target)
        return target