#%load_ext autoreload
%autoreload 2

In [None]:
#default_exp utils.data.dataset

# Dataset

A child of torch dataset.

CovidX dataset.

Load and split

In [None]:
#export
from PrimeCNNv3.utils.vizualize import *
from PrimeCNNv3.imports import *

In [None]:
#export
def get_dataset(filepath, seed=None):
    '''
        Args:
            
            filepath:
                path to the train/test file
                
        return:
            list containing the image name and label
    '''
    with open(filepath,'r') as file:
        cxr_list = np.array([line.rstrip('\n') for line in file])
    cxr_list.sort()
    
    if seed is not None:
        np.random.seed(seed)
        
    np.random.shuffle(cxr_list)
    
    return np.array(cxr_list)

In [None]:
#export
def get_csv_dataset(csv_filepath, seed=None):
    '''
        Args:
            
            filepath:
                path to the train/test file
                
        return:
            list containing the image name and label
    '''
    with open(csv_filepath,'r') as file:
        cxr_list = np.array([line.rstrip('\n') for line in file])
    
    #remove column heading
    cxr_list = cxr_list[1:]
    cxr_list.sort()
    
    if seed is not None:
        np.random.seed(seed)
        
    np.random.shuffle(cxr_list)
    
    return np.array(cxr_list)

In [None]:
#export
def get_train_val_split(train_list, valid_pct = 0.2, seed = None):
    '''
        Args:
        
            train_list:
                list containing the image and labels string
            
            valid_pct:
                default = 0.2
                valiation split ratio
                
            seed:
                default is None
                seed value for reproducibility
    '''
    content_list = train_list
    
    if seed is not None:
        np.random.seed(seed)
    
    idx = list(range(len(content_list)))
    np.random.shuffle(idx)
    
    valid_split = int(valid_pct * len(content_list))
    
    valid_idx = idx[valid_split:]
    train_idx = idx[:valid_split]
    
    train_list = content_list[valid_idx]
    valid_list = content_list[train_idx]
    
    return train_list, valid_list

In [None]:
#export
class CovidXDataset(Dataset):
    '''CovidXv5 Dataset'''
    
    def __init__(self, root_dir, data_list, seed = None, transform = None, MAX_VAL = 255.0, dtype = 'float32', image_mode = 'RGB'):
        '''
            Args:
            root_dir (Path): Directory with all the images.
            data_list (numpy array): list of images and labels
            seed (int): seed value
            transform (callable):  transform to be applied
                on a sample, required to convert np array to torch tensor
        '''
        self.seed = seed
        self.root_dir = root_dir
        self.transform = transform
        self.data_list = data_list
        self.CLASSES = {'normal' : 0, 'pneumonia' : 1, 'COVID-19' : 2}
        self.MAX_VAL = MAX_VAL
        self.dtype = dtype
        self.mode = image_mode
        
    def __getitem__(self, idx):
        
        #extract label and get int value
        label = self.CLASSES[self.data_list[idx].split()[2]]
        
        #extract imagename and join image with root_dir path 
        image_path = Path.joinpath(self.root_dir, self.data_list[idx].split()[1])
        
        
        image = np.array(Image.open(image_path).convert(self.mode)).astype(self.dtype)
        
        transform_seed = np.random.randint(2147483647)
        
        if self.transform:
            random.seed(transform_seed)
            augmented = self.transform(image=image)
            image = augmented['image']
           
        return image, label
        
    def __len__(self):
        return len(self.data_list)
        
    def show_images(self, n, figsize= (10,10), nrows = 1, ncols = None, rand = False):
        '''
            Shows n images with their labels
        '''
        if not rand:
            #get same images everytime
            random.seed(self.seed)
            
        indices = random.sample(range(0,self.__len__()), n)
        
        if ncols is None: 
            ncols = int(len(indices)/nrows)
                                      
        _,axs = plt.subplots(nrows, ncols, figsize = figsize)
       
                                      
        for idx, ax in zip(indices, axs.flatten()):
            label = self.data_list[idx].split()[2]
            
            image_path = Path.joinpath(self.root_dir, self.data_list[idx].split()[1])
            
            image = np.array(Image.open(image_path).convert('RGB'))
            
            show_image(image, ax = ax, title = label )
                        
                                      
        plt.tight_layout(True)                       
   
    def _get_Stats(self):
        '''
            Calculates number of samples in each of the class
            
            return dictionary
        '''
        class_dist = {key : 0 for key in self.CLASSES.keys()}
        for element in self.data_list:
            label = element.split()[2]
            class_dist[label] += 1 
        
        return class_dist
    def show_distribution(self, figsize = (5,5)):
        class_dist = self._get_Stats()
        _, ax = plt.subplots(1,1, figsize = figsize)
        
        ax.bar(class_dist.keys(), class_dist.values())
        
    def get_Weighted_RandomSampler(self,replacement = True, seed = 2147483647, use_generator = True):
        '''
            replacement: True:with or Fasle:without replacement
            
            seed value is not used if use_generator is false
            returns weightedRandomSampler for imbalance class
        '''
        class_distrib = list(self._get_Stats().values())
        class_weight = 1. / torch.as_tensor(class_distrib).float()
        
        sample_weight = [class_weight[self.CLASSES[self.data_list[idx].split()[2]]] for idx in range(self.__len__())]
        
        
        generator = torch.Generator().manual_seed(seed) if use_generator else None
        
        return WeightedRandomSampler(weights = sample_weight, num_samples = len(sample_weight), 
                                     replacement = replacement, generator = generator)
        

