In [3]:
import os
import shutil

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
import torch
import torch.utils.data as torch_data
from skimage.transform import resize

In [2]:
PATH_TO_SOURCE_TRAIN = 'dl_bio/ax_t2_source_train'
PATH_TO_SOURCE_VAL = 'dl_bio/ax_t2_source_val'
PATH_TO_SOURCE_TEST = 'dl_bio/ax_t2_source_test'
RANDOM_STATE = 42

In [7]:
def npy_load(path):
    with open(path, 'rb') as f:
        return np.load(f)

In [32]:
class SRGANMRIDataLoader(torch_data.Dataset):
    def __init__(self, path_to_data, seed=None):
        super().__init__()

        self.images = [os.path.join(path_to_data, file) for file in os.listdir(path_to_data) 
                       if file.endswith('.npy')]
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        slice = npy_load(self.images[idx])
        resized = self.sample_resize(slice)
        
        return self.images[idx], slice, resized
        
    def sample_resize(self, slice):
        slice_res = resize(slice[0], (160, 160), preserve_range=True)
        slice_res = np.expand_dims(slice_res, axis=(0))
#         print(slice.shape, slice_res.shape)
        return slice_res

In [33]:
train_loader = SRGANMRIDataLoader(PATH_TO_SOURCE_TEST, seed=RANDOM_STATE)

In [34]:
name, slic, resized = train_loader.__getitem__(0)

In [35]:
slic.shape, resized.shape

((1, 320, 320), (1, 160, 160))