# Dementia Template - Dummy Data + 3 Towers

Dementia template with dummy data using three autoencoder towers.

In [5]:
import os, numpy as np, pandas as pd
from tensorflow import losses, optimizers, metrics
from tensorflow.keras import Input, Model, layers, callbacks, regularizers
from jarvis.train import custom, datasets, params
from jarvis.train.client import Client
from jarvis.utils.general import gpus, overload, tools as jtools

import tensorflow as tf
import tensorflow.keras.backend as K

In [6]:
class CustomModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data
        

        with tf.GradientTape() as tape:
#             print("Shapes:")
            tf.print("\n shape:", x.shape, y.shape, output_stream=sys.stdout)
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value.
            # The loss function is configured in `compile()`.
            loss = self.compiled_loss(
                y,
                y_pred,
                sample_weight=sample_weight,
                regularization_losses=self.losses,
            )

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics.
        # Metrics are configured in `compile()`.
        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


In [7]:
def get_inputs():
    lo = -0.5
    hi = +0.5
    rand = lambda shape : np.random.rand(*shape) * (hi - lo) + lo

    pos = rand((96, 160, 160, 1))
    neg = rand((96, 160, 160, 1))
    unk = rand((96, 160, 160, 1))
    inputs = {"pos": Input(pos.shape, name='pos'), "neg": Input(neg.shape, name='neg'), "unk": Input(unk.shape, name='unk')}
    
    dat = rand((96, 160, 160, 1))
    inputs = {"dat": Input(dat.shape, batch_size=3, name='dat')}
    
    return inputs

In [8]:
def cosine_similarity(vects):
    """Find the cosine similarity between two vectors.

    Arguments:
        vects: List containing two tensors of same length.

    Returns:
        Tensor containing cosine similarity
        (as floating point value) between vectors.
    """
    
    x, y = vects
    
    x = tf.math.l2_normalize(x, axis=1)
    y = tf.math.l2_normalize(y, axis=1)
    return -tf.math.reduce_mean(x * y, axis=1, keepdims=True)

In [9]:
def euclidean_distance(vects):
    """Find the Euclidean distance between two vectors.

    Arguments:
        vects: List containing two tensors of same length.

    Returns:
        Tensor containing euclidean distance
        (as floating point value) between vectors.
    """

    x, y = vects
    sum_square = tf.math.reduce_sum(tf.math.square(x - y), axis=1, keepdims=True)
    return tf.math.sqrt(tf.math.maximum(sum_square, tf.keras.backend.epsilon()))

In [10]:
def loss(margin=1):
    """Provides 'constrastive_loss' an enclosing scope with variable 'margin'.

  Arguments:
      margin: Integer, defines the baseline for distance for which pairs
              should be classified as dissimilar. - (default is 1).

  Returns:
      'constrastive_loss' function with data ('margin') attached.
  """

    # Contrastive loss = mean( (1-true_value) * square(prediction) +
    #                         true_value * square( max(margin-prediction, 0) ))
    def contrastive_loss(y_true, y_pred):
        """Calculates the constrastive loss.

      Arguments:
          y_true: List of labels, each label is of type float32.
          y_pred: List of predictions of same length as of y_true,
                  each label is of type float32.

      Returns:
          A tensor containing constrastive loss as floating point value.
      """

        square_pred = tf.math.square(y_pred)
        margin_square = tf.math.square(tf.math.maximum(margin - (y_pred), 0))
        return tf.math.reduce_mean(
            (1 - y_true) * square_pred + (y_true) * margin_square
        )

    return contrastive_loss


