# Notebook for testing split equal function

In [1]:
from ..medmnist.info import INFO
from ..medmnist.dataset import PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST, \
    BreastMNIST, OrganMNISTAxial, OrganMNISTCoronal, OrganMNISTSagittal

ImportError: attempted relative import with no known parent package

In [3]:
class PrepareMedMNIST:
    def __init__(self, input_args, dataset_info):

        '''
        input args
            "data_name"         --> name of dataset                                                                                     --> string
            "data_root"         --> folder of medmnist data                                                                             --> string
            "output_root"       --> folder for saving results                                                                           --> string
            "n_epochs"          --> n of epochs in training                                                                             --> int
            "batch_size"        --> batch size                                                                                          --> int
            "learning_rate"     --> learning rate of the optimizer                                                                      --> float
            "momentum"          --> momentum of optimizer                                                                               --> float
            "train_size"        --> percentage of data for training                                                                     --> int
            "weight_decay"      --> weight decay                                                                                        --> float
            "model"             --> used model architecture                                                                             --> string
            "n_studentnets"     --> n of studentnet iterations in pseudolabeling                                                        --> int
            "operation"         --> train a model (False) or make predictions (True)                                                    --> boolean
            "task"              --> task: "Pseudolabel", "MTSS", "NoisyStudent", "Baseline"                                             --> string
            "optimizer"         --> optimizer                                                                                           --> string
            "LR_decay"          --> decay of learning rate: default 0                                                                   --> float
            "LR_milestones"     --> milestones for lr decay                                                                             --> int
            "Loss_function"     --> loss function                                                                                       --> string
            "augmentations"     --> list of augmenations                                                                                --> list
            "download"          --> download the data                                                                                   --> boolean
        
        dataset_info
            "description"       --> Description of dataset
            "url"               --> download url
            "MD5"               -->
            "task"              --> dataset task: "multi-class", "binary-class", "ordinal regression", "multi-label, binary-class"
            "label":            --> labels dictionary with class number and description
            "n_channels"        --> binary image (1), rgb image (3)
            "n_samples":        --> number of dataset splits: "train": XX, "val": XX, "test": XX
            "license":          --> licence for usage
        '''
        self.input_args = input_args
        self.dataset_info = dataset_info

        # create transformations of dataset 
        transform = self.createTransform(augmentations=self.input_args["augmentations"])

        # define dataset for training/ validation/ testing
        self.dataset_train   = self.prepareDataSet('train', transform)
        self.dataset_val     = self.prepareDataSet('train', transform)
        self.dataset_test     = self.prepareDataSet('train', transform)

        # create dataloader for training/ validation/ testing
        self.dataloader_train = self.createDataLoader(self.dataset_train)
        self.dataloader_val = self.createDataLoader(self.dataset_val)
        self.dataloader_test = self.createDataLoader(self.dataset_test)

        # split dataset if needed for training/ validation/ testing
        self.train_dataset_labeled, self.train_subset_labeled, self.train_dataset_unlabeled, self.train_subset_unlabeled = self.splitDataset(self.dataset_train, self.dataloader_train)
        #self.val_dataset_labeled, self.val_subset_labeled, self.val_dataset_unlabeled, self.val_subset_unlabeled = self.splitDataset(self.dataset_val, self.dataloader_val, 'val')
        #self.train_dataset_labeled, self.train_subset_labeled, self.train_dataset_unlabeled, self.train_subset_unlabeled = self.splitDataset(self.dataset_test, self.dataloader_test, 'train')



    def createTransform(self, image_size=32, augmentations=[]):
        aug_values = {
            "CenterCrop"   : {"size": 10},
            "ColorJitter"  : {"brightness": 0, "contrast": 0, "saturation": 0, "hue": 0},
            "GaussianBlur" : {"kernel": [3,3], "sigma" : 0.1},
            "Normalize"    : {"mean": [0.5], "std": [0.5]},
            "RandomHorizontalFlip" : {"probability": 0.5},
            "RandomVerticalFlip" : {"probability": 0.5},
            "RandomRotation" : {"degrees": [-20, 20]}	
        }

        tranform_compose_list = [transforms.ToTensor()]
        if self.input_args["model"] == "EfficientNet-b0" or self.input_args["model"] == "EfficientNet-b1" or self.input_args["model"] == "EfficientNet-b7":
            tranform_compose_list.append(transforms.Resize(256))
        
        for aug in self.input_args["augmentations"]:
            if aug == "centerCrop":
                tranform_compose_list.append(transforms.CenterCrop(
                            aug_values["CenterCrop"]["size"]))
            elif aug == "colorJitter":
                tranform_compose_list.append(transforms.ColorJitter(
                            brightness=aug_values["ColorJitter"]["brightness"], 
                            contrast=aug_values["ColorJitter"]["contrast"],
                            saturation=aug_values["ColorJitter"]["saturation"], 
                            hue=aug_values["ColorJitter"]["hue"]))
            elif aug == "gaussianBlur":
                tranform_compose_list.append(transforms.GaussianBlur(
                            kernel_size=aug_values["GaussianBlur"]["kernel"], 
                            sigma=aug_values["GaussianBlur"]["sigma"]))
            elif aug =="normalize":
                tranform_compose_list.append(transforms.Normalize(
                            mean=aug_values["Normalize"]["mean"], 
                            std=aug_values["Normalize"]["std"]))
            elif aug =="randomHorizontalFlip":
                tranform_compose_list.append(transforms.RandomHorizontalFlip(
                            p=aug_values["RandomHorizontalFlip"]["probability"]))
            elif aug =="randomVerticalFlip":
                tranform_compose_list.append(transforms.RandomVerticalFlip(
                            p=aug_values["RandomVerticalFlip"]["probability"]))
            elif aug =="randomRotation":
                tranform_compose_list.append(transforms.RandomRotation(
                            degrees=aug_values["RandomRotation"]["degrees"]))
            #else:
                #print("augmentation not found!")
        
        transform = transforms.Compose(tranform_compose_list)
        return transform

    def prepareDataSet(self, split, transform):

        flag_to_class = {
            "pathmnist": PathMNIST,
            "chestmnist": ChestMNIST,
            "dermamnist": DermaMNIST,
            "octmnist": OCTMNIST,
            "pneumoniamnist": PneumoniaMNIST,
            "retinamnist": RetinaMNIST,
            "breastmnist": BreastMNIST,
            "organmnist_axial": OrganMNISTAxial,
            "organmnist_coronal": OrganMNISTCoronal,
            "organmnist_sagittal": OrganMNISTSagittal,
        }
        DataClass = flag_to_class[self.input_args["data_name"]]

        dataset = DataClass(root=self.input_args["data_root"],
                            split=split,
                            transform=transform,
                            download=self.input_args["download"])

        return dataset

    def createDataLoader(self, dataset):
        data_loader = data.DataLoader(dataset=dataset,
                                    batch_size=self.input_args["batch_size"],
                                    shuffle=True)
        
        return data_loader

In [None]:
prepareData = PrepareMedMNIST()