In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime
from Models.Transformer import Transformer
import tensorboard

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_raw, test_raw = tfds.load("scan/addprim_jump", split=["train", "test"])

def standardize(text):
    text = tf.strings.join(["<SOS>", text, "<EOS>"], separator=' ')
    return text


command_processor = tf.keras.layers.TextVectorization(
    standardize=standardize
)
action_processor = tf.keras.layers.TextVectorization(
    standardize=standardize
)

command_processor.adapt(train_raw.map(lambda input: input["commands"]))
action_processor.adapt(train_raw.map(lambda input: input["actions"]))

def process_scan(pair):
    command = command_processor(pair["commands"])
    action = action_processor(pair["actions"])
    action_in = action[:, :-1]
    action_out = action[:, 1:]
    return (command, action_in), action_out

train_raw = train_raw.batch(32, drop_remainder=True).prefetch(20)
test_raw = test_raw.batch(32, drop_remainder=True).prefetch(20)

train_ds = train_raw.map(process_scan, tf.data.AUTOTUNE)
val_ds = test_raw.map(process_scan, tf.data.AUTOTUNE)


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
2024-03-20 19:49:36.864703: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [3]:
for (command, action_in), action_out in train_ds.take(1):
    print(command.shape)
    print(action_in.shape)
    print(action_out.shape)

(32, 11)
(32, 41)
(32, 41)


In [8]:
num_layers = 2
d_model = 32
dff = 64
num_heads = 4
dropout_rate = 0.1

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=len(command_processor.get_vocabulary()),
    target_vocab_size=len(action_processor.get_vocabulary()),
    dropout_rate=dropout_rate)

In [9]:
lr = 10e-4
optimizer = tf.keras.optimizers.Adam(lr)

def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [10]:
transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy])

In [11]:
#tf.config.run_functions_eagerly(True)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
transformer.fit(train_ds,
                epochs=20,
                validation_data=val_ds,
                callbacks=[tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
 37/458 [=>............................] - ETA: 11s - loss: 0.0696 - masked_accuracy: 0.9720

KeyboardInterrupt: 