In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime
from Models.Transformer import Transformer
from Models.ESBNTransformer import ESBNTransformer2
from Utils.Logging import BatchLoggingModel
from Utils.preprocess_scan import generate_data
import tensorboard
tf.config.run_functions_eagerly(False)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Generate data and get vocab sizes
train_ds, val_ds, command_vocab_len, action_vocab_len = generate_data()

2024-04-02 11:08:14.060263: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [3]:
# Define evaluation metrics
def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [4]:
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

esbn_transformer = ESBNTransformer2(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=command_vocab_len,
    target_vocab_size=action_vocab_len,
    dropout_rate=dropout_rate)

In [5]:
esbn_transformer.compile(
    loss=masked_loss,
    optimizer="Adam",
    metrics=[masked_accuracy])

In [6]:
log_dir = "logs/fit/scan_transformer/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
esbn_transformer.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
 88/458 [====>.........................] - ETA: 18s - loss: 0.0553 - masked_accuracy: 0.9783

KeyboardInterrupt: 

In [7]:
esbn_transformer.summary()

Model: "esbn_transformer2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  42784     
                                                                 
 decoder (Decoder)           multiple                  76288     
                                                                 
 decoder_1 (Decoder)         multiple                  76288     
                                                                 
 esbn_encoder (ESBNEncoder)  multiple                  52323     
                                                                 
 dense_18 (Dense)            multiple                  650       
                                                                 
Total params: 248,333
Trainable params: 248,333
Non-trainable params: 0
_________________________________________________________________


In [None]:
tf.config.run_functions_eagerly(True)

# Ordinary Transformer
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=len(command_processor.get_vocabulary()),
    target_vocab_size=len(action_processor.get_vocabulary()),
    dropout_rate=dropout_rate)

transformer.compile(
    loss=masked_loss,
    optimizer="Adam",
    metrics=[masked_accuracy])

log_dir = "logs/fit/scan_transformer/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
transformer.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20

In [None]:
transformer.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_1 (Encoder)         multiple                  85024     
                                                                 
 decoder_1 (Decoder)         multiple                  152256    
                                                                 
 dense_37 (Dense)            multiple                  330       
                                                                 
Total params: 237,610
Trainable params: 237,610
Non-trainable params: 0
_________________________________________________________________