In [None]:
#export
class CassavaLeafDataset(Dataset):
    '''CassavaLeafDataset Dataset'''
    
    def __init__(self, root_dir, data_list, seed = None, transform = None, MAX_VAL = 255.0, dtype = 'float32', image_mode = 'RGB'):
        '''
            Args:
            root_dir (Path): Directory with all the images.
            data_list (numpy array): list of images and labels
            seed (int): seed value
            transform (callable):  transform to be applied
                on a sample, required to convert np array to torch tensor
        '''
        self.seed = seed
        self.root_dir = root_dir
        self.transform = transform
        self.data_list = data_list
       
        self.CLASSES = {0:'Cassava Bacterial Blight (CBB)', 
                        1: 'Cassava Brown Streak Disease (CBSD)', 
                        2: 'Cassava Green Mottle (CGM)',
                        3: 'Cassava Mosaic Disease (CMD)',
                        4: 'Healthy'
                       }
        self.MAX_VAL = MAX_VAL
        self.dtype = dtype
        self.mode = image_mode
        
    def __getitem__(self, idx):
        
        #extract label and get int value
        label = self.CLASSES[self.data_list[idx].split(',')[1]]
        
        #extract imagename and join image with root_dir path 
        image_path = Path.joinpath(self.root_dir, self.data_list[idx].split(',')[0])
        
        
        image = np.array(Image.open(image_path).convert(self.mode)).astype(self.dtype)
        
        transform_seed = np.random.randint(2147483647)
        
        if self.transform:
            random.seed(transform_seed)
            augmented = self.transform(image=image)
            image = augmented['image']
           
        return image, label
        
    def __len__(self):
        return len(self.data_list)
        
    def show_images(self, n, figsize= (10,10), nrows = 1, ncols = None, rand = False):
        '''
            Shows n images with their labels
        '''
        if not rand:
            #get same images everytime
            random.seed(self.seed)
            
        indices = random.sample(range(0,self.__len__()), n)
        
        if ncols is None: 
            ncols = int(len(indices)/nrows)
                                      
        _,axs = plt.subplots(nrows, ncols, figsize = figsize)
       
                                      
        for idx, ax in zip(indices, axs.flatten()):
            label = self.data_list[idx].split(',')[1]
            
            image_path = Path.joinpath(self.root_dir, self.data_list[idx].split(',')[0])
            
            image = np.array(Image.open(image_path).convert('RGB'))
            
            show_image(image, ax = ax, title = label )
                        
                                      
        plt.tight_layout(True)                       
   
    def _get_Stats(self):
        '''
            Calculates number of samples in each of the class
            
            return dictionary
        '''
        class_dist = {key : 0 for key in self.CLASSES.keys()}
        for element in self.data_list:
            label = element.split(',')[1]
            class_dist[int(label)] += 1 
        
        return class_dist
    def show_distribution(self, figsize = (5,5)):
        class_dist = self._get_Stats()
        _, ax = plt.subplots(1,1, figsize = figsize)
        
        ax.bar(class_dist.keys(), class_dist.values())
        
    def get_Weighted_RandomSampler(self,replacement = True, seed = 2147483647, use_generator = True):
        '''
            replacement: True:with or Fasle:without replacement
            
            seed value is not used if use_generator is false
            returns weightedRandomSampler for imbalance class
        '''
        class_distrib = list(self._get_Stats().values())
        class_weight = 1. / torch.as_tensor(class_distrib).float()
        
        sample_weight = [class_weight[self.CLASSES[self.data_list[idx].split(',')[1]]] for idx in range(self.__len__())]
        
        
        generator = torch.Generator().manual_seed(seed) if use_generator else None
        
        return WeightedRandomSampler(weights = sample_weight, num_samples = len(sample_weight), 
                                     replacement = replacement, generator = generator)
        

In [None]:
from nbdev.export import notebook2script; notebook2script()

Converted 00_utils.data.dataset.ipynb.
Converted 01_utils.data.dataloaders.ipynb.
Converted 02_utils.vizualize.ipynb.
Converted 03_callbacks.ipynb.
Converted 04_learner.ipynb.
Converted 05_metrics.ipynb.
Converted index.ipynb.
