<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/check_tfrecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
import numpy as np, os
import tensorflow as tf
import nibabel as nib
from google.colab import drive

# Connect to TPU

In [None]:
DEVICE = 'TPU'

if DEVICE == 'TPU':
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        STRATEGY = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU, setting default strategy')
        tpu = None
        STRATEGY = tf.distribute.get_strategy()
elif DEVICE == 'GPU':
    tpu = None
    STRATEGY = tf.distribute.MirroredStrategy()
    
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = STRATEGY.num_replicas_in_sync

print(f'Number of accelerators: {REPLICAS}')

In [74]:
NUM_CLASSES = 3
IMG_SHAPE = (79, 95, 68, 1)

def load_image(path):    

    img = nib.load(path)
    img = np.asarray(img.dataobj, dtype=np.float32)
    img = np.expand_dims(img, axis=3) # Add dummy axis for channel
    return img

def max_intensity_normalization(X, proportion):

    n_max_values = int(np.prod(X.shape, axis=0) * proportion)
    n_max_idx = np.unravel_index((X).argsort(axis=None)[-n_max_values:], X.shape)
    mean = np.mean(X[n_max_idx])
    X /= mean

def preprocess_image(X, steps, arguments):

    for f, args in zip(steps, arguments):
        f(X, *arguments)


def read_tfrecord(example):
    tfrec_format = {
        "image": tf.io.VarLenFeature(tf.float32),
        "one_hot_label": tf.io.VarLenFeature(tf.float32),
        "filename": tf.io.FixedLenFeature([], tf.string) 
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    one_hot_label = tf.sparse.to_dense(example['one_hot_label'])
    one_hot_label = tf.reshape(one_hot_label, [NUM_CLASSES])
    image = tf.reshape(tf.sparse.to_dense(example['image']), IMG_SHAPE)
    filename = example['filename']
    
    # TPU needs size to be known, so this doesn't work
    #     image  = tf.reshape(example['image'], example['shape']) 
    return image, one_hot_label, filename

def load_dataset(filenames):
    
    # Allow order-altering optimizations
    option_no_order = tf.data.Options()
    option_no_order.experimental_deterministic = False
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    dataset = dataset.with_options(option_no_order)
    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO)
    return dataset

def get_dataset(filenames, batch_size = 8):
    dataset =  load_dataset(filenames)
    dataset = dataset.batch(batch_size * REPLICAS)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
drive.mount('/content/drive')
DATA_PATH = '/content/drive/My Drive/data/'

DS_PATH = DATA_PATH + 'tfrec-20skf-PET-spatialnorm-elastic-wfilenames' # or GCS path

# Check Tfrecords

Compare image readed from tfrecord with the original one

In [None]:
preprocess_steps = [max_intensity_normalization]
preprocess_args = [(0.01)]

def check_image(path, tfrec_img):
    img = load_image(path)
    np.nan_to_num(img, copy=False)
    # preprocess_image(img, preprocess_steps, preprocess_args)


    return np.all(img==tfrec_img)

def get_label(path):
    if '/NOR/' in path:
        return np.array([1,0,0], dtype=int)
    elif '/AD/' in path:
        return np.array([0,1,0], dtype=int)
    elif '/MCI/' in path:
        return np.array([0,0,1], dtype=int)


get_filenames = lambda pattern : tf.io.gfile.glob(pattern)

pet_train = get_dataset(get_filenames(DS_PATH + '/train/*.tfrec'), batch_size=1)
pet_test = get_dataset(get_filenames(DS_PATH + '/test/*.tfrec'), batch_size=1)

for i, data in enumerate(pet_train):

    print(i)

    filename = data[2].numpy()[0].decode()
    image = data[0].numpy()[0]
    label = data[1].numpy()[0].astype(int)

    if not check_image(filename, image):
        print('Images differ')
    
    if not np.all(label == get_label(filename)):
        print('Label and file does not match')