In [11]:
def prepare_model(inputs):
    
    # --- Define lambda functions
    
    kwargs = {
        'kernel_size': (3, 3, 3),
        'padding': 'same',
        'kernel_initializer': 'he_uniform'
    }
    conv = lambda x, filters, strides : layers.Conv3D(filters=filters, strides=strides, **kwargs)(x)
    norm = lambda x : layers.BatchNormalization()(x)
    acti = lambda x : layers.LeakyReLU()(x)
    tran = lambda x, filters, strides : layers.Conv3DTranspose(filters=filters, strides=strides, **kwargs)(x)
    
    conv1 = lambda filters, x : norm(acti(conv(x, filters, strides=1)))
    conv2 = lambda filters, x : norm(acti(conv(x, filters, strides=(2, 2, 2))))
    tran2 = lambda filters, x : norm(acti(tran(x, filters, strides=(2, 2, 2))))
    
    # --- Define autoencoder network
    
    inp = Input((96, 160, 160, 1))
    e1 = conv1(4, inp)
    e2 = conv1(8, conv2(8, e1))
    e3 = conv1(16, conv2(16, e2))
    e4 = conv1(32, conv2(32, e3))
    e5 = layers.Conv3D(filters=4, kernel_size=(1, 1, 1))(e4)
    e6 = layers.Flatten()(e5)
    e7 = layers.Dense(10, activation='sigmoid', name="enc")(e6)
    d1 = tran2(16, e4)
    d2 = conv1(8, tran2(8, d1))
    d3 = conv1(4, tran2(8, d2))
    d4 = layers.Conv3D(filters=1, kernel_size=(1, 1, 1), name="dec")(d3)
    
    autoencoder_logits = {}
    autoencoder_logits["enc"] = e7
    autoencoder_logits["dec"] = d4
    
    autoencoder_network = CustomModel(inputs=inp, outputs=autoencoder_logits)
    
    # --- Define contrastive network
    
    inp_1 = Input((96, 160, 160, 1), name="inp_1")
    inp_2 = Input((96, 160, 160, 1), name="inp_2")
    inp_3 = Input((96, 160, 160, 1), name="inp_3")
    
    tower_1 = autoencoder_network(inp_1)
    tower_2 = autoencoder_network(inp_2)
    tower_3 = autoencoder_network(inp_3)
    
    merge_layer1 = layers.Lambda(cosine_similarity)([tower_1["enc"], tower_2["enc"]])
    merge_layer2 = layers.Lambda(cosine_similarity)([tower_1["enc"], tower_3["enc"]])
    merge_layer3 = layers.Lambda(cosine_similarity)([tower_2["enc"], tower_3["enc"]])
#     merge_layer = layers.Lambda(euclidean_distance)([tower_1["enc"], tower_2["enc"]])
    norm_layer1 = layers.BatchNormalization()(merge_layer1)
    norm_layer2 = layers.BatchNormalization()(merge_layer2)
    norm_layer3 = layers.BatchNormalization()(merge_layer3)
    
    siamese_logits = {}
    siamese_logits["ctr1"] = layers.Dense(1, activation="sigmoid", name="ctr1")(norm_layer1)
    siamese_logits["ctr2"] = layers.Dense(1, activation="sigmoid", name="ctr2")(norm_layer2)
    siamese_logits["ctr3"] = layers.Dense(1, activation="sigmoid", name="ctr3")(norm_layer3)
    siamese_logits["enc"] = layers.Layer(name="enc")(tower_1["enc"])
    siamese_logits["dec1"] = layers.Layer(name="dec1")(tower_1["dec"])
    siamese_logits["dec2"] = layers.Layer(name="dec2")(tower_2["dec"])
    siamese_logits["dec3"] = layers.Layer(name="dec3")(tower_3["dec"])
    
    siamese = CustomModel(inputs=[inp_1, inp_2, inp_3], outputs=siamese_logits)
    
    siamese.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss={
            'ctr1': loss(),
            'ctr2': loss(),
            'ctr3': loss(),
            'dec1': losses.MeanSquaredError(),
            'dec2': losses.MeanSquaredError(),
            'dec3': losses.MeanSquaredError(),
            'enc': losses.BinaryCrossentropy()
        },
        metrics={
            'enc': metrics.Accuracy(),
        },
        experimental_run_tf_function=False
    )
    
    return siamese

In [12]:
model = prepare_model(get_inputs())
print(model.summary())

