In [1]:
import torch
from torch.utils.data import Dataset

import pickle

import numpy as np

from sklearn.model_selection import train_test_split

from glob import glob

import import_ipynb

from tools.ToolBox import json_loader

importing Jupyter notebook from ToolBox.ipynb


In [2]:
def angle_classifier(angle, n_class):
    """ Function that transform angle (in degree) into class
    Arguments:
        -> angle: a angle in degree
        -> n_class: the number of expected class
    """
    return int(angle // (360 / n_class))

In [3]:
def dataset_aggregator(path = None, problem = None, n_classes = 1):
    """Arguments
        -> path: the path to a folder where .json files are located
        -> problem: the nature of the problem. 'classification' or 'regression' regression (None) per default
    """
    
    # serach all given .json in a path 
    if path == None:
        file_list = glob('./*.json')
    else:
        file_list = glob(path + '*.json')

    angle_db = []
    img_db = []

    count = 1
    
    # for each .json
    for a_file in file_list:
        data = json_loader(a_file)
        
        # append each centriole and each corresponding angle to separate list
        try:
            for a_centriole in data[list(data.keys())[0]]:
                if problem == 'classification':
                    angle_db.append(angle_classifier(a_centriole['angle'], n_classes ))
                else:
                    angle_db.append(a_centriole['angle'])

                img_db.append(a_centriole['image'])
        except:
            print('file note treated, data key: {}, file name: {}\n'.format(data.keys(), file_name))
            
        print('File treated: {}/{} , Current file: {}\n'.format(count, len(file_list), list(data.keys())[0]))
        count += 1

    # transform and reshape the img dataset in array    
    img_db = np.array(img_db, dtype = 'double')
    img_db = img_db.reshape(img_db.shape[0], 1, img_db.shape[1], img_db.shape[2])  

    # return an array containing centrioles images and a list containing angle
    return img_db, angle_db

In [4]:
class centriole_dataset(Dataset):
    def __init__(self, img_db, angle_db, transform = None, root_dir = None, problem = None):
        self.root_dir = root_dir
        self.img_db = img_db
        self.angle_db = angle_db
        self.transform = transform
        self.problem = problem
           
    def __len__(self):
        return len(self.angle_db)
    
    def __getitem__(self, idx):
        
        img = self.img_db[idx]
        angle = self.angle_db[idx]
        
        if self.problem == 'classification':
            angle = np.array(angle, dtype = 'int') # [angle]
        else:
            angle = np.array(angle, dtype = 'double')
        
        sample = {'image': img, 'angle': angle}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

In [5]:
def dataset_creator(path_json = './data_json', batch_size = 700, n_class = 72, save_dataset = False):
    """ a function that create a dataset from appropriate json file
    Arguments:
        -> path_json  path to the json file
        -> batch_size: size of the batch
        -> n_class: number of class for classification problem 
        -> save_dataset: save the created dataset in ./weight/ directory
        
    The created datasets as property: 
        test_size = 0.2, 
        shuffle = True (for train) 
                  False (for validation)
        drop_last = True (for both)
    """
    
    if n_class != 1:
        problem = 'classification'
        
    else:
        problem  = 'regression'
        
    img_db, angle_db = dataset_aggregator(path = path_json, problem = problem, n_classes = n_class)

    x_train, x_test, y_train, y_test = train_test_split(img_db, angle_db, test_size=0.20)

    training = centriole_dataset(img_db = x_train, angle_db = y_train, problem = problem)
    testing = centriole_dataset(img_db = x_test, angle_db = y_test, problem = problem)

    train_loader = torch.utils.data.DataLoader(training, batch_size = batch_size, shuffle = True, drop_last=True)
    validation_loader = torch.utils.data.DataLoader(testing, batch_size = batch_size, shuffle = False, drop_last=True)
    
    if save_dataset == True:
        if problem == 'classification':
            train_name = './data/train_data_p' + problem + '_n' + str(n_class) + '_b' + str(batch_size) + '_.pth' 
            val_name = './data/validation_data_p' + problem + '_n' + str(n_class) + '_b' + str(batch_size) + '_.pth' 
        
        else:
            train_name = './data/train_data_p' + problem + '_b' + str(batch_size) + '_.pth' 
            val_name = './data/validation_data_p' + problem + '_b' + str(batch_size) + '_.pth' 
        
        pickle.dump(train_loader, open(train_name, 'wb'), protocol=4)
        pickle.dump(validation_loader, open(val_name, 'wb'), protocol=4)
            
             
    return train_loader, validation_loader

In [6]:
def dataset_loader(path = '../data/', train_set = 'train_loader_dataset_b700_unNormalized.pth', val_set = 'validation_loader_dataset_b700_unNormalized.pth' ):
    """ A function that load a 'torch' dataset.
    Arguments:
        -> path    : path to the dataset location
        -> train_set: name of the train dataset
        -> val_set  : name of the validation dataset
    """
    train_loader = pickle.load(open(path + train_set, 'rb'))
    validation_loader = pickle.load(open(path + val_set, 'rb'))
    
    return train_loader, validation_loader
