In [None]:
import os
import datetime
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, Input
import tensorboard
from tensorboard.plugins.hparams import api as hp
from src.preprocessing.image_gen import MultiTaskImageGen2, BalanceImageGenerator


In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='fyp',
                                                          zone='europe-west4-a',
                                                          project='eeefyp')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy=tf.distribute.experimental.TPUStrategy(resolver)


In [None]:
TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
AUTOTUNE = tf.data.experimental.AUTOTUNE
DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/clean')
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'logs/norm_exp/{TIMESTAMP}')
BATCH_SIZE = 16
CLASS_LABELS = ['FAW', 'zinc_def', 'healthy']
EPOCHS = 20

In [None]:
feature_description = {
    'rows': tf.io.FixedLenFeature([1], tf.int64),
    'cols': tf.io.FixedLenFeature([1], tf.int64),
    'channels': tf.io.FixedLenFeature([1], tf.int64),
    'image': tf.io.FixedLenFeature([1], tf.string),
    'labels': tf.io.VarLenFeature(tf.float32)
}

In [None]:
ds_faw = MultiTaskImageGen2(os.path.join(DATADIR, 'final/faw.tfrecord'), feature_description)
test_faw, val_faw = ds_faw.split_dataset()
img_ds = ds_faw.get_train_img()

ds_healthy = MultiTaskImageGen2(os.path.join(DATADIR, 'final/healthy.tfrecord'), feature_description)
test_healthy, val_healthy = ds_healthy.split_dataset()
img_ds = img_ds.concatenate(ds_healthy.get_train_img())

ds_zinc = MultiTaskImageGen2(os.path.join(DATADIR, 'final/zinc_def.tfrecord'), feature_description)
test_zinc, val_zinc = ds_zinc.split_dataset()
img_ds = img_ds.concatenate(ds_zinc.get_train_img())

# ds_nlb = MultiTaskImageGen(os.path.join(DATADIR, 'final/NLB/nlb'), 256, CLASS_LABELS)
# ds_nlb_h = MultiTaskImageGen(os.path.join(DATADIR, 'final/NLB/nlb'), 256, CLASS_LABELS)

In [None]:
test = test_faw.concatenate(test_healthy)
test = test.concatenate(test_zinc).shuffle(1000)
test = test.batch(2*BATCH_SIZE)

val = val_faw.concatenate(val_healthy)
val = val.concatenate(val_zinc)
val = val.batch(2*BATCH_SIZE)


In [None]:
num_healthy = len(os.listdir(os.path.join(DATADIR, 'final/healthy')))
STEPS_PER_EPOCH = np.ceil(3*0.8*0.8*num_healthy/BATCH_SIZE)
print(num_healthy, STEPS_PER_EPOCH)


In [None]:
balance_ds = BalanceImageGenerator(BATCH_SIZE, ds_faw(), ds_healthy(), ds_zinc())()


In [None]:
initializer = tf.keras.initializers.he_normal()
loss = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()
METRICS = [tf.keras.metrics.BinaryAccuracy(name='acc'),
            tf.keras.metrics.Precision(name='psn'),
            tf.keras.metrics.Recall(name='rcl'),
            tf.keras.metrics.AUC(name='AUC')]

In [None]:
with strategy.scope():
    model = make_model((256,256,3), METRICS, optimizer, loss, initializer)
    

In [None]:
model.fit(balance_ds,
          epochs=100,
          steps_per_epoch=STEPS_PER_EPOCH,
          validation_data=val,
          callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOGDIR, histogram_freq=1),
                     tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)])

In [None]:
model.evaluate(test, callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOGDIR)])