In [1]:
import os
import glob
import matplotlib.pylab as plt
import matplotlib as mpl
import numpy as np
from torchvision import transforms
import torch
import cv2
from skimage.io import imread
from skimage.transform import resize
from torch.utils.data import Dataset

In [2]:
class Cases:
    
    def __init__(self, path='../data/preprocess/'):
        self.path = path
        self.cases = [f.path[19:] for f in os.scandir(path) if f.is_dir()]
        self.before_dict = self._fill_dicts()
        self.after_dict = self._fill_dicts(kind='AFTER')
        
    def _fill_dicts(self, kind='BEFORE') -> dict:
        cdict = {}
        for case in self.cases:
            lst = [imread(file) for file in glob.glob(self.path + case + "/" + kind + "/*.JPG")]
            cdict[case] = lst
        return cdict
            
    @property
    def before(self) -> dict:
        return self.before_dict
    
    @property
    def after(self) -> dict:
        return self.after_dict
    
    def get_array(self) -> np.ndarray:
        lst = []
        for case in self.cases:
            lst += [file for file in glob.glob(self.path + case + "/BEFORE/*.JPG")]
            lst += [file for file in glob.glob(self.path + case + "/AFTER/*.JPG")]
        return lst
    
class Data(Dataset):
    
    def __init__(self, data, transform = None):
        self.data = data
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.data[idx], cv2.COLOR_BGR2RGB)
      
        # augmentations
        if self.transform is not None:
            image = self.transform(image = image)['image']
        
        return image