We begin by writing multi-head attention class in tensorflow. To do this, we assume our input is already tokenized vectors. 

We will import a pre-trained tokenizer, that will map text into list of token ID. And this is just list of numbers. The tokenizer we used has sub-word tokenization, but it is useful to think that we have one ID for each individual word. 

In [None]:
import tensorflow as tf
from transformers import BertTokenizer,TFBertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenization_test(input_test):
    inputs = tokenizer(input_test, return_tensors='tf',max_length=10,return_attention_mask=True,padding='max_length')
    ID_test=inputs['input_ids']
    mask=inputs["attention_mask"]
    print(ID_test)
    print(mask)

if __name__ == '__main__':
    tokenization_test("Hello, my dog is cute")

tf.Tensor([[  101  7592  1010  2026  3899  2003 10140   102     0     0]], shape=(1, 10), dtype=int32)
tf.Tensor([[1 1 1 1 1 1 1 1 0 0]], shape=(1, 10), dtype=int32)


The above code snipt is just to get a sense of how things works. If we use BERT tokenizer, for each batch, the maximal allowed sequence length is 512. One can choose truncation or simply returning an error by setting truncation=true or faule in the tokenizer arugment. The next step will be embeding and positional ecoding. There are multiple ways to do it. In my small llm code, I will use pre-defined model since the focus is the transformer architecture. But it's also good to understand how this step works, both in practise and from the original paper "Attention is all you need". 

When we preparing out training data, we will add padding into tokenizer because the training can only handle batches with the same sequence length.

In [102]:
# Here we work with single batch, and let the length of token ID to be N_id
# The input here is 1 X N_id tensor
# The embedding is the process where you map the token ID at each position to a vector of size D
# the positional encoding is creatig another vector of the same size D
# Model output is a tensor of size 1 X N_id X D

Embedding_model=TFBertModel.from_pretrained('bert-base-uncased')

def embedding_test(input_test):
    inputs = tokenizer(input_test, return_tensors='tf')
    Embedding_tmp=Embedding_model(inputs)
    output_embedding=Embedding_tmp.last_hidden_state

    print(f"output_shape={output_embedding.shape}")

if __name__ == '__main__':
    embedding_test("Make America Great Again")    


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

output_shape=(1, 6, 768)


In [3]:
# TensorFlow offers embedding layer for custom embedding
# embedding = tf.keras.layers.Embedding(input_dim=N_id, output_dim=embedding_dim)

In [72]:
#fixed positional encoding is included for demonstration purpose
# In practice, we can use the positional encoding layer that can be optimized during training
# Here we assumed that N_id is N_id_max
# Tokenization will called the number N_id for padding
# later we will add attention mask to deal with this during calculation of attention scores
class fixed_pos_encod_layer(tf.keras.layers.Layer):
      def __init__(self, N_id_max, embedding_dim):
            assert embedding_dim % 2 == 0, "embedding_dim must be even"
            super(fixed_pos_encod_layer, self).__init__()

            self.N_id_max=N_id_max
            self.embedding_dim=embedding_dim
    
      def _pos_encod(self):
            # position range from 0 to N_id-1
            # i range from 0 to embedding_dim-1
            # Due to the alternating pattern of sin and cos, by convention, embedding_dim is even
            # Position indices
            pos_index = tf.range(self.N_id_max, dtype=tf.float32)[:, tf.newaxis]
            # frequency factors, note that it is more efficient to use tf.exp then tf.pow
            omega = tf.exp(-2*tf.range(0,self.embedding_dim,2,dtype=tf.float32)/self.embedding_dim *tf.math.log(10000.0))
            
            angles=pos_index*omega 

            pos_encoding=tf.concat([tf.sin(angles),tf.cos(angles)],axis=-1)
            
            # add batch dimension
            # prepare for broadcasting
            return tf.expand_dims(pos_encoding,0)
      
      def call(self, inputs):
            # inputs is a tensor of size B X N_id X D
            N_id=tf.shape(inputs)[1]
            return inputs + self._pos_encod()[:,:N_id,:]
      
