<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/check_tfrecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import numpy as np, os
import tensorflow as tf
import nibabel as nib

In [None]:
# Kaggle only
from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient

In [None]:
if os.path.exists('cloned_repo'):
    shutil.rmtree('cloned_repo')
    
!git clone -l -s https://github.com/Angelvj/TFG.git cloned_repo

# Imports from my github repo
from cloned_repo.code.image_reading import *

# Initialize TPU (if pressent)

In [None]:
DEVICE = "TPU"

if DEVICE == "TPU":
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU')
        tpu = None
        strategy = tf.distribute.get_strategy()
        print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'Number of accelerators: {REPLICAS}')

# Acess GCS

In [None]:
# Kaggle only
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()

user_secrets.set_tensorflow_credential(user_credential)

GCS_DS_PATH = KaggleDatasets().get_gcs_path('ad-preprocessed-tfrecords-20skf')

In [None]:
# Google Colab
GCS_DS_PATH = 'set path to GCS here'

# Read images and labels from TFRecords

In [None]:
def read_tfrecord(example):
    tfrec_format = {
        "image": tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
#         "image": tf.io.VarLenFeature(tf.float32),
        "label": tf.io.FixedLenFeature([], tf.int64),
        "one_hot_label": tf.io.VarLenFeature(tf.float32),
        "shape": tf.io.FixedLenFeature([4], tf.int64),
        "filename": tf.io.FixedLenFeature([], tf.string) # Only for test, TODO: delete
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    one_hot_label = tf.sparse.to_dense(example['one_hot_label'])
    one_hot_label = tf.reshape(one_hot_label, [3])
    image  = tf.reshape(example['image'], example['shape'])
#     label = example['label']

    return image, one_hot_label

def load_dataset(filenames):
    # Allow order-altering optimizations
    
    option_no_order = tf.data.Options()
    option_no_order.experimental_deterministic = False
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    dataset = dataset.with_options(option_no_order)
    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO)
    return dataset

In [None]:
def data_augment(image, one_hot_class):
#     Call here image augmentation functions
#     image = tf.image.random_flip_left_right(image)
#     image = tf.image.random_saturation(image, 0, 2)
    return image, one_hot_class

def get_batched_dataset(filenames, batch_size = 4, train=False, augment=True):
    dataset =  load_dataset(filenames)
    dataset = dataset.cache() # Only if dataset fits in ram
    if train:
        dataset = dataset.repeat()
#         if augment:
#             dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        dataset = dataset.shuffle(len(filenames)) # Not for shure
    dataset = dataset.batch(batch_size * REPLICAS)
    dataset = dataset.prefetch(AUTO)
    return dataset

# Create datasets (examples)

In [None]:
# GCS_DS_PATH = '/kaggle/input/ad-preprocessed-tfrecords-20skf'

get_filenames = lambda pattern : tf.io.gfile.glob(pattern)

pet_train = get_batched_dataset(get_filenames(GCS_DS_PATH + '/PET/train/*.tfrec'), train=True, batch_size=8)
pet_test = get_batched_dataset(get_filenames(GCS_DS_PATH + '/PET/test/*.tfrec'), train=False, batch_size=8)
mri_grey_train = get_batched_dataset(get_filenames(GCS_DS_PATH + '/MRI/grey/train/*.tfrec'), train=True, batch_size=8)
mri_grey_test = get_batched_dataset(get_filenames(GCS_DS_PATH + '/MRI/grey/test/*.tfrec'), train=False, batch_size=8)
mri_white_train = get_batched_dataset(get_filenames(GCS_DS_PATH + '/MRI/white/train/*.tfrec'), train=True, batch_size=8)
mri_white_test = get_batched_dataset(get_filenames(GCS_DS_PATH + '/MRI/white/test/*.tfrec'), train=False, batch_size=8)