In [None]:
%run include.ipynb
%run FileIO.ipynb
%run Medical_IO.ipynb
import glob
import cv2
from scipy import signal
from torch.utils.data.sampler import SubsetRandomSampler

class Data_topo(torch.utils.data.Dataset):
    
    def __init__(self, root, root_pds, target_type, transform=None):
        self.root             = root
        self.transform        = transform
        self.address_book     = []
        self.address_book_pds = []
        os.chdir(root)
        for file in glob.glob("*."+target_type):
            self.address_book.append(os.path.join(root, file))
            self.address_book_pds.append(os.path.join(root_pds, file+".dat"))
        img_tease = cv2.imread(self.address_book[0], cv2.IMREAD_GRAYSCALE)
        print("Image shape: " + str(img_tease.shape))
        print("Image value range: %.2f - %.2f" %(np.amin(img_tease), np.amax(img_tease)))
        print("Image data type" + str(type(img_tease[0][0])))
        print("Required data type is np.uint8")
        
    def __len__(self):
        return len(self.address_book)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = np.uint8(cv2.imread(self.address_book[idx], cv2.IMREAD_GRAYSCALE))
        pd_path = self.address_book_pds[idx]
        if self.transform:
            img = self.transform(img)

        instance = {'image': img, 'pd_path': pd_path}
        return instance
    
class NII_with_label(torch.utils.data.Dataset):
    
    def __init__(self, root, target_type, transform=None):
        self.root = root
        self.transform = transform
        self.address_book = []
        self.labels = []
        self.kernel = Util_gen.generate_gaussian_kernel(3, 2, 3)
        
        os.chdir(root)
        for file in glob.glob("*."+target_type):
            self.address_book.append(os.path.join(root, file))
            self.labels.append(int(file.split('_')[2]))
        if (target_type == 'nii'):
            tease = FileIO_MEDICAL.load_nii(self.address_book[0])
        else:
            print("Data_with_label: unrecognized data type")
        print("Image shape: " + str(tease.shape))
        print("Image value range: %.2f - %.2f" %(np.amin(tease), np.amax(tease)))
        print("Image data type" + str(type(tease[0][0][0])))
        
    def __len__(self):
        return len(self.address_book)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        vol = np.float32(FileIO_MEDICAL.load_nii(self.address_book[idx]))
        #vol = signal.convolve(vol, self.kernel, mode='same')
        if self.transform:
            vol = self.transform(vol)
        instance = {'vol': vol, 'label': self.labels[idx]}
        return instance
    
class Data_dmt(torch.utils.data.Dataset):
    
    def __init__(self, root, target_type, transform=None):
        self.root             = root
        self.transform        = transform
        self.address_book     = []
        os.chdir(root)
        for file in glob.glob("*."+target_type):
            self.address_book.append(os.path.join(root, file))
        img_tease = FileIO.read_matrix_binary(self.address_book[0], 'i')
        print("Image shape: " + str(img_tease.shape))
        print("Image value range: %.2f - %.2f" %(np.amin(img_tease), np.amax(img_tease)))
        print("Image data type" + str(type(img_tease[0][0])))
        print("Required data type is np.int")
        
    def __len__(self):
        return len(self.address_book)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = np.float32(FileIO.read_matrix_binary(self.address_book[idx], 'i')) - 1.0
        img = (img - np.amax(img)/2.0) / (np.amax(img)/2.0)
        if self.transform:
            img = self.transform(img)
            img = torch.tanh(img)
        return img
    
class Data_fetcher(object):
    
    @staticmethod
    def fetch_dataset(name, data_path, batch_size, batch_workers, shuffle, drop_last, scalor, test_split=0.0, random_seed=-1):
        if name == "cifar10":
            dataset = dset.CIFAR10(root=data_path, download=True,
                      transform=transforms.Compose([
                          transforms.Resize([64, 64]),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ]))
        elif name == "celeba": # The data should be under a folder under root: root/celeba/*.png
            dataset = dset.ImageFolder(root=data_path,
                      transform=transforms.Compose([
                          transforms.Resize(64),
                          transforms.CenterCrop(64),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ]))
        elif name == "topo":
            dataset = Data_topo(data_path, FLAGS.pds_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToPILImage(),
                  transforms.ToTensor(),
                  transforms.Normalize([scalor], [scalor])
                 ]))
        elif name == "dmt":
            dataset = Data_dmt(data_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToPILImage(),
                  transforms.ToTensor()
                 ]))
        elif name == "nii":
                dataset = NII_with_label(data_path, FLAGS.data_extension,
                      transform=transforms.Compose(
                     [transforms.ToTensor()
                     ]))
        else:
            raise NotImplementedError('Unrecognized dataset %s' % name)
        
