In [4]:
import os
import sys
from loguru import logger
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import cv2
import pickle
from random import shuffle
import progressbar
import time

In [5]:
# logger.add(sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO")
logger.add(sys.stdout, colorize=True, format="<green>{time}</green> <level>{message}</level>")

2

In [6]:
#UTIL FUNCTIONS
class utils:
    def __init__(self, logs):
        self.logs = logs

    def splitpath(self,p):
        return p.split('/')[1]

    def read_img(self,path, size = None):
        if size is None:
                return plt.imread(path)
        return cv2.resize(plt.imread(path), size)

    def shuffle_labeled_data(self,imgs, lbls):
        combined = list(zip(imgs, lbls))
        shuffle(combined)
        imgs[:], lbls[:] = zip(*combined)
        return imgs, lbls

    def export_pickle(self, imgs, lbls, export_pickle_name):
        with open(export_pickle_name, 'wb') as handle:
            pickle.dump(zip(imgs, lbls), handle, protocol=pickle.HIGHEST_PROTOCOL)
            self.log_messages("Data is saved in pickle file with name {}".format(export_pickle_name), self.logs)

    def import_pickle(self, import_pickle_name):
        print(import_pickle_name)
        imgs = []
        lbls = []
        with open(import_pickle_name, 'rb') as handle:
            imgs[:], lbls[:] = zip(*pickle.load(handle))
            self.log_messages("Total {} images and labels loaded from pickle file.".format(len(imgs)), self.logs)
            return imgs, lbls

    def log_messages(self, msg, log=False):
        if(log):
            logger.log(10, msg)

In [7]:
UTIL = utils(True)
def get_img_data_gen(classes_path, classes, get_path_only, shuffle_data, img_format="png", 
                 resize = None, batch_size = 32):
    imgs = []
    lbls = []

    for cp, c in zip(classes_path, classes):
        imgs_path = glob.glob("{}/*.{}".format(cp,img_format))
        imgs.extend(imgs_path)
        lbls.extend([c] * len(imgs_path))
    
    if(shuffle_data):
        imgs, lbls = UTIL.shuffle_labeled_data(imgs, lbls)

    total_batches = len(imgs) / batch_size
    for i in range(int(total_batches)):
        imgs_y = (UTIL.read_img(x, resize) for x in imgs[batch_size*i : batch_size*(i+1)])
        lbls_y = lbls[batch_size*i : batch_size*(i+1)]
        
    yield zip(imgs_y, lbls_y)

def progressBar(i, x, resize, bar):
    bar.update(i)
    return UTIL.read_img(x, resize)

def get_img_data(classes_path, classes, get_path_only, shuffle_data, img_format="png", 
                 resize = None, batch_size = 32, return_generator = False):
    imgs = []
    lbls = []

    for cp, c in zip(classes_path, classes):
        imgs_path = glob.glob("{}/*.{}".format(cp,img_format))
        imgs.extend(imgs_path)
        lbls.extend([c] * len(imgs_path))
    
    if(shuffle_data):
        imgs, lbls = UTIL.shuffle_labeled_data(imgs, lbls)
    
    if get_path_only:
        return imgs, lbls
    progress = progressbar.ProgressBar()
            
    imgs = [UTIL.read_img(x, resize) for x in progress(imgs)] #list(map(read_img, imgs))
    
#     with progressbar.ProgressBar(maxval=len(imgs)) as bar:
#         imgs = [progressBar(i,x,resize, bar) for i, x in enumerate(imgs)] #list(map(read_img, imgs))
    return imgs, lbls

In [48]:
### GENERATED SIZE OF TFRECORD IS LARGE BECAUSE WE USED INT64...

def get_img_data_tfrecord(classes_path, labels, resize = (128,128), shuffle_data = True, img_format='.jpg'):
    
    import tensorflow as tf
    
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    with tf.io.TFRecordWriter('tfrecord.tfrecord') as writer:
        imgs = []
        lbls = []
        progress = progressbar.ProgressBar()
        for cp, c in zip(classes_path, labels):
            imgs_path = glob.glob("{}/*.{}".format(cp,img_format))
            imgs.extend(imgs_path)
            lbls.extend([c] * len(imgs_path))
        
        for index in progress(range(len(imgs_path))):
            image_raw = UTIL.read_img(imgs[index])
            images = image_raw.tostring()
            
            rows = image_raw.shape[0]
            cols = image_raw.shape[1]
            depth = image_raw.shape[2]

            example = tf.train.Example(
              features=tf.train.Features(
                  feature={
                      'height': _int64_feature(rows),
                      'width': _int64_feature(cols),
                      'depth': _int64_feature(depth),
                      'label': _int64_feature(int(lbls[index])),
                      'image_raw': _bytes_feature(images)
                  }))
            writer.write(example.SerializeToString())    

In [51]:
def dataset_load(dataset_dir, class_source, csv_path = None, get_path_only = False, img_format='jpg', 
                 resize = None, export_pickle_name = 'data.pickle', shuffle_data = True, log_msgs = True,
                 batch_size = 32, return_generator = False, save_tfrecord = False):
    classes = []
    imgs = []
    start_time = time.time()
    if class_source == 'FOLDER':
        classes_path = glob.glob(dataset_dir+'/*')
        classes = list(map(UTIL.splitpath, classes_path))
        
        if(return_generator):
            generator_x = get_img_data_gen(classes_path, classes, get_path_only, img_format=img_format, resize = resize, shuffle_data = shuffle_data, batch_size=batch_size)
            UTIL.log_messages("Total time taken {}".format(int(time.time() - start_time)))
            return generator_x
        
        if(save_tfrecord):
            tfrecord = get_img_data_tfrecord(classes_path, classes, get_path_only, img_format=img_format, shuffle_data = shuffle_data)
            UTIL.log_messages("Saved tfrecord file", True)
            return 0,0
        
        imgs, lbls = get_img_data(classes_path, classes, get_path_only, img_format=img_format, resize = resize, shuffle_data = shuffle_data)
        UTIL.log_messages("Total time taken {}".format(int(time.time() - start_time)), True)
        
        if export_pickle_name is not None:
            UTIL.export_pickle(imgs, lbls, export_pickle_name)
        UTIL.log_messages("Total time taken {}".format(int(time.time() - start_time)))
        return imgs, lbls
    
    if class_source == 'CSV':
        assert (csv_path is not None), "Please provide a valid CSV file path"
        
        df = pd.read_csv(csv_path)
        classes = list(df["class"])
        imgs_path = list(df["path"])
        if(shuffle_data):
            imgs, lbls = UTIL.shuffle_labeled_data(imgs_path, classes)
            
        imgs = [UTIL.read_img(x, size=resize) for x in imgs_path] #map(read_img, list(df["path"]))       
        
        UTIL.log_messages("Total images {}".format(len(imgs)), log_msgs)    

        if export_pickle_name is not None:
            UTIL.export_pickle(imgs, lbls, export_pickle_name)
        UTIL.log_messages("Total time taken {}".format(time.time() - start_time))
        return imgs, lbls
    
def dataset_load_pickle(path, shuffle = True):
    imgs, lbls = UTIL.import_pickle(path)
    if shuffle:
        imgs, lbls = UTIL.shuffle_labeled_data(imgs, lbls)
    return imgs,lbls


In [52]:
# %%time
imgs, lbls = dataset_load('dataset/', 'FOLDER',get_path_only = False, img_format='jpg', save_tfrecord=True)

100% |########################################################################|


In [7]:
# dataset_load_pickle('data.pickle')