In [1]:
import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import cv2
import pickle
from random import shuffle

In [2]:
#UTIL FUNCTIONS
class utils:
    def __init__(self, logs):
        self.logs = logs

    def splitpath(self,p):
        return p.split('/')[1]

    def read_img(self,path, size = None):
        if size is None:
                return plt.imread(path)
        return cv2.resize(plt.imread(path), size)

    def shuffle_labeled_data(self,imgs, lbls):
        combined = list(zip(imgs, lbls))
        shuffle(combined)
        imgs[:], lbls[:] = zip(*combined)
        return imgs, lbls

    def export_pickle(self, imgs, lbls, export_pickle_name):
        with open(export_pickle_name, 'wb') as handle:
            pickle.dump(zip(imgs, lbls), handle, protocol=pickle.HIGHEST_PROTOCOL)
            self.log_messages("Data is saved in pickle file with name {}".format(export_pickle_name), self.logs)

    def import_pickle(self, import_pickle_name):
        print(import_pickle_name)
        with open(import_pickle_name, 'rb') as handle:
            imgs[:], lbls[:] = zip(*pickle.load(handle))
            self.log_messages("Total {} images and labels loaded from pickle file.".format(len(imgs)), self.logs)
            return imgs, lbls

    def log_messages(self, msg, log=False):
        if(log):
            print(msg)

In [3]:
UTIL = utils(True)

def get_img_data(classes_path, classes, get_path_only, shuffle_data, img_format="png", resize = None, batch_size = 32, return_generator = False):
    imgs = []
    lbls = []
    for cp, c in zip(classes_path, classes):
        imgs_path = glob.glob("{}/*.{}".format(cp,img_format))
        imgs.extend(imgs_path)
        lbls.extend([c] * len(imgs_path))
    if(shuffle_data):
        imgs, lbls = UTIL.shuffle_labeled_data(imgs, lbls)
    
    if get_path_only:
        return imgs, lbls
    if return_generator:
        total_batches = len(imgs) / batch_size
        for i in range(int(total_batches)):
            imgs_y = (UTIL.read_img(x, resize) for x in imgs[batch_size*i : batch_size*(i+1)])
            lbls_y = lbls[batch_size*i : batch_size*(i+1)]
            yield zip(imgs_y, lbls_y)
    imgs = [UTIL.read_img(x, resize) for x in imgs] #list(map(read_img, imgs))
    return imgs, lbls

In [5]:

def dataset_load(dataset_dir, class_source, csv_path = None, get_path_only = False, img_format='jpg', resize = None, export_pickle_name = 'data.pickle', shuffle_data = True, log_msgs = True, batch_size = 32, return_generator = True):
    classes = []
    imgs = []
    
    if class_source == 'FOLDER':
        classes_path = glob.glob(dataset_dir+'/*')
        classes = list(map(UTIL.splitpath, classes_path))
        
        if(return_generator):
            generator_x = get_img_data(classes_path, classes, get_path_only, img_format=img_format, resize = None, shuffle_data = shuffle_data, batch_size=batch_size, return_generator=return_generator)
            return generator_x
    
        imgs, lbls = get_img_data(classes_path, classes, get_path_only, img_format=img_format, resize = None, shuffle_data = shuffle_data)
        UTIL.log_messages("Total images {}".format(len(imgs)), log_msgs)
        
        if export_pickle_name is not None:
            UTIL.export_pickle(imgs, lbls, export_pickle_name)

        return imgs, lbls
    
    if class_source == 'CSV':
        assert (csv_path is not None), "Please provide a valid CSV file path"
        
        df = pd.read_csv(csv_path)
        classes = list(df["class"])
        imgs_path = list(df["path"])
        if(shuffle_data):
            imgs, lbls = UTIL.shuffle_labeled_data(imgs_path, classes)
            
        imgs = [UTIL.read_img(x, size=resize) for x in imgs_path] #map(read_img, list(df["path"]))       
        
        UTIL.log_messages("Total images {}".format(len(imgs)), log_msgs)    

        if export_pickle_name is not None:
            UTIL.export_pickle(imgs, lbls, export_pickle_name)

        return imgs, lbls
    
def dataset_load_pickle(path, shuffle = True):
    imgs, lbls = UTIL.import_pickle(path)
    if shuffle:
        imgs, lbls = UTIL.shuffle_labeled_data(imgs, lbls)
    return imgs.lbls


In [6]:
gen = dataset_load('output/', 'FOLDER', get_path_only = False, img_format='png')

In [None]:
dataset_load_pickle('data.pickle')    

In [10]:
v = next(gen)
a, b = zip(*v)

In [11]:
print(b)

('6', '7', '4', '7', '5', '2', '6', '6', '0', '1', '5', '0', '6', '5', '0', '2', '3', '0', '4', '0', '5', '5', '1', '7', '6', '3', '7', '4', '3', '1', '3', '1')
