In [None]:
import os
import _pickle as pickle 
import torch
import numpy as np
from multiprocessing import Pool

class Data:
    
    def __init__(self, path_to_pickle_folders, replace_classes = {}, maximum_per_folder = None):
        
        if type(path_to_pickle_folders) != list:
            path_to_pickle_folders = [path_to_pickle_folders]
        
        self.data = []
        
        for folder in path_to_pickle_folders:
            print("Unpacking", os.path.basename(folder))
            working_label = os.path.basename(folder)
            
            if os.path.basename(folder) in replace_classes:
                working_label = replace_classes[os.path.basename(folder)]
            
            files_to_unpickle = [os.path.join(folder, img) for img in os.listdir(folder)]
            files_to_unpickle = files_to_unpickle[:maximum_per_folder]
            
#             Single threaded
#             for img in files_to_unpickle:
#                 self.addPickle(img)
            
#             Multithreaded
            p = Pool()
            results = p.map(self.parsePickle, files_to_unpickle)
            p.close()
            p.join()      
            
            # add to data
            for file in results:
                try:
                    if file:
                        pass
                except:
                    self.data.append({
                        working_label : file
                    })

        self.convetLabels()
        self.dataToTensor()
    
    def dataToTensor(self):
        i = 0
        for data_dict in self.data:
            array = list(data_dict.values())[0]
            array = torch.Tensor(array).unsqueeze(0)
            self.data[i] = {
                list(data_dict.keys())[0] : array
            }
            i+=1
    
    def convetLabels(self):
        all_labels = np.array([list(data.keys())[0] for data in self.data])
        unique_labels = list(np.unique(all_labels))
        self.label_dict = {label:unique_labels.index(label) for label in unique_labels}
    
    def parsePickle(self, path_to_pickle):
        try:
            f=open(path_to_pickle,'rb')
            img=pickle.load(f)
            f.close()
            return img
        except:
            pass
    
    def __getitem__(self, idx):
        ''' Return img, label'''
        data = self.data[idx]
        img = list(data.values())[0]
        word_label = list(data.keys())[0]
        label = self.label_dict[word_label]

        return img, label
    
    def __len__(self):
        return len(self.data)

## Ex: Load hem/noHem classes

In [66]:
%%time

training_folders = [
    "Processed/train/epidural",
    "Processed/train/intraparenchymal",
    "Processed/train/subarachnoid",
    "Processed/train/intraventricular",
    "Processed/train/subdural",
    "Processed/train/nohem",
]

train_data = Data(training_folders, replace_classes = {
    "epidural":"any", 
    "intraparenchymal":"any", 
    "subarachnoid":"any", 
    "intraventricular":"any", 
    "subdural":"any", 
}, maximum_per_folder=10)

Unpacking epidural
Unpacking intraparenchymal
Unpacking subarachnoid
Unpacking intraventricular
Unpacking subdural
Unpacking nohem
CPU times: user 2.43 s, sys: 3.44 s, total: 5.88 s
Wall time: 8.29 s


In [67]:
train_data[-1][0].shape

torch.Size([1, 512, 512])

## Ex: Load 5 sub classes

In [None]:
training_folders = [
    "Processed/train/epidural",
    "Processed/train/intraparenchymal",
    "Processed/train/subarachnoid",
    "Processed/train/intraventricular",
    "Processed/train/subdural",
]

train_data = Data(training_folders)