In [None]:
# %pip install -q "monai[nibabel]"

In [None]:
import glob
import nibabel as nb
import numpy as np
import torch
import torch.utils.data as data

from monai.utils import first

In [None]:

class MRIDataset(data.Dataset):
    def __init__(self, X_files, y_files, transforms=None):
        self.X_files = X_files
        self.y_files = y_files
        self.transforms = transforms
    
        img_array = list()
        label_array = list()
        for vol_f, label_f in zip(self.X_files, self.y_files):
            img, label = nb.load(vol_f), nb.load(label_f)
            img_array.extend(np.array(img.get_fdata()))
            label_array.extend(np.array(label.get_fdata()))
            img.uncache()
            label.uncache()
            
        X = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
        y = np.stack(label_array, axis=0) if len(label_array) > 1 else label_array[0]
        self.X = X if len(X.shape) == 4 else X[:, np.newaxis, :, :]
        self.y = y
        print(self.X.shape, self.y.shape)
        
    def __getitem__(self, index):
        img = torch.from_numpy(self.X[index])
        label = torch.from_numpy(self.y[index])
        return img, label

    def __len__(self):
        return len(self.y)

In [3]:
ukb_volumes = sorted(glob.glob(f'temp/UKB/processed_resized/volume_cropped/**.nii.gz'))
ukb_labels = sorted(glob.glob(f'temp/UKB/processed_resized/label_cropped/**.nii.gz'))

ds = MRIDataset(ukb_volumes, ukb_labels)
loader = torch.utils.data.DataLoader(ds, batch_size=3, num_workers=4)

im, seg = first(loader)
print(im.shape, seg.shape)

(4320, 1, 145, 100) (4320, 145, 100)
torch.Size([3, 1, 145, 100]) torch.Size([3, 145, 100])


In [4]:
im, seg = first(loader)
print(im.shape, seg.shape)

torch.Size([3, 1, 145, 100]) torch.Size([3, 145, 100])
