In [None]:
# default_exp datagenerator

# datagenerator

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import tensorflow as tf
import numpy as np
from glob import glob
import os
import pathlib

from typing import Union

from chitra.image import read_image, resize_image

# Data generator
## components
> components are methods that can be easily overridden
- image path gen
- image label gen
- image resizer


> the generator object will also support  callbacks that can update the components 




In [None]:
#export
def get_filenames(root_dir):
        root_dir = pathlib.Path(root_dir)
        return glob(str(root_dir/'*'))
    
def get_label(filename):
    return filename.split('/')[-2]


In [None]:
#export
class ImageSizeList():
    def __init__(self, img_sz_list=[]):
        
        if type(img_sz_list) in (list, tuple):
            if not type(img_sz_list[0]) in (list, tuple):
                img_sz_list = [img_sz_list]
        
        self.start_size = None
        self.last_size = None
        self.curr_size = None
        self.img_sz_list = img_sz_list
        
        try:
            self.start_size = img_sz_list[0]
            self.last_size = img_sz_list[-1]
            self.curr_size = img_sz_list[0]
        except (IndexError, TypeError) as e:
            print('No item present in the image size list')
            self.curr_size = None # no item present in the list
        
          
    def get_size(self):
        img_sz_list = self.img_sz_list
        try:
            self.curr_size = img_sz_list.pop(0)
        except (IndexError, AttributeError) as e:
            print(f'Returning the last set size which is: {self.curr_size}')
        
        return self.curr_size

In [None]:
img_sz_list = ImageSizeList(None)
img_sz_list.get_size()

No item present in the image size list
Returning the last set size which is: None


In [None]:
#export
class LabelEncoder():
    def __init__(self, labels):
        self.labels = labels
        self.label_to_idx = {label: i for i, label in enumerate(self.labels)}
        
    def encode(self, label):
        return self.label_to_idx[label]

In [None]:
#export
class Dataset():
    MAPPINGS = {
        'PY_TO_TF': {str:tf.string, int:tf.int32, float:tf.float32},
        
        }
    
    def __init__(self, root_dir, image_size=None, transforms=None, label_encoder=None):
        self.get_filenames = get_filenames
        self.read_image = read_image
        self.get_label = get_label
        self.label_encoder = label_encoder
        self.transforms = transforms
        
        self.root_dir = root_dir
        self.filenames = self.get_filenames(root_dir)
        self.num_files = len(self.filenames)
        self.img_sz_list= ImageSizeList(image_size)
        
        
    def __len__(self): return len(self.filenames)
    
    
    def _process(self, filename):
        image = self.read_image(filename)
        label = self.get_label(filename)
        return image, label
    
    
    def _reload(self):
        self.filenames  = self.get_filenames(self.root_dir)
        self.num_files = len(self.filenames)
        
    def _capture_return_types(self):
        return_types = []
        for e in self.generator():
            outputs = e
            break
        if isinstance(outputs, tuple):
            for ret_type in outputs:
                return_types.append(
                    ret_type.dtype if tf.is_tensor(ret_type) else Dataset.MAPPINGS['PY_TO_TF'][type(ret_type)]
                )
        else:
            return_types.append(
                ret_type.dtype if tf.is_tensor(ret_type) else Dataset.MAPPINGS['PY_TO_TF'][type(ret_type)]
            )
        return tuple(return_types)
    
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        return self._process(filename)
    
    def update_component(self, component_name, new_component, reload=True):
        setattr(self, component_name, new_component)
        print(f'{component_name} updated with {new_component}')
        self._reload()
        
    
    def generator(self,):
        img_sz = self.img_sz_list.get_size()
        n = len(self.filenames)
        for i in range(n):
            image, label = self.__getitem__(i)
            if img_sz: image = resize_image(image, img_sz)
            if self.transforms: image = self.transforms(image)
            yield image, label
    
    
    def get_tf_dataset(self, output_shape=None):
        datagen = tf.data.Dataset.from_generator(
            self.generator,
            self._capture_return_types(),
            output_shape

        )
        return datagen

In [None]:
ds = Dataset('/data/aniket/tiny-imagenet/data/tiny-imagenet-200/train', image_size=[(96,96), (64,64)])

In [None]:
def load_files(path):
    return glob(f'{path}/*/images/*')

def get_label(path):
    return path.split('/')[-3]
    

In [None]:
ds.update_component('get_filenames', load_files)
ds.update_component('get_label', get_label)

get_filenames updated with <function load_files at 0x7faffca3e3b0>
get_label updated with <function get_label at 0x7faff0ca7560>


In [None]:
for e in ds.generator():
    print(e[0].dtype, e[1])
    break

<dtype: 'float32'> n03584254


In [None]:
def get_tf_dataset(ds):
    datagen = tf.data.Dataset.from_generator(
                    ds.generator,
                    ds._capture_return_types()
    )
    return datagen

In [None]:
get_tf_dataset(ds)

Returning the last set size which is: (64, 64)


<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.string)>