class pos_encod_layer(tf.keras.layers.Layer):
      def __init__(self, N_id_max, embedding_dim):
            assert embedding_dim % 2 == 0, "embedding_dim must be even"
            super(pos_encod_layer, self).__init__()

            self.N_id_max=N_id_max
            self.embedding_dim=embedding_dim
            
            # for trainable positional encoding, we want to make sure weights can be called 
            self.pos_encoding=self.add_weight(name="pos_encoding",
                                              shape=(1,N_id_max,embedding_dim),
                                              initializer='random_normal',
                                              trainable=True)
            
      def call(self,inputs):
                  N_id=tf.shape(inputs)[1]
                  return inputs + tf.expand_dims(self.pos_encoding[:N_id,:],0) 
                  # or self.pos_encoding[tf.newaxis,:N_id,:]
      

Now we have our data ready to feed into transformer. A transformer is consist of encoder or decoder or both. Their underling structure is multi-headed attention, which is the core idea behind "attention is all you need". In a Transformer, the conventional architecture uses one Self-MHA followed by an FFN per encoder layer to process the input sequence, and two attention mechanisms (Self-MHA then regular MHA) followed by an FFN per decoder layer to handle the target sequence and connect to the encoder’s output, with multiple such layers stacked in both the encoder and decoder to form the full model, though you can tweak this setup for custom experiments.

In [None]:
# Multi-head attention layer
# input is tensor B X N_id X D 
# attention, vector of D dimension
# W_q W_k W_v => Q, K, V, eg: Q= inputs \cdot W_q
# attention score Q \dot K^T   
# Q is the information that is in your input
# K is the information that you are comparing to
# D= d_head * Nof_heads
# 1000= 10 * 100
class MHA(tf.keras.layers.Layer):
        def __init__(self, embedding_dim, Nof_heads):
            super(MHA,self).__init__()
            assert embedding_dim % Nof_heads == 0, "embedding_dim must be divisible by num_heads"

            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads  
            self.key_dim= embedding_dim // Nof_heads
 
            # weigth matrices for Q,K,V
            self.W_q=self.add_weight(name="W_q",
                                     shape=(embedding_dim, embedding_dim),
                                     initializer="random_normal",
                                     trainable=True)
            
            self.W_k=self.add_weight(name="W_k",
                                     shape=(embedding_dim, embedding_dim),
                                     initializer="random_normal",
                                     trainable=True)
            
            self.W_v=self.add_weight(name="W_v",
                                     shape=(embedding_dim, embedding_dim),
                                     initializer="random_normal",
                                     trainable=True)
            
            self.W_out=self.add_weight(name="W_out",
                                     shape=(embedding_dim, embedding_dim),
                                     initializer="random_normal",
                                     trainable=True)
            
        def split_heads(self,Vector):
            
            # input vector is a tensor of size B X N_id X D
            # name vector emphasizes that it is a vector in the embedding space that we are splitting
            # split the last dimension into (Nof_heads, key_dim)
            # Output should be (B, Nof_heads, N_id, key_dim)


            # Split the last dimension into (Nof_heads, key_dim)
            input_reshaped=tf.reshape(Vector,(tf.shape(Vector)[0],tf.shape(Vector)[1],self.Nof_heads,self.key_dim))
            # Transpose to get the standard format 
            
            return tf.transpose(input_reshaped,perm=[0,2,1,3]) # B X Nof_heads X N_id  X key_dim
        
        def call(self,Q,K,V,mask=None):
                
                # Project Q,K,V
                Q=tf.matmul(Q,self.W_q) # B X N_id X D
                K=tf.matmul(K,self.W_k) # B X N_id X D
                V=tf.matmul(V,self.W_v) # B X N_id X D
                
                # split the vector into different heads B X Nof_heads X N_id  X key_dim
                Q=self.split_heads(Q)
                K=self.split_heads(K)
                V=self.split_heads(V)

                # calculate the attention scores for each heads
                # keep in mind that we have to use a tf tensor for the denominator 
                Scores= tf.matmul(Q,K,transpose_b=True) /tf.math.sqrt(tf.cast(self.key_dim,tf.float32))
                # at each batch, head, Q=size of N_id X key_dim, K^T=size of key_dim X N_id 
                # Scores is a tensor of size B X Nof_heads X N_id X N_id
                if mask is not None:
                      
                      mask=tf.where(mask == 0, -1e9, 0.0)
                      Scores= Scores + mask

                Soft_max=tf.nn.softmax(Scores,axis=-1)
                 
                # output per head is B X Nof_heads X N_id X key_dim
                Output= tf.matmul(Soft_max,V) 
                # transpose back to the original format
                Output=tf.transpose(Output,perm=[0,2,1,3])
                # concatenate the heads
                Output=tf.reshape(Output,(tf.shape(Q)[0], tf.shape(Q)[2], self.embedding_dim))
                # alternatively but less explicitly one can do tf.reshape(Output,tf.shape(inputs))
                # final projection
                O=tf.matmul(Output, self.W_out)      
               
                return O 
        


