In [1]:
import tensorflow as tf
import datetime
from Models.Transformers import Transformer, ESBNTransformer, ESBNTransformerSCA
from Utils.preprocess_scan import generate_data
import tensorboard
%load_ext tensorboard
tf.config.run_functions_eagerly(False)

  from .autonotebook import tqdm as notebook_tqdm


In this notebook, we compare two variants of integrating the esbn into a transformer against the standard transformer. The code for the transformer is adapted from the tensorflow blog.

In [2]:
# Generate data and get vocab sizes
train_ds, val_ds, command_vocab_len, action_vocab_len = generate_data()

2024-04-03 19:40:51.760034: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [3]:
# Define evaluation metrics
def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [4]:
# Ordinary Transformer
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=command_vocab_len,
    target_vocab_size=action_vocab_len,
    dropout_rate=dropout_rate)

transformer.compile(
    loss=masked_loss,
    optimizer="Adam",
    metrics=[masked_accuracy])

log_dir = "logs/fit/scan_transformer/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
transformer.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

transformer.summary()

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  42784     
                                                                 
 decoder (Decoder)           multiple                  76288     
                                                                 
 dense_8 (Dense)             multiple                  330       
                                                                 
Total params: 119,402
Trainable params: 119,402
Non-trainable params: 0
_________________________________________________________________


In [5]:
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

esbn_transformer = ESBNTransformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=command_vocab_len,
    target_vocab_size=action_vocab_len,
    dropout_rate=dropout_rate)

esbn_transformer.compile(
    loss=masked_loss,
    optimizer="Adam",
    metrics=[masked_accuracy])

log_dir = "logs/fit/scan_transformer_esbn/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
esbn_transformer.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

esbn_transformer.summary()

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Model: "esbn_transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_1 (Encoder)         multiple                  42784     
                                                                 
 decoder_1 (Decoder)         multiple                  76288     
                                                                 
 decoder_2 (Decoder)         multiple                  76288     
                                                                 
 esbn_encoder (ESBNEncoder)  multiple                  52323     
                                                                 
 dense_27 (Dense)            multiple                  650       
                 

In [6]:
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

esbn_transformer_sca = ESBNTransformerSCA(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=command_vocab_len,
    target_vocab_size=action_vocab_len,
    dropout_rate=dropout_rate)

esbn_transformer_sca.compile(
    loss=masked_loss,
    optimizer="Adam",
    metrics=[masked_accuracy])

log_dir = "logs/fit/scan_transformer_esbn_sca/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
esbn_transformer_sca.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

esbn_transformer_sca.summary()

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Model: "esbn_transformer_sca"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_2 (Encoder)         multiple                  42784     
                                                                 
 decoder_3 (Decoder)         multiple                  76288     
                                                                 
 decoder_4 (Decoder)         multiple                  76288     
                                                                 
 esbn_encoder_cross_attentio  multiple                 52323     
 n (ESBNEncoderCrossAttentio                                     
 n)                                                              
             

In [7]:
esbn_transformer_sca.summary()

Model: "esbn_transformer_sca"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_2 (Encoder)         multiple                  42784     
                                                                 
 decoder_3 (Decoder)         multiple                  76288     
                                                                 
 decoder_4 (Decoder)         multiple                  76288     
                                                                 
 esbn_encoder_cross_attentio  multiple                 52323     
 n (ESBNEncoderCrossAttentio                                     
 n)                                                              
                                                                 
 dense_46 (Dense)            multiple                  650       
                                                                 
Total params: 248,333
Trainable params: 248,33

In [9]:
%tensorboard --logdir logs/fit

Reusing TensorBoard on port 6006 (pid 58338), started 0:04:30 ago. (Use '!kill 58338' to kill it.)