In [None]:
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import utils, transforms
from glob import glob
import os

In [None]:
class NiftiDataset(Dataset):
    def __init__(self, source_dir, transform=None):
        """
        create a dataset class in PyTorch for reading nifti (.nii.gz) files
        source_dir: path to images
        transform: transform to apply to images (Probably None or ToTensor)
        """
        self.source_fns = sorted(glob(os.path.join(source_dir, "*.nii.gz")))
        self.transform = transform

    def __len__(self):
        return len(self.source_fns)

    def __getitem__(self, idx:int):
        src_fn = self.source_fns[idx]
        img_tmp = nib.load(src_fn)
        sample = img_tmp.get_fdata(dtype=np.float64)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample


In [None]:
good_data = NiftiDataset(source_dir='/media/data/Track_2/good',
                         transform=transforms.Compose([
                             transforms.ToTensor(),
#                              transforms.Normalize(mean=[0.5]*218, std=[0.1]*218)
                         ])
                        )
bad_data = NiftiDataset(source_dir='/media/data/Track_2/bad',
                       transform=transforms.ToTensor())

pure_data = NiftiDataset(source_dir='/media/data/Track_2/mni',
                        transform=transforms.ToTensor())

batch_size = 1
good_dataloader = DataLoader(good_data, batch_size=batch_size,
                             shuffle=True, num_workers=batch_size)
bad_dataloader = DataLoader(bad_data, batch_size=batch_size,
                            shuffle=True, num_workers=batch_size)
pure_dataloader = DataLoader(pure_data, batch_size=batch_size, shuffle=True, num_workers=batch_size)

In [None]:
def normalize_brain(brainy_batch):
    mean_val = images_batch.mean(dim=[1,2,3])
    std_val = images_batch.std(dim=[1,2,3])
    images_batch = (images_batch-mean_val[:, None, None, None])/std_val[:, None, None, None]

def show_brains(images_batch):
    images_batch = normalize_brain(images_batch)

    images_batch = (images_batch-images_batch.min())/(images_batch.max()-images_batch.min())
    images_batch = images_batch[:, 50, :, :]
    images_batch = images_batch.unsqueeze(1)
    grid = utils.make_grid(images_batch)
    
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

for good_batch, bad_batch, pure_dataloader in zip(good_dataloader, bad_dataloader, pure_dataloader):
    print(good_batch[0].size())
    good_batch = good_batch[0]
    bad_batch = bad_batch[0]
    pure_batch = pure_batch[0]
    print(good_batch.size())
    print(pure_batch.size())
    plt.figure()
    print("Good Brain")
    show_brains(good_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()
    plt.figure()
    print("Bad Brain")
    show_brains(bad_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()