In [6]:
# Feed forward network layer
class Feed_forward_network(tf.keras.layers.Layer):
        def __init__(self,embedding_dim, expanding_dim):
            super(Feed_forward_network,self).__init__()
            self.embedding_dim=embedding_dim
            self.expanding_dim=expanding_dim
            self.ffn=tf.keras.Sequential([
                tf.keras.layers.Dense(expanding_dim,activation='relu'),
                tf.keras.layers.Dense(embedding_dim)])
        
        def call(self,input):
              return self.ffn(input)


In [7]:
class encoder_layer(tf.keras.layers.Layer):
        # the input is result of positional encoding namely B X N_id X D
        # in parameter needs to specify MHA and FFN
        def __init__(self,embedding_dim,Nof_heads,expanding_dim):
            super(encoder_layer,self).__init__()
            #initialize the paramters
            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads
            self.expanding_dim=expanding_dim
            # create the layers
            self.MHA=MHA(embedding_dim,Nof_heads,mask=None)
            self.FFN=Feed_forward_network(embedding_dim,expanding_dim)
            self.LN1=tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.LN2=tf.keras.layers.LayerNormalization(epsilon=1e-6)

        def call(self,inputs,source_mask=None):
              # MHA
              Out_MHA=self.MHA(inputs,inputs,inputs,mask=source_mask)
              # residual connection+ layer normalization
              Out1=self.LN1(inputs+Out_MHA)       
              # Feed forward network
              Out_ffn=self.FFN(Out1)
              # residual connection+ layer normalization
              return self.LN2(Out1+Out_ffn)     

class decoder_layer(tf.keras.layers.Layer):
      def __init__(self, embedding_dim, Nof_heads,expanding_dim):
            super(decoder_layer,self).__init__()
            # initialize the paramters 
            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads
            self.expanding_dim=expanding_dim

            # create the layers
            self.MHA_masked=MHA(embedding_dim,Nof_heads, mask=None)   
            self.MHA_cross= MHA(embedding_dim,Nof_heads, mask=None)   
            self.FFN=Feed_forward_network(embedding_dim,expanding_dim)
            self.LN1=tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.LN2=tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.LN3=tf.keras.layers.LayerNormalization(epsilon=1e-6)

      def call(self,target,encoder_output,source_mask=None,target_mask=None):
            #inputs could be encoder output or directly from the positional encoding
            # MHA masked
            Out_MHA_masked=self.MHA_masked(target,target,target,mask=target_mask)
            # residual connection+ layer normalization
            Out1=self.LN1(target+Out_MHA_masked)
            # MHA cross
            Out_MHA_cross=self.MHA_cross(Out1,encoder_output,encoder_output,mask=source_mask)
            # residual connection+ layer normalization
            Out2=self.LN2(Out1+Out_MHA_cross)
            # Feed forward network
            Out_ffn=self.FFN(Out2)
            # residual connection+ layer normalization
            return self.LN3(Out2+Out_ffn)

