In [21]:
import torch
from torch.utils.data import DataLoader,Dataset,ConcatDataset
import glob
from PIL import Image
import torchvision.transforms.functional as TF
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms

In [22]:
class NimbroSegmentDataset(Dataset):
    def __init__(self,image_dir,image_size):
        
        # loading input image paths
        self.input_files = glob.glob(image_dir+'image/*.*')
        
        # transforms for normailsation and resize
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        self.resize_inp = transforms.Compose([transforms.ToPILImage(),transforms.Resize(image_size),transforms.ToTensor()])
        self.resize_op = transforms.Compose([transforms.ToPILImage(),
                                             transforms.Resize((image_size[0]//4,image_size[1]//4),interpolation=Image.NEAREST),
                                             transforms.ToTensor()])
        
        # cache for saving target files
        self.data = [ i for i in range(len(self.input_files))]
        
    def __getitem__(self,index):
        f = self.input_files[index]
        
        # if file is present in cache
        if type(self.data[index]) != int:
            target_ = self.data[index]
        else:
            target_path = f.replace('image','target').replace('jpg','png')
            target = Image.open(target_path)
            target = TF.to_tensor(target)
            target = self.resize_op(target)
            
            # extracting unique values from the tensor 
            unique_vals = target.unique()
            unique_vals.sort()
            
            # assuming everything to be field
            target_ = torch.ones_like(target)
            
            # line values come 2nd from the end 
            target_[target == unique_vals[-2]] = 2
            
            # smallest value is for background
            target_[target == unique_vals[0]] = 0
            
            target_ = target_.squeeze()
            self.data[index] = target_ # saving the image in cache


        x = Image.open(f)
        x = TF.to_tensor(x)     
        x = self.resize_inp(x)
        x = self.normalize(x)
        
        return x,target_
    def __len__(self):
        return len(self.input_files)