In [4]:
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Dropout, Embedding, Layer
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizers import adam_v2

print("TensorFlow version: ", tf.__version__)
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='node-7',zone='us-central1-f')
strategy = tf.distribute.TPUStrategy(resolver)
print("Done!")
print("Number of accelerators: ", strategy.num_replicas_in_sync)

TensorFlow version:  2.12.0
Connecting to TPU...
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Initializing the TPU system: node-7


INFO:tensorflow:Initializing the TPU system: node-7


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


Done!
Number of accelerators:  8


In [None]:
def positional_encoding(position, d_model):
    angle_rads = tf.range(position, dtype=tf.float32)[:, tf.newaxis] * 1 / tf.pow(10000, (2 * tf.range(0, d_model, dtype=tf.float32)) / d_model)
    angle_rads_even = tf.math.sin(angle_rads[:, 0::2])
    angle_rads_odd = tf.math.cos(angle_rads[:, 1::2])
    angle_rads = tf.stack([angle_rads_even, angle_rads_odd], axis=-1)
    angle_rads = tf.reshape(angle_rads, (-1, d_model))
    pos_encoding = angle_rads[tf.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

# Encoder layer
class EncoderLayer(Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        # When using MultiHeadAttention inside a custom layer, the custom layer must implement its own build() method and call MultiHeadAttention's _build_from_signature() there. This enables weights to be restored correctly when the model is loaded.
        self.ffn = tf.keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        
    def build(self, input_shape):
        self.mha._build_from_signature(input_shape, input_shape, input_shape)
        super(EncoderLayer, self).build(input_shape)

    def call(self, x, training, mask):
        attn_output = self.mha(x, x, x, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2
    

# Decoder layer
class DecoderLayer(Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.mha2 = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)

        self.ffn = tf.keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        self.dropout3 = Dropout(rate)
        
    def build(self, input_shape):
        self.mha1._build_from_signature(input_shape, input_shape, input_shape)
        self.mha2._build_from_signature(input_shape, input_shape, input_shape)
        super(DecoderLayer, self).build(input_shape)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        attn1 = self.mha1(x, x, x, attention_mask=look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        attn2 = self.mha2(out1, enc_output, enc_output, attention_mask=padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)

        return out3

# Encoder
class Encoder(Layer):
    def __init__(self, num_layers, d_model, num_heads,
        dff, input_vocab_size, maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = Dropout(rate)

    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)

        return x

# Decoder
class Decoder(Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        # attention_weights = {}
        # perhaps save the attention weights here?

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)

        return x

def create_padding_mask(seq):
    seq = tf.convert_to_tensor(seq)
    if len(seq.shape) == 3:
        seq = tf.reduce_sum(seq, axis=-1)
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)



def create_look_ahead_mask(size):
    if size == 1:
        return tf.zeros((1, 1))
    else:
        mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
        return mask  # (seq_len, seq_len)


class Transformer(Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_position_encoding, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, max_position_encoding, rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, max_position_encoding, rate)

        self.final_layer = Dense(target_vocab_size)

    def create_masks(self, inp):
        enc_padding_mask = create_padding_mask(inp)
        look_ahead_mask = create_look_ahead_mask(tf.shape(inp)[1])
        dec_padding_mask = enc_padding_mask

        return enc_padding_mask, look_ahead_mask, dec_padding_mask

    def call(self, inputs, training):
        inp = inputs
        enc_padding_mask, look_ahead_mask, dec_padding_mask = self.create_masks(inp)

        enc_output = self.encoder(inp, training, enc_padding_mask)
        
        # Create a start token for the target sequence
        start_token = tf.constant([[1]], dtype=tf.int32)  # Assuming 1 is the start token ID
        dec_input = tf.tile(start_token, [tf.shape(inp)[0], 1])  # (batch_size, 1)

        dec_output = self.decoder(dec_input, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output)

        return final_output


In [12]:
import os
import time
from transformers import GPT2Tokenizer
tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

saved_transformers_folder = 'gs://saved_transformers'
os.makedirs(saved_transformers_folder, exist_ok=True)
saved_transformer_path = f'{saved_transformers_folder}/v2'

loss_history = []

with strategy.scope():
    # Hyperparameters
    # The transformer model currently has TK parameters.
    num_layers = 6
    d_model = 512
    num_heads = 8
    dff = 2048
    input_vocab_size = tokenizer.vocab_size
    target_vocab_size = tokenizer.vocab_size
    max_position_encoding = 128
    dropout_rate = 0.1
    learning_rate = 5e-5
    batch_size = 32
    epochs = 3
    warmup_steps = 4000

    # Create the transformer model
    # Load the weights if the saved model exists
    if os.path.exists(saved_transformer_path):
        print('Loading the saved model')
        model = tf.keras.models.load_model(saved_transformer_path, custom_objects={'Transformer': Transformer})
        if model is None:
            raise Exception('Failed to load the saved model')
        transformer: Transformer = model
        print('Done')
    else:
        print('Loaded model not found. Creating a new model')
        transformer = Transformer(num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_position_encoding, dropout_rate)
        print('Done')


    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

    def loss_function(real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = loss_object(real, pred)

        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask

        return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
    
    class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, d_model, warmup_steps=4000):
            super(CustomSchedule, self).__init__()

            self.d_model = tf.cast(d_model, tf.float32)
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * (self.warmup_steps ** -1.5)
            learning_rate = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

            return learning_rate

    learning_rate = CustomSchedule(d_model, warmup_steps)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

bucket_path = 'gs://dataset_w/'
input_tfrecord_files = [f'{bucket_path}wikitrain_{i:04d}.tfrecord' for i in range(79)]

# Function to parse a single example from the TFRecord files
def create_windows(sequence, window_size, step=1):
    sequence_length = tf.shape(sequence)[0]
    if sequence_length < window_size:
        # Pad the sequence with zeros if it's shorter than the window size
        pad_size = window_size - sequence_length
        sequence = tf.concat([sequence, tf.zeros(pad_size, dtype=tf.int64)], axis=0)
        num_windows = 1
    else:
        num_windows = (sequence_length - window_size) // step + 1

    windows = tf.TensorArray(dtype=tf.int64, size=num_windows, dynamic_size=True)

    for i in range(num_windows):
        windows = windows.write(i, sequence[i * step:i * step + window_size])

    return windows.stack()


@tf.function
def _parse_function(example_proto):
    feature_description = {
        'token_ids': tf.io.VarLenFeature(tf.int64),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    input_sequence = tf.sparse.to_dense(parsed_features['token_ids'])
    input_sequences = create_windows(input_sequence, max_position_encoding)

    def process_sequences(seq):
        inp = seq[:-1]
        tar = seq[1:]  # Get the last element as the target
        return inp, tar

    input_sequences, target_sequences = tf.map_fn(process_sequences, input_sequences, dtype=(tf.int64, tf.int64))

    return input_sequences, target_sequences  # Add an extra dimension at the end of the tensor



# Load and preprocess the dataset from the TFRecord files
def load_dataset(input_files):
    input_ds = tf.data.TFRecordDataset(input_files)
    input_ds = input_ds.map(_parse_function)
    return input_ds

def print_sequences_as_words(inp, tar):
    inp_tokens = tokenizer.batch_decode(inp.numpy(), skip_special_tokens=True)
    tar_tokens = tokenizer.batch_decode([tar.numpy()], skip_special_tokens=True)

    print("Input:")
    for seq in inp_tokens:
        print(seq)

    print("\nTarget:")
    for seq in tar_tokens:
        print(seq)

print('Processing dataset...')
input_dataset = load_dataset(input_tfrecord_files)
input_dataset = input_dataset.flat_map(lambda x, y: tf.data.Dataset.from_tensor_slices((x, y)))

def print_dataset(input_dataset, num_examples=1):
    for i, (inp, tar) in enumerate(input_dataset.take(num_examples)):
        print(f"Example {i + 1}:")
        print("Input: ", inp.numpy())
        print("Target: ", tar.numpy())
        print_sequences_as_words(inp, tar)
        print("\n")
print_dataset(input_dataset)
dataset = input_dataset.shuffle(buffer_size=1000)
dataset = input_dataset.batch(batch_size, drop_remainder=True)

# def create_tf_dataset(data, tokenizer):
#     def split_input_target(input_string):
#         parts = input_string.strip().split("? ")
#         event, year = parts[0], parts[-1]
#         return event, year

#     events, years = zip(*[split_input_target(item) for item in data])

#     # Encode events using GPT-2 tokenizer
#     encoded_events = [tokenizer.encode(event) for event in events]
#     encoded_years = [tokenizer.encode(year) for year in years]
    
#     # Find the maximum length among encoded events
#     max_length = max([len(event) for event in encoded_events])
#     max_y_length = max([len(year) for year in encoded_years])
    
#     # Pad the encoded events to have the same length
#     padded_events = [event + [0] * (max_length - len(event)) for event in encoded_events]
#     padded_years = [year + [0] * (max_y_length - len(year)) for year in encoded_years]
    
#     events_tensor = tf.data.Dataset.from_tensor_slices(padded_events)
#     years_tensor = tf.data.Dataset.from_tensor_slices(padded_years)

#     dataset = tf.data.Dataset.zip((events_tensor, years_tensor))

#     return dataset


# Example usage
# data = ["What year was the signing of the Declaration of Independence? The signing of the Declaration of Independence was in 1776.",
# "What year was the storming of the Bastille? The storming of the Bastille was in 1789.",
# "What year was the Battle of Waterloo? The Battle of Waterloo was in 1815.",
# "What year was the assassination of Abraham Lincoln? The assassination of Abraham Lincoln was in 1865.",
# "What year was the invention of the telephone by Alexander Graham Bell? The invention of the telephone by Alexander Graham Bell was in 1876.",
# "What year was the first successful powered airplane flight by the Wright brothers? The first successful powered airplane flight by the Wright brothers was in 1903.",
# "What year was the sinking of the Titanic? The sinking of the Titanic was in 1912.",
# "What year was the beginning of World War I? The beginning of World War I was in 1914.",
# "What year was the Russian Revolution? The Russian Revolution was in 1917.",
# "What year was the end of World War I? The end of World War I was in 1918.",
# "What year was the stock market crash that led to the Great Depression? The stock market crash that led to the Great Depression was in 1929.",
# "What year was the beginning of World War II? The beginning of World War II was in 1939.",
# "What year was the attack on Pearl Harbor? The attack on Pearl Harbor was in 1941.",
# "What year was the D-Day invasion during World War II? The D-Day invasion during World War II was in 1944.",
# "What year was the dropping of the atomic bombs on Hiroshima and Nagasaki? The dropping of the atomic bombs on Hiroshima and Nagasaki was in 1945.",
# "What year was the end of World War II? The end of World War II was in 1945.",
# "What year was the establishment of the United Nations? The establishment of the United Nations was in 1945.",
# "What year was the beginning of the Korean War? The beginning of the Korean War was in 1950.",
# "What year was the launch of Sputnik 1, the first artificial satellite? The launch of Sputnik 1, the first artificial satellite, was in 1957.",
# "What year was the Cuban Missile Crisis? The Cuban Missile Crisis was in 1962.",
# "What year was the assassination of John F. Kennedy? The assassination of John F. Kennedy was in 1963.",
# "What year was the first moon landing by Apollo 11? The first moon landing by Apollo 11 was in 1969.",
# "What year was the end of the Vietnam War? The end of the Vietnam War was in 1975.",
# "What year was the fall of the Berlin Wall? The fall of the Berlin Wall was in 1989.",
# "What year was the dissolution of the Soviet Union? The dissolution of the Soviet Union was in 1991.",
# "What year was the terrorist attacks on September 11? The terrorist attacks on September 11 were in 2001.",
# "What year was the beginning of the Iraq War? The beginning of the Iraq War was in 2003.",
# "What year was the invention of the World Wide Web by Tim Berners-Lee? The invention of the World Wide Web by Tim Berners-Lee was in 1989.",
# "What year was the assassination of Martin Luther King Jr.? The assassination of Martin Luther King Jr. was in 1968.",
# "What year was the discovery of DNA's double helix structure by James Watson and Francis Crick? The discovery of DNA's double helix structure was in 1953.",
# "What year was the first human heart transplant performed by Dr. Christiaan Barnard? The first human heart transplant was in 1967.",
# "What year was the Chernobyl nuclear disaster? The Chernobyl nuclear disaster was in 1986.",
# "What year was the launch of the Hubble Space Telescope? The launch of the Hubble Space Telescope was in 1990.",
# "What year was the Rwandan Genocide? The Rwandan Genocide was in 1994.",
# "What year was the Oklahoma City bombing? The Oklahoma City bombing was in 1995.",
# "What year was the cloning of Dolly the sheep? The cloning of Dolly the sheep was in 1996.",
# "What year was the death of Princess Diana? The death of Princess Diana was in 1997.",
# "What year was the Euro currency introduced? The Euro currency was introduced in 1999.",
# "What year was the Indian Ocean earthquake and tsunami? The Indian Ocean earthquake and tsunami was in 2004.",
# "What year was the election of Pope Francis? The election of Pope Francis was in 2013.",
# "What year was the Paris Agreement on climate change signed? The Paris Agreement on climate change was signed in 2016.",
# "What year was the Brexit referendum? The Brexit referendum was in 2016.",
# "What year was the first iPhone released? The first iPhone was released in 2007.",
# "What year was the election of Donald Trump as the 45th President of the United States? The election of Donald Trump was in 2016.",
# "What year was the completion of the Human Genome Project? The completion of the Human Genome Project was in 2003.",
# "What year was the founding of the World Health Organization? The founding of the World Health Organization was in 1948.",
# "What year was the assassination of Archduke Franz Ferdinand? The assassination of Archduke Franz Ferdinand was in 1914.",
# "What year was the start of the California Gold Rush? The start of the California Gold Rush was in 1848.",
# "What year was the completion of the Panama Canal? The completion of the Panama Canal was in 1914.",
# "What year was the discovery of penicillin by Alexander Fleming? The discovery of penicillin was in 1928.",
# "What year was the Montgomery Bus Boycott? The Montgomery Bus Boycott was in 1955.",
# "What year was the assassination of Mahatma Gandhi? The assassination of Mahatma Gandhi was in 1948.",
# "What year was the formation of the European Union? The formation of the European Union was in 1993.",
# "What year was the release of the first Harry Potter book by J.K. Rowling? The release of the first Harry Potter book was in 1997.",
# "What year was the start of the American Civil War? The start of the American Civil War was in 1861."]

# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# tf_dataset = create_tf_dataset(data, tokenizer)

# dataset = strategy.experimental_distribute_dataset(tf_dataset)
print('Done!')

import os
import matplotlib.pyplot as plt
import datetime
from IPython.display import clear_output

def plot_loss(loss_history):
    clear_output(wait=True)  # Clear the output before plotting a new graph
    plt.plot(loss_history)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    
    # Get the current timestamp
    now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    plt.title(f"Loss History\nLast Updated at: {now}")
    
    plt.show()

@tf.function
def train_step(inp, tar):
    with tf.GradientTape() as tape:
        predictions = transformer([inp, tar], training=True)
        loss = loss_function(tar, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    return loss

print("Initializing training...")
try:
    for epoch in range(epochs):
        total_loss = tf.constant(0.0, dtype=tf.float32)
        for (batch, (inp, tar)) in enumerate(dataset):
            per_replica_losses = strategy.run(train_step, args=(inp, tar))
            loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
            total_loss += loss
            loss = loss.numpy()
            
            loss_history.append(loss)  # Save the loss of the current batch
            start = time.time()
            plot_loss(loss_history)  # Update the loss graph
            

        tf.saved_model.save(transformer, saved_transformer_path, options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))
        avg_loss = total_loss / (batch + 1)
        print(f'Epoch {epoch + 1}, Average loss: {avg_loss:.4f}')

except Exception as e:
    print(e)
    print("Saving transformer...")
    tf.saved_model.save(transformer, saved_transformer_path, options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))
except KeyboardInterrupt:
    print("Saving transformer...")
    tf.saved_model.save(transformer, saved_transformer_path, options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))

Loaded model not found. Creating a new model
Done
Processing dataset...
Example 1:
Input:  [ 2025   998  1042   318   281  3098    12  9800  8353  1964   290  1919
  8876   326 28317 28398   444 10762 21218   290 11009   511  9014   351
  2116    12 39935    11  2116    12 47866   276 14515  1912   319 16171
    11 22849  6712    13  2312  6712   389  1690  3417   355  1181  1203
 14515    11  3584  1811  7035   423  5447   606   517  5734   355  7310
  6712  1912   319  1729    12    71   959   998   605   393  1479 15814
    13 32229  1042   338  4318 25800   351   584 35871   318   326   340
  6622   262  1181   284   307 38117    11 13114    11   290 13568    13
 32229  1042   318  3221  4624   319   262  1290    12  9464   286   262
  1964 10958    11   290   881   286   663 12446   290  2742  8876  4079
  3098    12  9800  8353 26146   286 27770]
Target:  [  998  1042   318   281  3098    12  9800  8353  1964   290  1919  8876
   326 28317 28398   444 10762 21218   290 11009   51

Exception ignored in: <function Executor.__del__ at 0x7f70e326f700>
Traceback (most recent call last):
  File "/home/adrian_fagerland/.local/lib/python3.9/site-packages/tensorflow/python/eager/executor.py", line 46, in __del__
    self.wait()
  File "/home/adrian_fagerland/.local/lib/python3.9/site-packages/tensorflow/python/eager/executor.py", line 65, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence


Done!
Initializing training...
in user code:

    File "<ipython-input-10-6da9f6f795dc>", line 258, in train_step  *
        predictions = transformer([inp, tar], training=True)
    File "<ipython-input-5-43d3c2a885e1>", line 181, in call  *
        enc_output = self.encoder(inp, training, enc_padding_mask)
    File "<ipython-input-5-43d3c2a885e1>", line 103, in call  *
        x = self.embedding(x)

    AttributeError: 'list' object has no attribute 'dtype'

Saving transformer...


ValueError: Model <__main__.Transformer object at 0x7f70bde92880> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`. To manually set the shapes, call `model.build(input_shape)`.

In [None]:
print(min(loss_history))

In [None]:
def predict_next_word(input_text, transformer, tokenizer, top_k=5, max_length=128):
    input_tokens_full = tokenizer.encode(input_text, return_tensors="tf")
    if input_tokens_full.shape[1] > max_length:
        input_tokens = input_tokens_full[:, -max_length:]
    else:
        input_tokens = input_tokens_full
    seq_len = input_tokens.shape[1]
    logits = transformer([input_tokens, tf.zeros((1, seq_len), dtype=tf.int32)], training=False)
    logits = logits[:, -1, :]  # Get the logits for the last token
    top_k_indices = tf.math.top_k(logits, k=top_k).indices
    top_k_tokens = [tokenizer.decode([token_id]) for token_id in top_k_indices.numpy()[0]]
    
    return top_k_tokens


input_text = """"""
predicted_words = predict_next_word(input_text, transformer, tokenizer, top_k=50)
print(f"Input: {input_text}")
print("Predicted next words:")
for i, word in enumerate(predicted_words):
    print(f"{i + 1}. {word}")