Now we do decoder layer and decoder, in which we simply loop through the layers from a list and pass through the result to the next layer.
 

In [8]:
class Encoder(tf.keras.layers.Layer):

      # 
      def __init__(self, embedding_dim, Nof_heads,expanding_dim, N_layers):
            super(Encoder,self).__init__()
            #initialize the parameters
            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads
            self.expanding_dim=expanding_dim
            self.N_layers=N_layers
            # create the layers
            self.encoder_layers=[encoder_layer(embedding_dim,Nof_heads,expanding_dim) for _ in range(Nof_layers)]
            
      def call(self,inputs,source_mask=None):
            #inputs is the output of the positional encoding
            for i in range(self.N_layers):
                  inputs=self.encoder_layers[i](inputs,source_mask)
            return inputs
      
class Decoder(tf.keras.layers.Layer):
      def __init__(self, embedding_dim, Nof_heads, expanding_dim, N_layers):
            super(Decoder,self).__init__()
            #initialize the parameters
            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads
            self.expanding_dim=expanding_dim
            self.N_layers=N_layers
            # create the layers
            self.decoder_layers=[decoder_layer(embedding_dim,Nof_heads,expanding_dim) for _ in range(N_layers)]
            
      def call(self,target,encoder_output,source_mask=None,target_mask=None):
            #inputs is the output of the positional encoding
            for i in range(self.N_layers):
                  target=self.decoder_layers[i](target,encoder_output,source_mask,target_mask)
            return target


For text summerization, we need both the encoder and decoder. Which has the following attributes: embedding_dim, Nof_heads, expanding_factor, N_layer. The following the part 2 of the transformer blog post. 

In [9]:
import pandas as pd

In [124]:
from datasets import load_dataset
dataset=load_dataset("wikitext", "wikitext-2-raw-v1")
# we first understand the structure of the dataset 
print(dataset)

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})


This tells us that train_split has given us three subset. For our purpose, just the train part will be sufficient. 

In [125]:
training_data=dataset['train']['text']
print(type(training_data))
print(training_data[1])
print(dataset['train'][1])
# df = pd.DataFrame(training_data, columns=["text"])
# print("DataFrame Info:")
# print(df.info())
# print("\nFirst 5 rows:")
# print(df.head())
# print("\nBasic stats:")
# print(df.describe())
# print("\nSample 5 random rows:")
# print(df.sample(5))

<class 'list'>
 = Valkyria Chronicles III = 

{'text': ' = Valkyria Chronicles III = \n'}


In [126]:
# tokenize data

Tokenized_data=tokenizer(training_data,
                          return_tensors='tf',
                          padding="max_length",
                          truncation=True,
                          add_special_tokens=False,
                          max_length=50,
                          return_attention_mask=True)

input_ids=Tokenized_data['input_ids']
padding_mask=Tokenized_data['attention_mask']

print(input_ids[1])

tf.Tensor(
[ 1027 11748  4801  4360 11906  3523  1027     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0], shape=(50,), dtype=int32)


In [129]:
print(input_ids[10])

tf.Tensor(
[ 1996  2208  1005  1055  2645  2291  1010  1996 22312  2291  1010  2003
  3344  2058  3495  2013 11748  4801  7895 11906  1012  2076  6416  1010
  2867  7276  2169  3131  2478  1037  2327  1030  1011  1030  2091  7339
  1997  1996 11686  4949  1024  2320  1037  2839  2003  3479  1010  1996
  2447  5829], shape=(50,), dtype=int32)