#         dataset_size = len(dataset)
#         split = int(np.floor(test_split*dataset_size))
#         train_, test_ = torch.utils.data.random_split(dataset, [dataset_size-split, split])
#         train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size,
#              shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
#         test_loader = torch.utils.data.DataLoader(test_, batch_size=batch_size,
#              shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
#         return train_loader, test_loader
           
        if (test_split > 0.0):
            dataset_size = len(dataset)
            indices = list(range(dataset_size))
            split = int(np.floor(test_split*dataset_size))
            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)
            train_indices, test_indices = indices[split:], indices[:split]

            train_sampler = SubsetRandomSampler(train_indices)
            test_sampler  = SubsetRandomSampler(test_indices)
            train_loader  = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
            test_loader   = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
            return train_loader, test_loader
        else:
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                         shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
            return dataloader
        
    @staticmethod
    def fetch_dataset_wValidation(name, data_path, batch_size, batch_workers, shuffle, drop_last, scalor, datasplit_scheme, test_split, xfold, fold_idx, random_seed=-1):
        '''
        This is the advanced version to fetch_dataset with validation split, it works like this:
        the data is split into [train+validation][test] according to test_split
        the [train+validation] is further split into [train][validation] according to valid_split
        @name: name of the dataset
        @batch_size: number of instances per batch
        @batch_workers: number of workers to fetch data
        @shuffle: if to shuffle the data
        @drop_last: drop the instances that do not fit in the last batch
        @scalor: scale the data
        @datasplit_scheme: "All" use all data for training; "Test" partitions into train/test according to test_split; "Valid" paritions into train/test/validation
        @test_split: percentage of the data for test
        @xfold: number of folds for validation
        @fold_idx: x-fold cross validation, indicates which fold to use as validation
        @random_seed: for random indices shuffle purpose
        '''
        if name == "cifar10":
            dataset = dset.CIFAR10(root=data_path, download=True,
                      transform=transforms.Compose([
                          transforms.Resize([64, 64]),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ]))
        elif name == "celeba": # The data should be under a folder under root: root/celeba/*.png
            dataset = dset.ImageFolder(root=data_path,
                      transform=transforms.Compose([
                          transforms.Resize(64),
                          transforms.CenterCrop(64),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                      ]))
        elif name == "topo":
            dataset = Data_topo(data_path, FLAGS.pds_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToPILImage(),
                  transforms.ToTensor(),
                  transforms.Normalize([scalor], [scalor])
                 ]))
        elif name == "dmt":
            dataset = Data_dmt(data_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToPILImage(),
                  transforms.ToTensor()
                 ]))
        elif name == "nii":
            dataset = NII_with_label(data_path, FLAGS.data_extension,
                  transform=transforms.Compose(
                 [transforms.ToTensor()
                 ]))
        else:
            raise NotImplementedError('Unrecognized dataset %s' % name)
        
#         dataset_size = len(dataset)
#         split = int(np.floor(test_split*dataset_size))
#         train_, test_ = torch.utils.data.random_split(dataset, [dataset_size-split, split])
#         train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size,
#              shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
#         test_loader = torch.utils.data.DataLoader(test_, batch_size=batch_size,
#              shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
#         return train_loader, test_loader

        if datasplit_scheme=="All":
            print("Using all data for trianing in data fetcher.")
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                         shuffle=shuffle, num_workers=int(batch_workers), drop_last=drop_last)
            return dataloader
        elif datasplit_scheme == "Test":
            print("Test mode in data fetcher.")
            dataset_size = len(dataset)
            indices = list(range(dataset_size))
            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)
            fold_base = int(np.floor(dataset_size)/xfold)
            fold_rec = [None] * xfold
            for fold_gen in range(xfold-1):
                list_tmp = list(np.arange(fold_gen*fold_base, (fold_gen+1)*fold_base, dtype=np.int32))
                fold_rec[fold_gen] = list_tmp
            fold_rec[xfold-1] = list(np.arange((xfold-1)*fold_base, dataset_size, dtype=np.int32))
            
            assert(fold_idx < xfold)
            fold_train = list()
            for fold_gen in range(xfold):
                if fold_gen == fold_idx:
                    continue
                else:
                    fold_train = fold_train + fold_rec[fold_gen]
            fold_test = fold_rec[fold_idx]
            
            indices = np.asarray(indices)
            train_indices = list(indices[fold_train])
            test_indices  = list(indices[fold_test])
            
            train_sampler = SubsetRandomSampler(train_indices)
            test_sampler  = SubsetRandomSampler(test_indices)
            train_loader  = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
            test_loader   = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
            return train_loader, test_loader
        elif datasplit_scheme == "Valid":
            print("Validation mode in data fetcher.")
            dataset_size = len(dataset)
            indices = list(range(dataset_size))
            split = int(np.floor(test_split*dataset_size))
            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)
            test_indices = indices[dataset_size-split:]
            train_valid_indices = indices[:dataset_size-split]
            
            fold_base = int(np.floor((dataset_size - split)/xfold))
            fold_rec = [None]*xfold
            for fold_gen in range(xfold-1):
                list_tmp = list(np.arange(fold_gen*fold_base, (fold_gen+1)*fold_base, dtype=np.int32))
                fold_rec[fold_gen] = list_tmp
            fold_rec[xfold-1] = list(np.arange((xfold-1)*fold_base, dataset_size-split, dtype=np.int32))
            
            assert(fold_idx < xfold)
            fold_train = list()
            for fold_gen in range(xfold):
                if fold_gen == fold_idx:
                    continue
                else:
                    fold_train = fold_train + fold_rec[fold_gen]
            fold_valid = fold_rec[fold_idx]
            
            train_valid_indices = np.asarray(train_valid_indices)
            train_indices = list(train_valid_indices[fold_train])
            valid_indices = list(train_valid_indices[fold_valid])

            train_sampler = SubsetRandomSampler(train_indices)
            valid_sampler = SubsetRandomSampler(valid_indices)
            test_sampler  = SubsetRandomSampler(test_indices)
            train_loader  = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
            valid_loader  = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
            test_loader   = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
            return train_loader, valid_loader, test_loader