In [1]:
import os
import torch
import time
import pandas as pd
import nibabel as nib
import numpy as np
import random
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

import gc
gc.collect()
torch.cuda.empty_cache()

In [2]:
DATA_PATH = "/home/admin1/Arindam/Alzheimer/PreprocessedData/miriad"
config = {
    'img_size': 192,
    'depth' : 192
}

random.seed(0)

In [3]:
class DataPaths():
    def __init__(self, data_path=DATA_PATH):
        self.data_path = data_path

    def mri_path_loading(self):
        paths = []
        for (root,dirs,files) in os.walk(self.data_path, topdown=True):
            if len(files):
                for file_path in files:
                    if file_path.endswith('mni_norm.nii.gz'):
                        paths.append(os.path.join(root, file_path))
        
        random.shuffle(paths)
        return paths

    def train_val_test_division(self, train_split=0.7, val_split=0.15, stratify=True):
        test_split = 1 - (train_split + val_split)
        image_paths = self.mri_path_loading()
        AD_image_paths, HC_image_paths = [], []
        for im_path in image_paths:
            if "_AD_" in im_path:
                AD_image_paths.append(im_path)
            elif "_HC_" in im_path:
                HC_image_paths.append(im_path)

        assert len(AD_image_paths) + len(HC_image_paths) == len(image_paths)

        no_of_images = {
            "train_ad" : int(train_split*len(AD_image_paths)),
            "val_ad" : int(val_split*len(AD_image_paths)),
            "test_ad" : len(AD_image_paths) - int(train_split*len(AD_image_paths)) - int(val_split*len(AD_image_paths)),
            "train_hc" : int(train_split*len(HC_image_paths)),
            "val_hc" : int(val_split*len(HC_image_paths)),
            "test_hc" : len(HC_image_paths) - int(train_split*len(HC_image_paths)) - int(val_split*len(HC_image_paths))
        }
        print(no_of_images)
        len_train = no_of_images['train_ad'] + no_of_images['train_hc']
        len_val = no_of_images['val_ad'] + no_of_images['val_hc']
        len_test = no_of_images['test_ad'] + no_of_images['test_hc']
        print("Total number of train, validation and test images are {}, {} and {} respectively.".format(len_train, len_val, len_test))
        
        save_path = os.path.join(os.getcwd(), 'data/MIRIAD')
        if os.path.exists(save_path)==False:
            os.mkdir(save_path)

        trin_img_df = pd.DataFrame(HC_image_paths[:no_of_images['train_hc']]+\
                                   AD_image_paths[:no_of_images['train_ad']], columns=['image_path'])
        trin_img_df_path = os.path.join(save_path, 'train_mri_scan_list.csv')
        trin_img_df.to_csv(trin_img_df_path, index=False)

        val_img_df = pd.DataFrame(HC_image_paths[no_of_images['train_hc']:(no_of_images['train_hc']+no_of_images['val_hc'])]+\
                                   AD_image_paths[no_of_images['train_ad']:(no_of_images['train_ad']+no_of_images['val_ad'])], columns=['image_path'])
        val_img_df_path = os.path.join(save_path, 'val_mri_scan_list.csv')
        val_img_df.to_csv(val_img_df_path, index=False)

        test_img_df = pd.DataFrame(HC_image_paths[no_of_images['train_hc']+no_of_images['val_hc']:]+\
                                   AD_image_paths[no_of_images['train_ad']+no_of_images['val_ad']:], columns=['image_path'])
        test_img_df_path = os.path.join(save_path, 'test_mri_scan_list.csv')
        test_img_df.to_csv(test_img_df_path, index=False)

        return trin_img_df_path, val_img_df_path, test_img_df_path

