In [2]:
import cv2
from pathlib import Path
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader, Sampler
import xml.etree.ElementTree as ET
import torch
import math
from random import shuffle

### Load image arrays only when needed, too much memory to load everything on initialization

In [3]:
class PASCALVOC(Dataset):
    def __init__(self,textfile ="train_val.txt",transform = None):
       # assert Path(rootfolder).exists(), "%s is an invalid path"%rootfolder

        self.classes = ('__background__','aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        
        self.sizes = (645,429,285)
        self.all_img_names = []      
        
        self.rootfolder = "C:\\Users\\bing.DEFIDE\\Documents\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012"

        listofimages = "C:\\Users\\bing.DEFIDE\\Documents\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\ImageSets\\Main\\" + textfile
        
        file = open(listofimages,"r")
        for line in file:
            self.all_img_names.append((line.split(" "))[0])  
        file.close()

            
    def __getitem__(self,x):
        imgname = self.all_img_names[x]
        imagepath = self.rootfolder+ "\\JPEGImages\\" + imgname + ".jpg"
        annopath =  self.rootfolder+ "\\Annotations\\" + imgname + ".xml"
        
        img = cv2.imread(imagepath)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        resize_shape = self.eulc_dist(img,imgname)
        img = cv2.resize(img,(resize_shape,resize_shape))
        
        stepone = ET.parse(annopath)
        steptwo = stepone.findall('object')
        stepthree = steptwo[0].find('name')
        label = stepthree.text

        return (img,label)
    
    def __len__(self):
        return len(self.all_img_names)

    def __repr__(self):
        return ""
    
    def eulc_dist(self,img,imgname): #calc eulc dist and return the size to reshape to
        all_dist = []
        for x in self.sizes:
            vert_dist =  x - img.shape[0] 
            hori_dist = x - img.shape[1]
            dist = math.sqrt((vert_dist*vert_dist) + (hori_dist*hori_dist))
            all_dist.append(dist)
        
     #   if len(set(all_dist)) != len(self.sizes): #incase two exact same euclidean distances, which actually exists in pascal voc.. image res 314,400, equiv dist from 429 and 285
      #      raise Exception(imgname + " has two same euclidean distances computed")
        #in the event of equiv eucl dist, favor the larger number for now
        idx_of_smallest = all_dist.index(min(all_dist))
        return self.sizes[idx_of_smallest]
        

In [4]:
class Custom_RandomSampler(Sampler):
    def __init__(self,data_src):
        self.data_src = data_src
        
        self.len645 = None
        self.len429 = None
        self.len285 = None

    def __iter__(self):     
        
        list645 = []
        list429 = []
        list285 = []
        finallist = []
        for x in range(len(self.data_src)):
            if self.data_src[x][0].shape[1] == 645:
                list645.append(x)
            
            elif self.data_src[x][0].shape[1] == 429:
                list429.append(x)
            
            elif self.data_src[x][0].shape[1] == 285:
                list285.append(x) 
        
        
        self.len645 = len(list645)
        self.len429 = len(list429)
        self.len285 = len(list285)
        
        shuffle(list645)
        shuffle(list429)
        shuffle(list285)
        
        if self.len645 != 0:
            finallist.extend(list645)
        if self.len429 != 0:
            finallist.extend(list429)
        if self.len285 !=0:
            finallist.extend(list285)
        return iter(finallist)
    def __len__(self):
        return len(self.data_src)
        
        
                

In [13]:
class Custom_BatchSampler(Sampler):
    def __init__(self,sampler,batch_size,drop_last = True):
    #    if not isinstance(sampler, Sampler):
     #       raise ValueError("sampler should be an instance of "
        #                     "torch.utils.data.Sampler, but got sampler={}"
       #                      .format(sampler))
     #   if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
      #          batch_size <= 0:
        if not isinstance(batch_size, int) or batch_size <= 0:
            raise ValueError("batch_size should be a positive integral value, "
                             "but got batch_size={}".format(batch_size))
            
        if drop_last != True:
            raise Exception("drop_last input={} is invalid. Only drop_last = True is implemented for this function".format(drop_last))
        self.sampler = sampler
        
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch =[]
        #slicing the list is probably easier and cleaner though and shouldnt take much memory
        for count,value in enumerate(self.sampler):
            #Assign only after calling sampler's iter method.
            self.len645 = self.sampler.len645
            self.len429 = self.sampler.len429
            self.len285 = self.sampler.len285
            
            if count<self.len645:
                batch.append(value)
                if len(batch) == self.batch_size:
                    yield batch
                    batch=[]
            elif count == self.len645-1:
                batch.append(value)
                if len(batch) == self.batch_size:
                    yield batch
                    batch =[]
                else:
                    batch = []
                    continue
            
            elif self.len645<=count< self.len429 -1:
                batch.append(value)
                if len(batch) == self.batch_size:
                    yield batch
                    batch =[]
            elif count == self.len429 -1:
                batch.append(value)
                if len(batch) == self.batch_size:
                    yield batch
                    batch =[]
                else:
                    batch =[]
                    continue
            
            elif count>= self.len285:
                batch.append(value)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
    

    def __len__(self): 
        if self.drop_last:
            length = (self.len645//self.batch_size) + (self.len429//self.batch_size) + (self.len285//self.batch_size)
            return length
        else: #not gonna use this option, not modifying
            raise Exception("drop_last == False is not implemented for this function")
            #return (len(self.sampler) + self.batch_size - 1) // self.batch_size

        

In [8]:
hi = PASCALVOC()

In [9]:
rdm = Custom_RandomSampler(hi)

In [7]:
rdm.len645

In [14]:
batch = Custom_BatchSampler(rdm,8,drop_last = False)

Exception: drop_last input=False is invalid. Only drop_last = True is implemented for this function

In [44]:
len(batch)

RandomSampler Len is called


727

In [39]:
def dummy(x):
    for x in range(x):
        return x
      #  print (type(x[0]))
    

In [40]:
dummy(19)

0

In [42]:
traindata = DataLoader(hi,batch_sampler = Custom_BatchSampler(Custom_RandomSampler(hi),8))

RandomSampler Init is called
BatchSampler Init is called


BatchSampler Iter is called
RandomSampler Iter is called
[tensor([[[[114,  98,  81],
          [116,  99,  66],
          [105,  87,  54],
          ...,
          [151, 130, 121],
          [150, 136, 126],
          [156, 149, 141]],

         [[118,  94,  79],
          [ 98,  83,  47],
          [ 82,  58,  30],
          ...,
          [150, 131, 122],
          [151, 138, 129],
          [158, 151, 144]],

         [[107,  90,  77],
          [ 78,  55,  39],
          [ 62,  34,  20],
          ...,
          [147, 131, 122],
          [150, 139, 131],
          [158, 151, 145]],

         ...,

         [[233, 174, 130],
          [235, 175, 131],
          [238, 178, 136],
          ...,
          [211, 121,  41],
          [216, 124,  46],
          [214, 122,  44]],

         [[228, 174, 128],
          [235, 181, 135],
          [236, 180, 136],
          ...,
          [213, 122,  43],
          [214, 122,  45],
          [209, 117,  40]],

         [[235, 172, 131],
     