<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/generate_tfrecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook has the function of converting the original dataset in tfrecords (better performance on i/o operations and other advantages).

# Imports

In [21]:
from google.colab import drive
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from scipy import ndimage
import os, csv
import numpy as np
import tensorflow as tf
import nibabel as nib
import skimage.transform as transform

# Image loading and preprocessing

In [22]:
def load_image(path, add_axis=True):    
    img = nib.load(path)
    img = np.asarray(img.dataobj, dtype=np.float32)
    if add_axis:
        img = np.expand_dims(img, axis=3) # Add axis for channel
    return img

def downscale(image, shape):
    'For upscale, anti_aliasing should be false'
    return transform.resize(image, shape, mode='constant', anti_aliasing=True)


def standarize(X):
    """Standarize with zero mean, unit variance"""
    mean = np.mean(X)
    std = np.std(X)
    if std > 0:
        X = X - mean
        X = X/std
    else:
        X = X * 0
    return X

def max_intensity_normalization(X, proportion):
    n_max_values = int(np.prod(X.shape, axis=0) * proportion)
    n_max_idx = np.unravel_index((X).argsort(axis=None)[-n_max_values:], X.shape)
    mean = np.mean(X[n_max_idx])
    X = X/mean
    return X

def minmax(X):
    min = np.min(X)
    max = np.max(X)
    X = (X - min)/(max - min)
    return X

#### Preprocessing for COVID-19 data ###
def normalize(X):
    min = -1000
    max = 400
    X[X < min] = min
    X[X > max] = max
    X = (X - min)/(max - min)
    return X

def resize_img(img, shape=(64, 128, 128)):
    width = img.shape[0] / shape[0]
    height = img.shape[1] / shape[1]
    depth = img.shape[2] / shape[2]

    depth_factor = 1/depth
    width_factor = 1/width
    height_factor = 1/height

    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

# Functions for generating tfrecords

In [24]:
# We can store three types of data in a TFRecord: bytestring, integer and floats. 
# They are always stored as lists, a single data element will be a list of size 1
def _bytestring_feature(list_of_bytestrings):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

def _float_feature(list_of_floats): # float32
    return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

def _int_feature(list_of_ints): # int64
    return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

def to_tfrecord(image, label):
    
    one_hot_label = np.eye(3, dtype=np.float32)[label]
        
    feature = {
        'image': _float_feature(image),
        'one_hot_label': _float_feature(one_hot_label.tolist())
    }
    
    # Create a Features message
    return tf.train.Example(features=tf.train.Features(feature=feature))


def to_tfrecord_2(image_pet, image_mri, label):
    
    one_hot_label = np.eye(3, dtype=np.float32)[label]
        
    feature = {
        'image_pet': _float_feature(image_pet),
        'image_mri': _float_feature(image_mri),
        'one_hot_label': _float_feature(one_hot_label.tolist())
    }
    
    # Create a Features message
    return tf.train.Example(features=tf.train.Features(feature=feature))


def generate_tfrecords(filenames, labels, dir, tfrec_name, preprocess=None, num_folds=15, stratify=True, 
                       shuffle=True, random_state=None, make_summary=True):
    """Given path to images and corresponding labels, creates num_folds tfrecords 
    containing the images"""
    
    if not os.path.exists(dir):
        os.makedirs(dir)
    
    if make_summary:
        summary_filename = os.path.join(dir, tfrec_name,)
        summary_filename += '_summary.csv'
        with open(summary_filename, 'w', encoding='UTF8', newline='') as f:
            csv_writer = csv.writer(f)
            header = ['tfrec_id', '#samples']
            header += [c for c in CLASSES]
            csv_writer.writerow(header)

        f = open(summary_filename, 'a', encoding='UTF8', newline='')
        csv_writer = csv.writer(f)

    if stratify:
        kfold = StratifiedKFold(num_folds, shuffle, random_state)
    else:
        kfold = KFold(num_folds, shuffle, random_state)
    
    for n, (_, indices) in enumerate(kfold.split(filenames, labels)):
                
        name = f'{tfrec_name}_{n}-{len(indices)}.tfrec'

        if make_summary:
            num_samples = str(len(indices))
            classes, count = np.unique(labels[indices], return_counts=True)
            class_counts = np.zeros(len(CLASSES), dtype=np.int64)
            class_counts[classes] = count
            row = [name] + [num_samples] + list(class_counts.astype(str))
            csv_writer.writerow(row)
        
        with tf.io.TFRecordWriter(os.path.join(dir, name)) as writer:

            for index in indices:
                filename = filenames[index]
                label = labels[index]
                img = np.nan_to_num(load_image(filename), copy=False)
                if preprocess != None:
                    img = preprocess(img)
                example = to_tfrecord(img.ravel(), label)
                writer.write(example.SerializeToString())


