<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/generate_covid19_tfrecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

In [2]:
import sys
import numpy as np, os, shutil, math
import tensorflow as tf, csv
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
import nibabel as nib
import skimage.transform as transform
import matplotlib.pyplot as plt
from scipy import ndimage

In [3]:
def load_image(path, add_axis=True):
    img = nib.load(path)
    img = np.asarray(img.dataobj, dtype=np.float32)
    if add_axis:
        img = np.expand_dims(img, axis=3) # Add axis for channel
    return img

def normalize(img):
    min = -1000
    max = 400
    img[img < min] = min
    img[img > max] = max
    img = (img - min)/(max - min)
    return img
    
def resize_img(img, shape=(64, 128, 128)):
    width = img.shape[0] / shape[0]
    height = img.shape[1] / shape[1]
    depth = img.shape[2] / shape[2]

    depth_factor = 1/depth
    width_factor = 1/width
    height_factor = 1/height

    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process(path):
    img = load_image(path, add_axis=False)
    img = normalize(img)
    img = resize_img(img)
    img = np.expand_dims(img, axis=3) # Channel's axis
    return img

In [4]:
# We can store three types of data in a TFRecord: bytestring, integer and floats. 
# They are always stored as lists, a single data element will be a list of size 1
def _bytestring_feature(list_of_bytestrings):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

def _float_feature(list_of_floats): # float32
    return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

def _int_feature(list_of_ints): # int64
    return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

def to_tfrecord(image, label):
    
    one_hot_label = np.eye(3, dtype=np.float32)[label]
        
    feature = {
        'image': _float_feature(image),
        'one_hot_label': _float_feature(one_hot_label.tolist())
    }
    
    # Create a Features message
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [5]:

def generate_tfrecords(filenames, labels, dir, tfrec_name, num_folds=15, 
                       stratify=True, shuffle=True, random_state=None, make_summary=True):
    """Given path to images and corresponding labels, creates num_folds tfrecords 
    containing the images"""
    
    if not os.path.exists(dir):
        os.makedirs(dir)
    
    if make_summary:
        summary_filename = os.path.join(dir, tfrec_name,)
        summary_filename += '_summary.csv'
        with open(summary_filename, 'w', encoding='UTF8', newline='') as f:
            csv_writer = csv.writer(f)
            header = ['tfrec_id', '#samples']
            header += [c for c in CLASSES]
            csv_writer.writerow(header)

        f = open(summary_filename, 'a', encoding='UTF8', newline='')
        csv_writer = csv.writer(f)

    if stratify:
        kfold = StratifiedKFold(num_folds, shuffle, random_state)
    else:
        kfold = KFold(num_folds, shuffle, random_state)
    
    for n, (_, indices) in enumerate(kfold.split(filenames, labels)):
                
        name = f'{tfrec_name}_{n}-{len(indices)}.tfrec'

        if make_summary:
            num_samples = str(len(indices))
            classes, count = np.unique(labels[indices], return_counts=True)
            class_counts = np.zeros(len(CLASSES), dtype=np.int64)
            class_counts[classes] = count
            row = [name] + [num_samples] + list(class_counts.astype(str))
            csv_writer.writerow(row)
        
        with tf.io.TFRecordWriter(os.path.join(dir, name)) as writer:

            for index in indices:
                filename = filenames[index]
                label = labels[index]
                img = np.nan_to_num(process(filename), copy=False)
                example = to_tfrecord(img.ravel(), label)
                writer.write(example.SerializeToString())

In [6]:
CLASSES = ['normal', 'covid']

drive.mount('/content/drive')
DATA_PATH = '/content/drive/My Drive/data/'

SEED = 27

Mounted at /content/drive


In [9]:
DS = 'COVID19'
DS_PATH = DATA_PATH + DS

# Path to images
ct_paths = np.empty((0,), dtype=str)
ct_labels = np.empty((0,), dtype=np.int64)

for label, c in enumerate(CLASSES):
    pattern = os.path.join(DS_PATH, c) + '/*.nii.gz'
    ct_paths = np.concatenate((ct_paths, np.array(tf.io.gfile.glob(pattern))))
    ct_labels = np.concatenate((ct_labels, np.full(len(ct_paths) - len(ct_labels), label, dtype=np.int64)))

X_train, X_test, y_train, y_test = train_test_split(ct_paths, ct_labels, test_size = 0.2,
                                                    random_state = SEED, stratify = ct_labels)

In [None]:
OUT_DS = 'tfrec-covid19'
OUT_PATH = DATA_PATH + OUT_DS

generate_tfrecords(X_train, y_train, OUT_PATH + '/train', 'train', len(X_train), 
                   False, False)

generate_tfrecords(X_test, y_test, OUT_PATH + '/test', 'test', len(X_test),
                   False, False)