In [1]:
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from glob import glob
from PIL import Image, ImageEnhance
import os

In [None]:
cx = 17
cy = 41
cw = 367
ch = 471
dataroot = '/opt/ml/input/purified'

class MaskDataset(Dataset):
    # dataroot = '/opt/ml/input/purified'
    def __init__(self, dataroot: str, isTrain: bool, n_class=18):
        self.x = []
        self.y = []
        self.isTrain = isTrain
        self.transform = T.Compose([
                            T.RandomRotation((0,15)),
                            T.Resize((64, 64)),
                            T.RandomAutocontrast(0.3),
                            T.RandomHorizontalFlip(0.5),
                            T.ToTensor(),
                        ])

        if isTrain:
            dataroot = os.path.join(dataroot, 'train')
            cls0 = [0,1,2,3,4,5]
            cls1 = [6,7,8,9,10,11]
            cls2 = [12,13,14,15,16,17]
            if (len(cls0 + cls1 + cls2) != n_class) or (len(set(cls0+cls1+cls2)) != n_class):
                raise Exception('[MaskDataset Exeption]: Need mental caring')

            for cls in range(n_class):
                if cls in cls0:     m_cls = 0
                elif cls in cls1:   m_cls = 1
                elif cls in cls2:   m_cls = 2
                else:
                    raise Exception('[Dataset Exeption]: check rootdir')

                cls_paths = glob(f'{dataroot}/{cls}/*.*')
                self.x.extend( cls_paths )
                self.y.extend( [m_cls] * len(cls_paths) )
        
        else:
            dataroot = os.path.join(dataroot, 'test')
            self.x = glob(f'{dataroot}/' + '*.*')


    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        X = Image.open(self.x[idx])
        X = self._preprocess(X)

        if self.isTrain:
            if self.transform:    
                X = self.transform(X)
            return X, self.y[idx]
        else:               
            return X

    def _preprocess(self, X: str)->Image:
        X = X.crop((cx,cy,cw,ch))
        X = ImageEnhance.Contrast(X).enhance(5)
        X = ImageEnhance.Color(X).enhance(0.8)
        return X