In [112]:
# Take tokenized data and give us batched training and target sequences, and causal mask 
class token_to_data():
        def __init__(self,input_ids,padding_mask,batch_size):
               self.input_ids=input_ids
               self.padding_mask=tf.identity(padding_mask)
               self.batch_size=batch_size
               self._data_pre()
               self._create_dataset()

        def _data_pre(self):
                #sequence length
                self.seq_len=tf.shape(self.input_ids)[-1]-1

                # parameters for batches
                N_seq  =  tf.shape(self.input_ids)[0]
                self.N_batch=  N_seq // self.batch_size
                N_trimmed=  self.N_batch * self.batch_size

                # input sequence
                self.input_sequence=self.input_ids[:N_trimmed,:-1]
                # target sequence
                self.target_sequence=self.input_ids[:N_trimmed,1:] 
                # padding mask
                self.padding_mask=self.padding_mask[:N_trimmed,:-1]
                
                # batch the data
                self.batched_input_sequence=tf.reshape(self.input_sequence,(self.N_batch,self.batch_size,self.seq_len))
                self.batched_target_sequence=tf.reshape(self.target_sequence,(self.N_batch,self.batch_size,self.seq_len))
                self.batched_padding_mask=tf.reshape(self.padding_mask,(self.N_batch,self.batch_size,1,1,self.seq_len))

                # compute causal mask
                self.causal_mask=tf.linalg.band_part(tf.ones((self.seq_len, self.seq_len)), -1, 0)
                self.causal_mask=tf.reshape(self.causal_mask,(1,1,self.seq_len, self.seq_len))
        
        
        def _create_dataset(self):
                self.dataset = tf.data.Dataset.from_tensor_slices((self.batched_input_sequence, self.batched_target_sequence, self.batched_padding_mask) )
                self.dataset = self.dataset.shuffle(buffer_size=1000).prefetch(tf.data.AUTOTUNE)

        def get_dataset(self):
            return self.dataset

        def get_batch(self, batch_id):
                return (self.batched_input_sequence[batch_id],
                                        self.batched_target_sequence[batch_id],
                                          self.batched_padding_mask[batch_id])
        def get_causal_mask(self):
               return self.causal_mask
 

In [113]:

class decoderonlylayer(tf.keras.layers.Layer):
      def __init__(self, embedding_dim, Nof_heads,expanding_dim):
            super(decoderonlylayer,self).__init__()
            # initialize the paramters 
            self.embedding_dim=embedding_dim
            self.Nof_heads=Nof_heads
            self.expanding_dim=expanding_dim

            # create the layers
            self.MHA_masked=MHA(embedding_dim,Nof_heads)   
            self.FFN=Feed_forward_network(embedding_dim,expanding_dim)
            self.LN1=tf.keras.layers.LayerNormalization(epsilon=1e-6)
            self.LN2=tf.keras.layers.LayerNormalization(epsilon=1e-6)
      

      def call(self,input,mask=None):
            #inputs could be encoder output or directly from the positional encoding
            # MHA masked
            Out_MHA_masked=self.MHA_masked(input,input,input,mask=mask)
            # residual connection+ layer normalization
            Out1=self.LN1(input+Out_MHA_masked)
            # residual connection+ layer normalization
            # Feed forward network
            Out_ffn=self.FFN(Out1)
            # residual connection+ layer normalization
            return self.LN2(Out1+Out_ffn)

In [None]:
# Here we simplify things and stacking decoder layer directly
class next_token_prediction(tf.keras.Model):
      def __init__(self,vocab_size,N_id_max,embedding_dim,Nof_heads,expanding_dim,N_layers):
            # embedding_dim is the size of the embedding vector, denoted as D
            # Nof_heads is the number of heads in the multi-head attention layer
            # expanding_factor is the factor by which the embedding vector is expanded in the feed forward network
            # N_layers is the number of layers in the encoder and decoder
            super(next_token_prediction,self).__init__()

            self.embedding_layer = tf.keras.layers.Embedding(vocab_size, embedding_dim)
            self.position_encoding= fixed_pos_encod_layer(N_id_max,embedding_dim)
            self.decoder_layers=[decoderonlylayer(embedding_dim,Nof_heads,expanding_dim) for _ in range(N_layers)]
            self.output_layer = tf.keras.layers.Dense(vocab_size)

      def call(self,input,mask=None):
            tmp=self.embedding_layer(input)
            tmp=self.position_encoding(tmp)
            for layer in self.decoder_layers:
                  tmp=layer(tmp,mask=mask)
            return self.output_layer(tmp)      
                  

