In [3]:
from typing import Union, Optional, List, Tuple

class MoDataSet():
    def __init__(self,
                 X_data_or_dir:Union[str,list],
                 Y_data_or_dir:Union[str,list],
                 root:str='',
                 x_format:str='*',
                 y_format:str='*',
                 transform_dict:Optional[Union[dict,list]]={'train':{},'test':{},'eval':{}},
                 test_size:float=0.2,
                 eval_size:float=0, **arg):
        from glob import glob
        from os.path import join as path_join
        from os.path import basename as path_basename
        from os.path import splitext as path_splitext
        from torch.utils.data import Dataset
        from torch.utils.data import DataLoader
        from sklearn.model_selection import train_test_split
        from numpy import load as np_load

        class DataByPath(Dataset):
            def __init__(self, x_paths, y_paths, transforms={}):
                self.x_paths = x_paths
                self.y_paths = y_paths
                self.transforms = transforms

            def __getitem__(self, index):
                data_x = np_load(self.x_paths[index])
                data_y = np_load(self.y_paths[index])
                if 'x' in self.transforms:
                    data_x = self.transforms['x'](data_x)
                if 'y' in self.transforms:
                    data_y = self.transforms['y'](data_y)
                return {'x': data_x, 'y': data_y}
                # img = Image.open()

            def __len__(self):
                return len(self.y_paths)

            def to_loader(self, **args):
                return DataLoader(self, **args)
        class DataAccess(Dataset):
            def __init__(self, X, Y, transforms={}):
                self.X =X
                self.Y = Y
                self.transforms = transforms

            def __getitem__(self, index):
                data_x = self.X[index]
                data_y = self.Y[index]
                if 'x' in self.transforms:
                    data_x = self.transforms['x'](data_x)
                if 'y' in self.transforms:
                    data_y = self.transforms['y'](data_y)
                return {'x': data_x, 'y': data_y}
                # img = Image.open()

            def __len__(self):
                return len(self.Y)

            def to_loader(self, **args):
                return DataLoader(self, **args)

        self.TrainDataset,self.TestDataset,self.EvalDataset=None,None,None
        if isinstance(transform_dict,list):
            transform_dict={aset:{avar:transform_dict for avar in ['x','y']} for aset in ['train','test','eval']}
        if x_format[0]!='.':
            x_format='.'+x_format
        if y_format[0]!='.':
            y_format='.'+y_format
        if isinstance(X_data_or_dir,str):
            assert isinstance(Y_data_or_dir,str),'the type of input X and Y must be same'
            X_train = glob(path_join(root,X_data_or_dir,'*'+x_format))
            Y_train = [path_join(root,Y_data_or_dir, path_splitext(path_basename(x) ) [0]) + y_format for x in X_train]
        else:
            X_train=X_data_or_dir
            Y_train=Y_data_or_dir

        if test_size>0:
            X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=test_size,**arg)  # random_state=0,stratify=Y
        if eval_size>0:
            X_train, X_eval, Y_train, Y_eval = train_test_split(X_train, Y_train, test_size=eval_size,**arg)  # random_state=0,stratify=Y
        if isinstance(X_data_or_dir,str):
            self.TrainDataset=DataByPath(X_train,Y_train,transforms=transform_dict['train'])
            if test_size>0:
                self.TestDataset = DataByPath(X_test, Y_test, transforms=transform_dict['test'])
            if eval_size>0:
                self.EvalDataset = DataByPath(X_eval, Y_eval, transforms=transform_dict['eval'])
        else:
            self.TrainDataset = DataAccess(X_train, Y_train, transforms=transform_dict['train'])
            if test_size > 0:
                self.TestDataset = DataAccess(X_test, Y_test, transforms=transform_dict['test'])
            if eval_size > 0:
                self.EvalDataset = DataAccess(X_eval, Y_eval, transforms=transform_dict['eval'])


mm=MoDataSet(X_data_or_dir=[10,20,30,40,70,80,90,110,40,50,60,410],Y_data_or_dir=[7,8,9,6,2,3,4,7,8,44,12,14])
for id,it in enumerate(mm.TrainDataset.to_loader(batch_size=4)):
    print(id,it)

0 {'x': tensor([ 30, 410,  10,  40]), 'y': tensor([ 9, 14,  7,  8])}
1 {'x': tensor([40, 80, 90, 50]), 'y': tensor([ 6,  3,  4, 44])}
2 {'x': tensor([70]), 'y': tensor([2])}


In [None]:
class Compose:
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
    
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
        
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


In [4]:
transform_dict=['it','is']
transform_dict={aset:{avar:transform_dict for avar in ['x','y']} for aset in ['train','test','eval']}
print(transform_dict)

{'train': {'x': ['it', 'is'], 'y': ['it', 'is']}, 'test': {'x': ['it', 'is'], 'y': ['it', 'is']}, 'eval': {'x': ['it', 'is'], 'y': ['it', 'is']}}