In [4]:
class ADNIAlzheimerDataset(Dataset):
    def __init__(self, image_df_paths, transform=None):
        self.image_df_paths = image_df_paths
        self.transform = transform
        self.df = pd.read_csv(self.image_df_paths)
        self.desired_width = config['img_size']
        self.desired_height = config['img_size']
        self.desired_depth = config['depth']

    def __label_extract(self, im_path):
        if "_AD_" in im_path:
            return 1
        elif "_HC_" in im_path:
            return 0
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        data = {}
        image_filepath = self.df['image_path'][idx]
        image = nib.as_closest_canonical(nib.load(image_filepath))
        image = image.get_fdata()
        xdim, ydim, zdim = image.shape
        image = np.pad(image, [((256-xdim)//2, (256-xdim)//2), ((256-ydim)//2, (256-ydim)//2), ((256-zdim)//2, (256-zdim)//2)], 'constant', constant_values=0)
        #image = image.reshape(image.shape[2], image.shape[1], image.shape[0])

        width_factor = self.desired_width / image.shape[0]
        height_factor = self.desired_height / image.shape[1]
        depth_factor = self.desired_depth / image.shape[-1]

        image = zoom(image, (width_factor, height_factor, depth_factor), order=1)
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])
        image = image.astype('float32')
        image = torch.from_numpy(image)
        
        label = self.__label_extract(image_filepath)
        
        return image, label

In [5]:
data_paths = DataPaths()
paths = data_paths.mri_path_loading()
trin_img_df_path, val_img_df_path, test_img_df_path = data_paths.train_val_test_division()

train_dataset = ADNIAlzheimerDataset(trin_img_df_path)
val_dataset = ADNIAlzheimerDataset(val_img_df_path)
test_dataset = ADNIAlzheimerDataset(test_img_df_path)

def saveTensors(dataset, data_type):
    path = '/home/admin1/Arindam/Alzheimer/ViT/data/MIRIAD/MRIs'
    data_path = os.path.join(path, data_type)
    if os.path.exists(data_path)==False:
        os.mkdir(data_path)
    
    labels = {
        0 : 'HC',
        1 : 'AD'
    }
    
    for label in labels.keys():
        os.mkdir(os.path.join(data_path, labels[label]))
    
    print(f"Processing for {data_type} data is starting. Data will be saved at {data_path}")
    print(f"Total number of images are: {len(dataset)}")
    
    start = time.time()
    for idx in range(len(dataset)):
        tensor, label = dataset.__getitem__(idx)
        tensor_path = f"{data_path}/{labels[label]}/{idx}.pt"
        torch.save(tensor, tensor_path)
        
        if (idx+1)%100==0:
            print(f"{idx+1} images done.")
    
    req_time = time.time() - start
    print(f"Total time required for processing the data is {req_time// 60} minutes {req_time%60} sec.")
    print(f"Processing of a single image took {req_time/(1.0*len(dataset))} sec.")

{'train_ad': 325, 'val_ad': 69, 'test_ad': 71, 'train_hc': 170, 'val_hc': 36, 'test_hc': 37}
Total number of train, validation and test images are 495, 105 and 108 respectively.


In [6]:
saveTensors(test_dataset, 'Test')

Processing for Test data is starting. Data will be saved at /home/admin1/Arindam/Alzheimer/ViT/data/MIRIAD/MRIs/Test
Total number of images are: 108
100 images done.
Total time required for processing the data is 0.0 minutes 36.04144597053528 sec.
Processing of a single image took 0.3337170923197711 sec.


In [7]:
saveTensors(val_dataset, 'Val')

Processing for Val data is starting. Data will be saved at /home/admin1/Arindam/Alzheimer/ViT/data/MIRIAD/MRIs/Val
Total number of images are: 105
100 images done.
Total time required for processing the data is 0.0 minutes 35.26710748672485 sec.
Processing of a single image took 0.33587721415928434 sec.


In [8]:
saveTensors(train_dataset, 'Train')

Processing for Train data is starting. Data will be saved at /home/admin1/Arindam/Alzheimer/ViT/data/MIRIAD/MRIs/Train
Total number of images are: 495
100 images done.
200 images done.
300 images done.
400 images done.
Total time required for processing the data is 2.0 minutes 46.547834157943726 sec.
Processing of a single image took 0.33646027102614895 sec.