model= next_token_prediction(vocab_size = 30522,
                             N_id_max=50,
                             embedding_dim = 128,
                             Nof_heads = 4,
                             expanding_dim = 512,
                             N_layers = 2)           
     



In [114]:
# Loss and Optimizer
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.001)     

In [None]:
input_ids=Tokenized_data['input_ids']
padding_mask=Tokenized_data['attention_mask']


data = token_to_data(input_ids, padding_mask, batch_size=32)
dataset = data.get_dataset()
causal_mask = data.get_causal_mask()

In [116]:
@tf.function
def train_step(inputs, targets, causal_mask):
    with tf.GradientTape() as tape:
        predictions = model(inputs, mask=causal_mask)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

In [None]:
# Training Loop
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    total_loss = 0
    num_batches = 0
    for inputs, targets, padding_mask in dataset:
        loss = train_step(inputs, targets, causal_mask)
        total_loss += loss
        num_batches += 1
    avg_loss = total_loss / num_batches
    print(f"Average loss: {avg_loss:.4f}")

    

Epoch 1/5
Average loss: 2.9560
Epoch 2/5
Average loss: 2.6811
Epoch 3/5


2025-04-01 15:56:35.594281: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average loss: 2.4780
Epoch 4/5
Average loss: 2.2810
Epoch 5/5
Average loss: 2.0946


In [121]:
# # Now let's do inference 

# def text_generator(input_string, model, max_length, tokenizer):
#     input_tokens =tokenizer(input_string, add_special_tokens=False,return_tensors='tf')
#     input_ids = input_tokens['input_ids']
#     generated = input_tokens.copy()

#     for _ in range(max_length - len(input_tokens)):
#         seq_len = tf.shape(input_ids)[1]
#         causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)  # Shape: (seq_len, seq_len)
#         causal_mask = tf.reshape(causal_mask, (1, 1, seq_len, seq_len))

#         logits = model(input_ids, mask=causal_mask)
#         next_token_logits = logits[:, -1, :]
#         next_token = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
#         generated.append(next_token.numpy()[0])

#         print(tokenizer.decode(generated))
#         input_ids = tf.constant([generated], dtype=tf.int32)
#         if next_token.numpy()[0] == 102:
#             break
    

def text_generator(input_string, model, max_length, tokenizer):
    # Tokenize input string and get token IDs as a list
    input_tokens = tokenizer.encode(input_string, add_special_tokens=False)  # Returns a list
    input_ids = tf.constant([input_tokens], dtype=tf.int32)  # Shape: (1, input_len)
    generated = input_tokens.copy()  # Now a list

    for _ in range(max_length - len(input_tokens)):  # len works on list
        seq_len = tf.shape(input_ids)[1]
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        causal_mask = tf.reshape(causal_mask, (1, 1, seq_len, seq_len))

        logits = model(input_ids, mask=causal_mask)
        next_token_logits = logits[:, -1, :]
        next_token_logits = tf.where(tf.range(tf.shape(next_token_logits)[-1]) == 0, -1e9, next_token_logits)
        next_token = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
        generated.append(next_token.numpy()[0])  # Append to list
        
        # Overwrite previous output with new string
        print(tokenizer.decode(generated), end='\r', flush=True)
        input_ids = tf.constant([generated], dtype=tf.int32)
        if next_token.numpy()[0] == 102:  # Stop at [SEP]
            break
    
    final_text = tokenizer.decode(generated)
    print(final_text)  # Final print without \r to keep it
    return final_text   
         

   


In [122]:
input_str="I love dogs"
generated_text = text_generator(model=model,
                 input_string=input_str,
                  max_length=50,
                  tokenizer=tokenizer
                  )


i love dogs and chains ( welsh : ffordd g. 39 ( 3 ) ( piano ), rac ( welsh : ffordd gyswllt, and the most common name of the " king's " ( " ).