def generate_tfrecords_2(list_of_filenames, labels, dir, tfrec_name, num_folds=15, 
                       stratify=True, shuffle=True, random_state=None, make_summary=True):
    
    if not os.path.exists(dir):
        os.makedirs(dir)
    
    if make_summary:
        summary_filename = os.path.join(dir, tfrec_name,)
        summary_filename += '_summary.csv'
        with open(summary_filename, 'w', encoding='UTF8', newline='') as f:
            csv_writer = csv.writer(f)
            header = ['tfrec_id', '#samples']
            header += [c for c in CLASSES]
            csv_writer.writerow(header)

        f = open(summary_filename, 'a', encoding='UTF8', newline='')
        csv_writer = csv.writer(f)

    if stratify:
        kfold = StratifiedKFold(num_folds, shuffle, random_state)
    else:
        kfold = KFold(num_folds, shuffle, random_state)
    
    for n, (_, indices) in enumerate(kfold.split(list_of_filenames[0], labels)):
                
        name = f'{tfrec_name}_{n}-{len(indices)}.tfrec'

        if make_summary:
            num_samples = str(len(indices))
            classes, count = np.unique(labels[indices], return_counts=True)
            class_counts = np.zeros(len(CLASSES), dtype=np.int64)
            class_counts[classes] = count
            row = [name] + [num_samples] + list(class_counts.astype(str))
            csv_writer.writerow(row)
        
        with tf.io.TFRecordWriter(os.path.join(dir, name)) as writer:

            for index in indices:

                filename_pet = list_of_filenames[0][index]
                filename_mri = list_of_filenames[1][index]

                label = labels[index]

                img_pet = np.nan_to_num(load_image(filename_pet), copy=False)
                img_mri = np.nan_to_num(load_image(filename_mri), copy=False)

                img_pet = preprocess_pet(img_pet)
                img_mri = preprocess_mri(img_mri)

                example = to_tfrecord_2(img_pet.ravel(), img_mri.ravel(), label)
                writer.write(example.SerializeToString())

# Configure where to save tfrecords

In [5]:
drive.mount('/content/drive')
DATA_PATH = '/content/drive/MyDrive/data/'

Mounted at /content/drive


# Preprocessed images

In [10]:
DS = 'ad-preprocessed'
DS_PATH = DATA_PATH + DS
CLASSES = ['NOR', 'AD', 'MCI'] # Classes in the dataset
SEED = 156 # Arbitrary seed

## PET

In [None]:
# Path to images
pet_paths = np.empty((0,), dtype=str)
pet_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'PET') + '/*.nii'
    pet_paths = np.concatenate((pet_paths, np.array(tf.io.gfile.glob(pattern))))
    pet_labels = np.concatenate((pet_labels, np.full(len(pet_paths) - len(pet_labels), label, dtype=np.int64)))

idx = np.argsort(pet_paths)
pet_paths, pet_labels = pet_paths[idx], pet_labels[idx]

X_train, X_test, y_train, y_test = train_test_split(pet_paths, pet_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = pet_labels)

OUT_DS = 'tfrec-pet-preprocessed'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = max_intensity_normalization(image, 0.01)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), False, False)

## MRI

In [None]:
mri_grey_paths = np.empty((0,), dtype=str)
mri_grey_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'MRI/grey') + '/*.nii'
    mri_grey_paths = np.concatenate((mri_grey_paths, np.array(tf.io.gfile.glob(pattern))))
    mri_grey_labels = np.concatenate((mri_grey_labels, np.full(len(mri_grey_paths) - len(mri_grey_labels), label, dtype=np.int64)))
    
idx = np.argsort(mri_grey_paths)
mri_grey_paths, mri_grey_labels = mri_grey_paths[idx], mri_grey_labels[idx]

X_train, X_test, y_train, y_test = train_test_split(mri_grey_paths, mri_grey_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = mri_grey_labels)

OUT_DS = 'tfrec-mri-preprocessed-downscale'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = downscale(image, (75, 90, 75, 1))
    image = standarize(image)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), stratify=False, shuffle=False, random_state=None)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), stratify=False, shuffle=False)


OUT_DS = 'tfrec-mri-preprocessed'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = standarize(image)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), stratify=False, shuffle=False, random_state=None)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), stratify=False, shuffle=False)

# Non preprocessed images (PET)

In [None]:
DS = 'ad-raw'
DS_PATH = DATA_PATH + DS
CLASSES = ['NOR', 'AD', 'MCI'] # Classes in the dataset
SEED = 156 # Arbitrary seed

