<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/generate_ensemble_tfrecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
# Colab only
from google.colab import drive

In [14]:
import sys
import numpy as np, os, shutil, math
import tensorflow as tf, csv
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
import nibabel as nib
import skimage.transform as transform
import matplotlib.pyplot as plt

In [15]:
def load_image(path, add_axis=True):    

    img = nib.load(path)
    img = np.asarray(img.dataobj, dtype=np.float32)
    if add_axis:
        img = np.expand_dims(img, axis=3) # Add axis for channel

    return img

def downscale(image, shape):
    'For upscale, anti_aliasing should be false'
    return transform.resize(image, shape, mode='constant', anti_aliasing=True)

def standarize(X):

    mean = np.mean(X)
    std = np.std(X)
    
    if std > 0:
        X -= mean
        X /= std
    else:
        X *= 0

def max_intensity_normalization(X, proportion):

    n_max_values = int(np.prod(X.shape, axis=0) * proportion)
    n_max_idx = np.unravel_index((X).argsort(axis=None)[-n_max_values:], X.shape)
    mean = np.mean(X[n_max_idx])
    X /= mean

def preprocess_image(X, steps, arguments):

    for f, args in zip(steps, arguments):
        if args is None:
            f(X)
        else:
            f(X, *args)

In [16]:
# We can store three types of data in a TFRecord: bytestring, integer and floats. 
# They are always stored as lists, a single data element will be a list of size 1
def _bytestring_feature(list_of_bytestrings):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

def _float_feature(list_of_floats): # float32
    return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

def _int_feature(list_of_ints): # int64
    return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

def to_tfrecord(image_pet, image_mri, label):
    
    one_hot_label = np.eye(3, dtype=np.float32)[label]
        
    feature = {
        'image_pet': _float_feature(image_pet),
        'image_mri': _float_feature(image_mri),
        'one_hot_label': _float_feature(one_hot_label.tolist())
    }
    
    # Create a Features message
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [20]:
def generate_tfrecords(list_of_filenames, labels, dir, tfrec_name, num_folds=15, 
                       stratify=True, shuffle=True, random_state=None, make_summary=True):
    """Given path to images and corresponding labels, creates num_folds tfrecords 
    containing the images"""
    
    if not os.path.exists(dir):
        os.makedirs(dir)
    
    if make_summary:
        summary_filename = os.path.join(dir, tfrec_name,)
        summary_filename += '_summary.csv'
        with open(summary_filename, 'w', encoding='UTF8', newline='') as f:
            csv_writer = csv.writer(f)
            header = ['tfrec_id', '#samples']
            header += [c for c in CLASSES]
            csv_writer.writerow(header)

        f = open(summary_filename, 'a', encoding='UTF8', newline='')
        csv_writer = csv.writer(f)

    if stratify:
        kfold = StratifiedKFold(num_folds, shuffle, random_state)
    else:
        kfold = KFold(num_folds, shuffle, random_state)
    
    for n, (_, indices) in enumerate(kfold.split(list_of_filenames[0], labels)):
                
        name = f'{tfrec_name}_{n}-{len(indices)}.tfrec'

        if make_summary:
            num_samples = str(len(indices))
            classes, count = np.unique(labels[indices], return_counts=True)
            class_counts = np.zeros(len(CLASSES), dtype=np.int64)
            class_counts[classes] = count
            row = [name] + [num_samples] + list(class_counts.astype(str))
            csv_writer.writerow(row)
        
        with tf.io.TFRecordWriter(os.path.join(dir, name)) as writer:

            for index in indices:

                filename_pet = list_of_filenames[0][index]
                filename_mri = list_of_filenames[1][index]

                label = labels[index]

                img_pet = np.nan_to_num(load_image(filename_pet), copy=False)
                img_mri = np.nan_to_num(load_image(filename_mri), copy=False)

                img_pet = preprocess_pet(img_pet)
                img_mri = preprocess_mri(img_mri)

                example = to_tfrecord(img_pet.ravel(), img_mri.ravel(), label)
                writer.write(example.SerializeToString())

In [18]:
def preprocess_pet(img):
    return img

def preprocess_mri(img):
    return img

In [None]:
CLASSES = ['NOR', 'AD', 'MCI']

drive.mount('/content/drive')
DATA_PATH = '/content/drive/MyDrive/data/'

SEED = 27

DS = 'ad-preprocessed'
DS_PATH = DATA_PATH + DS
OUT_DS = 'tfrec-ensemble_mri_pet'
OUT_PATH = DATA_PATH + OUT_DS

# Path to PET images
pet_paths = np.empty((0,), dtype=str)
pet_labels = np.empty((0,), dtype=np.int64)


for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'PET') + '/*.nii'
    pet_paths = np.concatenate((pet_paths, np.array(tf.io.gfile.glob(pattern))))
    pet_labels = np.concatenate((pet_labels, np.full(len(pet_paths) - len(pet_labels), label, dtype=np.int64)))

idx = np.argsort(pet_paths)
pet_paths, pet_labels = pet_paths[idx], pet_labels[idx]

X_train_pet, X_test_pet, y_train_pet, y_test_pet = train_test_split(pet_paths, pet_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = pet_labels)

# Path to MRI, grey matter images
mri_grey_paths = np.empty((0,), dtype=str)
mri_grey_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'MRI/grey') + '/*.nii'
    mri_grey_paths = np.concatenate((mri_grey_paths, np.array(tf.io.gfile.glob(pattern))))
    mri_grey_labels = np.concatenate((mri_grey_labels, np.full(len(mri_grey_paths) - len(mri_grey_labels), label, dtype=np.int64)))
    
idx = np.argsort(mri_grey_paths)
mri_grey_paths, mri_grey_labels = mri_grey_paths[idx], mri_grey_labels[idx]

X_train_mri, X_test_mri, y_train_mri, y_test_mri = train_test_split(mri_grey_paths, mri_grey_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = mri_grey_labels)

generate_tfrecords([X_train_pet, X_train_mri], y_train_pet, OUT_PATH + '/train', 'train',
                   len(X_train_pet), False, False)

generate_tfrecords([X_test_pet, X_test_mri], y_train_pet, OUT_PATH + '/test', 'test',
                   len(X_test_pet), False, False)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