Model: "custom_model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inp_1 (InputLayer)              [(None, 96, 160, 160 0                                            
__________________________________________________________________________________________________
inp_2 (InputLayer)              [(None, 96, 160, 160 0                                            
__________________________________________________________________________________________________
inp_3 (InputLayer)              [(None, 96, 160, 160 0                                            
__________________________________________________________________________________________________
custom_model (CustomModel)      {'enc': (None, 10),  269087      inp_1[0][0]                      
                                                                 inp_2[0][0]         

In [13]:
def Generator():
    """
    Method to define a Python generator for training data
    
    """
    # --- Define lambda function for random values [-0.5, +0.5]
    lo = -0.5
    hi = +0.5
    rand = lambda shape : np.random.rand(*shape) * (hi - lo) + lo

    # --- Define lambda function for linear transform
    m = 2
    b = -1
    f = lambda x : m * x + b + rand((1, 96, 160, 160, 1))
    
    while True:
        
        xs = {}
        xs['inp_1'] = np.expand_dims(rand((4, 96, 160, 160)), -1)
        xs['inp_2'] = np.expand_dims(rand((4, 96, 160, 160)), -1)
        xs['inp_3'] = np.expand_dims(rand((4, 96, 160, 160)), -1)
        
#         xs['dat'] = ((batch_size, 3, 96, 160, 160, 1))

        ys = {}
        ys['ctr1'] = rand((4, 1))
        ys['ctr2'] = rand((4, 1))
        ys['ctr3'] = rand((4, 1))
        ys['enc'] = rand((4, 10))
        ys['dec1'] = rand((4, 96, 160, 160, 1))
        ys['dec2'] = rand((4, 96, 160, 160, 1))
        ys['dec3'] = rand((4, 96, 160, 160, 1))
                
        yield xs, ys
        
gen_train = Generator()
gen_valid = Generator()
gen_test = Generator()

In [10]:
csv_logger = callbacks.CSVLogger(filename="training_log.csv")

# --- Train
model.fit(
    x=gen_train,
    epochs=1,
    steps_per_epoch=5,
    validation_data=gen_valid,
    validation_steps=10,
    validation_freq=5#,
#     callbacks=[csv_logger]
#     callbacks=[tensorboard_callback]
)



<tensorflow.python.keras.callbacks.History at 0x7f712013dfd0>

In [57]:
import time

# Instantiate an optimizer to train the model.
optimizer = optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn={
    'ctr1': loss(),
    'ctr2': loss(),
    'ctr3': loss(),
    'dec1': losses.MeanSquaredError(),
    'dec2': losses.MeanSquaredError(),
    'dec3': losses.MeanSquaredError(),
    'enc': losses.BinaryCrossentropy()
}

# Prepare the metrics.
train_acc_metric = metrics.Accuracy()
val_acc_metric = metrics.Accuracy()

In [59]:
batch_size = 4
epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(gen_train):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            enc_loss = loss_fn['enc'](y_batch_train['enc'], logits['enc'])
            dec1_loss = loss_fn['dec1'](y_batch_train['dec1'], logits['dec1'])
            dec2_loss = loss_fn['dec2'](y_batch_train['dec2'], logits['dec2'])
            dec3_loss = loss_fn['dec3'](y_batch_train['dec3'], logits['dec3'])
            ctr1_loss = loss_fn['ctr1'](y_batch_train['ctr1'], logits['ctr1'])
            ctr2_loss = loss_fn['ctr2'](y_batch_train['ctr2'], logits['ctr2'])
            ctr3_loss = loss_fn['ctr3'](y_batch_train['ctr3'], logits['ctr3'])
            
            loss_value = enc_loss + dec1_loss + dec2_loss + dec3_loss + ctr1_loss + ctr2_loss + ctr3_loss
#             loss_value = 0.5 * enc_loss + 0.2 * dec_loss + 0.3 * ctr_loss
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train['enc'], logits['enc'])

        # Log every 10 batches.
        if step == 10:
            break
        print(
            "Training loss (for one batch) at step %d: %.4f"
            % (step, float(loss_value))
        )
        print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
#     for x_batch_val, y_batch_val in val_dataset:
#         val_logits = model(x_batch_val, training=False)
#         # Update val metrics
#         val_acc_metric.update_state(y_batch_val, val_logits)
#     val_acc = val_acc_metric.result()
#     val_acc_metric.reset_states()
#     print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))



Start of epoch 0
Training loss (for one batch) at step 0: 4.6364
Seen so far: 4 samples
Training loss (for one batch) at step 1: 4.6669
Seen so far: 8 samples
Training loss (for one batch) at step 2: 4.2962
Seen so far: 12 samples
Training loss (for one batch) at step 3: 3.9099
Seen so far: 16 samples
Training loss (for one batch) at step 4: 3.3724
Seen so far: 20 samples
Training loss (for one batch) at step 5: 3.2586
Seen so far: 24 samples
Training loss (for one batch) at step 6: 3.5254
Seen so far: 28 samples
Training loss (for one batch) at step 7: 3.3762
Seen so far: 32 samples
Training loss (for one batch) at step 8: 2.7795
Seen so far: 36 samples
Training loss (for one batch) at step 9: 3.6156
Seen so far: 40 samples
Training acc over epoch: 0.0000
Time taken: 956.13s

Start of epoch 1
Training loss (for one batch) at step 0: 1.7321
Seen so far: 4 samples
Training loss (for one batch) at step 1: 2.1212
Seen so far: 8 samples
Training loss (for one batch) at step 2: 2.8035
Seen