# Path to images
pet_paths = np.empty((0,), dtype=str)
pet_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'PET') + '/*.nii'
    pet_paths = np.concatenate((pet_paths, np.array(tf.io.gfile.glob(pattern))))
    pet_labels = np.concatenate((pet_labels, np.full(len(pet_paths) - len(pet_labels), label, dtype=np.int64)))

X_train, X_test, y_train, y_test = train_test_split(pet_paths, pet_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = pet_labels)

# Raw
OUT_DS = 'tfrec-pet-raw'
OUT_PATH = DATA_PATH + OUT_DS

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', None,
                   len(X_train), False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', None,
                   len(X_test), False, False)

# Standarized
OUT_DS = 'tfrec-pet-raw-standarized'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = standarize(image)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), False, False)

# Minmax
OUT_DS = 'tfrec-pet-raw-minmax'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = minmax(image)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), False, False)

# Minmax + standarized
OUT_DS = 'tfrec-pet-raw-minmax-standarized'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = minmax(image)
    image = standarize(image)
    return image

# Maxintensity
OUT_DS = 'tfrec-pet-raw-maxintensity'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = max_intensity_normalization(image, 0.01)
    return image

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', preprocess,
                   len(X_train), False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', preprocess,
                   len(X_test), False, False)

# COVID-19 images

In [None]:
DS = 'COVID19'
DS_PATH = DATA_PATH + DS
CLASSES = ['normal', 'covid'] # Classes in the dataset
SEED = 156 # Arbitrary seed

# Path to images
covid_paths = np.empty((0,), dtype=str)
covid_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c) + '/*.nii.gz'
    covid_paths = np.concatenate((covid_paths, np.array(tf.io.gfile.glob(pattern))))
    covid_labels = np.concatenate((covid_labels, np.full(len(covid_paths) - len(covid_labels), label, dtype=np.int64)))


OUT_DS = 'tfrec-covid19'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess(image):
    image = np.squeeze(image)
    image = normalize(image)
    image = resize_img(image)
    image = np.expand_dims(image, axis=3)
    return image

generate_tfrecords(covid_paths, covid_labels, OUT_PATH, 'covid_dataset', preprocess,
                   len(covid_paths), False, False)

# MRI + PET

In [None]:
DS = 'ad-preprocessed'
DS_PATH = DATA_PATH + DS
CLASSES = ['NOR', 'AD', 'MCI'] # Classes in the dataset
SEED = 156 # Arbitrary seed

# Path to PET images
pet_paths = np.empty((0,), dtype=str)
pet_labels = np.empty((0,), dtype=np.int64)


for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'PET') + '/*.nii'
    pet_paths = np.concatenate((pet_paths, np.array(tf.io.gfile.glob(pattern))))
    pet_labels = np.concatenate((pet_labels, np.full(len(pet_paths) - len(pet_labels), label, dtype=np.int64)))

idx = np.argsort(pet_paths)
pet_paths, pet_labels = pet_paths[idx], pet_labels[idx]

X_train_pet, X_test_pet, y_train_pet, y_test_pet = train_test_split(pet_paths, pet_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = pet_labels)

# Path to MRI, grey matter images
mri_grey_paths = np.empty((0,), dtype=str)
mri_grey_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c, 'MRI/grey') + '/*.nii'
    mri_grey_paths = np.concatenate((mri_grey_paths, np.array(tf.io.gfile.glob(pattern))))
    mri_grey_labels = np.concatenate((mri_grey_labels, np.full(len(mri_grey_paths) - len(mri_grey_labels), label, dtype=np.int64)))
    
idx = np.argsort(mri_grey_paths)
mri_grey_paths, mri_grey_labels = mri_grey_paths[idx], mri_grey_labels[idx]

X_train_mri, X_test_mri, y_train_mri, y_test_mri = train_test_split(mri_grey_paths, mri_grey_labels,
                                                    test_size = 0.2,
                                                    random_state = SEED,
                                                    stratify = mri_grey_labels)


OUT_DS = 'tfrec-pet-mri'
OUT_PATH = DATA_PATH + OUT_DS

def preprocess_pet(image):
    image = max_intensity_normalization(image, 0.01)
    return image

def preprocess_mri(image):
    image = standarize(image)
    return image

generate_tfrecords_2([X_train_pet, X_train_mri], y_train_pet, OUT_PATH + '/train', 'train',
                   len(X_train_pet), False, False)

generate_tfrecords_2([X_test_pet, X_test_mri], y_train_pet, OUT_PATH + '/test', 'test',
                   len(X_test_pet), False, False)