In [2]:
import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
import pickle
import copy
import tensorflow as tf

In [18]:
DATA_PATH ='data'
PASCAL_PATH = os.path.join(DATA_PATH, 'pascal_voc')
CACHE_PATH = os.path.join(PASCAL_PATH, 'cache')

In [95]:
class dataset_pascal_voc(object):
    def __init__(self, phase, rebuild=False):
        self.devkil_path = os.path.join(PASCAL_PATH, 'VOCdevkit')
        self.data_path = os.path.join(self.devkil_path, 'VOC2007')
        self.cache_path = CACHE_PATH
        self.batch_size = 20
        self.image_size = 448
        self.cell_size = 7
        self.classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
           'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
           'train', 'tvmonitor']
        self.class_to_ind = dict(zip(self.classes, range(len(self.classes)))) #create dict where keys = labels
        self.flipped = True
        self.phase = phase
        self.rebuild = rebuild
        self.cursor = 0
        self.epoch = 1
        self.labels_got = None
        self.prepare()
        
    def get(self):
        X_img = np.zeros((self.batch_size, self.image_size, self.image_size, 3))
        Y_labels = np.zeros((self.batch_size, self.cell_size, self.cell_size, 25))
        count_batch = 0
        while count_batch < self.batch_size:
            img_name = self.labels_got[self.cursor]['imname']
            flipped = self.labels_got[self.cursor]['flipped']
            X_img[count_batch, :,:,:] = self.read_image(img_name,flipped)
            Y_labels[count_batch, :,:,:] = self.labels_got[self.cursor]['label']
            count_batch +=1
        
        return X_img, Y_labels
        
            
    def read_image(self, img_name, flipped=False):
        image = cv2.imread(img_name)
        # resize image to 448 , 448
        image = cv2.resize(image, (self.image_size, self.image_size))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = (image /255.0) *2 -1.0
        if flipped:
            image = image[:,::-1,:]        
        return image
    
    def prepare(self):
        print("In prepare")
        # contains list of dict where each dicts where each dict contain file path, 3D labels & if flipped or not
        labels_got = self.get_labels()
        #print(labels_got[0:10])
        #print(labels_got[0]['label'])
        if self.flipped:
            print('Appending horizontally-flipped training examples ...')
            labels_copy = copy.deepcopy(labels_got)
            added_labels = self.data_augment(labels_got,labels_copy)
        np.random.shuffle(added_labels)
        self.labels_got = added_labels
        print("labels len", len(added_labels))
        return labels_got
        
    
    def data_augment(self, orig_labels, labels_copy):
        print("Create flipped data")
        for index in range(len(labels_copy)):
            labels_copy[index]['flipped'] = True
            labels_copy[index]['label'] = labels_copy[index]['label'][:,::-1,:]
            
            for i in range(self.cell_size):
                for j in range(self.cell_size):
                    if labels_copy[index]['label'][i,j,0] == 1:
                        #print(labels_copy[index]['label'][i,j,2])
                        labels_copy[index]['label'][i,j,1] = self.image_size-1-labels_copy[index]['label'][i,j,1]
                        #print(labels_copy[index]['label'][i,j,2])
        orig_labels+= labels_copy
        return orig_labels
        
        
    def get_labels(self):
        # getting file containing data
        cache_file = os.path.join(self.cache_path,'pascal_'+self.phase+'_gt_labels.pkl')
        print(self.cache_path)
        if(os.path.isfile(cache_file) and not self.rebuild):
            print("Getting labels from "+ cache_file)
            with open(cache_file, 'rb') as labels_file:
                labels_got = pickle.load(labels_file)
                
            print("original labels length :", len(labels_got))
            return labels_got

In [96]:
pascal_dataset = dataset_pascal_voc('train')

In prepare
data/pascal_voc/cache
Getting labels from data/pascal_voc/cache/pascal_train_gt_labels.pkl
original labels length : 5011
Appending horizontally-flipped training examples ...
Create flipped data
labels len 10022


In [97]:
imgs,lbls = pascal_dataset.get()


In [98]:
print(imgs.shape)
print(lbls.shape)

(20, 448, 448, 3)
(20, 7, 7, 25)


In [None]:
# Just testing code snippets

#------------------------------------------------

# flipping image horizontally. Uncomment to test
#from matplotlib import pyplot as plt
#image = cv2.imread("cat2.jpeg")
#print(image.shape)
#plt.imshow(image)
#plt.show()
#print("*"*30)
#img_flip = image[:,::-1,:]
#print(img_flip.shape)
#plt.imshow(img_flip)
#plt.show()

#------------------------------------------------
