In [1]:
# default_exp datagenerator

# datagenerator

> API details.

In [2]:
#hide
from nbdev.showdoc import *

In [3]:
#export
import tensorflow as tf
import numpy as np
from glob import glob
import os
import pathlib

from typing import Union

from chitra.image import read_image, resize_image

# Data generator
## components
> components are methods that can be easily overridden
- image path gen
- image label gen
- image resizer


> the generator object will also support  callbacks that can update the components 




In [4]:
#export
def get_filenames(root_dir):
        root_dir = pathlib.Path(root_dir)
        return glob(str(root_dir/'*'))
    
def get_label(filename):
    return filename.split('/')[-2]


In [5]:
#export
class ImageSizeList():
    def __init__(self, img_sz_list=[]):
        
        if type(img_sz_list) in (list, tuple):
            if not type(img_sz_list[0]) in (list, tuple):
                img_sz_list = [img_sz_list]
        
        self.start_size = None
        self.last_size = None
        self.curr_size = None
        self.img_sz_list = img_sz_list
        
        try:
            self.start_size = img_sz_list[0]
            self.last_size = img_sz_list[-1]
            self.curr_size = img_sz_list[0]
        except (IndexError, TypeError) as e:
            print('No item present in the image size list')
            self.curr_size = None # no item present in the list
        
          
    def get_size(self):
        img_sz_list = self.img_sz_list
        try:
            self.curr_size = img_sz_list.pop(0)
        except (IndexError, AttributeError) as e:
            print(f'Returning the last set size which is: {self.curr_size}')
        
        return self.curr_size

In [6]:
img_sz_list = ImageSizeList(None)
img_sz_list.get_size()

No item present in the image size list
Returning the last set size which is: None


In [10]:
#export
class Dataset():
    def __init__(self, root_dir, image_size=None):
        self.get_filenames = get_filenames
        self.read_image = read_image
        self.get_label = get_label
        
        self.root_dir = root_dir
        self.filenames = self.get_filenames(root_dir)
        self.num_files = len(self.filenames)
        self.img_sz_list= ImageSizeList(image_size)
        
        
    def __len__(self): return len(self.filenames)
    
    
    def _process(self, filename):
        image = self.read_image(filename)
        label = self.get_label(filename)
        return image, label
    
    
    def _reload(self):
        self.filenames  = self.get_filenames(self.root_dir)
        self.num_files = len(self.filenames)
    
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        return self._process(filename)
    
    def update_component(self, component_name, new_component, reload=True):
        setattr(self, component_name, new_component)
        print(f'{component_name} updated with {new_component}')
        self._reload()
        
    
    def generator(self,):
        img_sz = self.img_sz_list.get_size()
        n = len(self.filenames)
        for i in range(n):
            image, label = self.__getitem__(i)
            if img_sz:
                image = resize_image(image, img_sz)
            yield image, label
    

In [11]:
ds = Dataset('/data/aniket/tiny-imagenet/data/tiny-imagenet-200/train', image_size=[(96,96), (64,64)])

In [12]:
def load_files(path):
    return glob(f'{path}/*/images/*')

In [13]:
ds.update_component('get_filenames', load_files)

get_filenames updated with <function load_files at 0x7faf6b1d3320>


In [15]:
for e in ds.generator():
    print(e[0].shape, e[1])
    break

(64, 64, 3) images
