<a href="https://colab.research.google.com/github/ChirudeepG/Transformers-and-finetuning-with-LLMs/blob/main/297_tensorflow_text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
# Hyperparameters
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

In [3]:
with open('/content/sample_data/S01E07 The Blackout.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

# Encoding the data
data = [stoi[c] for c in text]

In [5]:
n = int(0.9 * len(data))

train_data_tensor = tf.constant(data[:n], dtype=tf.int32)
val_data_tensor = tf.constant(data[n:], dtype=tf.int32)

In [6]:
def get_batch(data_tensor, batch_size, block_size):
    start_indices = tf.random.uniform((batch_size,), 0, len(data_tensor) - block_size, dtype=tf.int64)
    x_batch = tf.stack([data_tensor[start:start + block_size] for start in start_indices])
    y_batch = tf.stack([data_tensor[start + 1:start + block_size + 1] for start in start_indices])
    return x_batch, y_batch

In [7]:
class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = layers.Dense(self.head_dim, use_bias=False)
        self.keys = layers.Dense(self.head_dim, use_bias=False)
        self.queries = layers.Dense(self.head_dim, use_bias=False)
        self.fc_out = layers.Dense(embed_size)

    def call(self, values, keys, query):
        N, seq_length, _ = query.shape
        value_len, key_len = values.shape[1], keys.shape[1]

        # Split embedding into self.head pieces
        values = tf.reshape(values, (N, value_len, self.heads, self.head_dim))
        keys = tf.reshape(keys, (N, key_len, self.heads, self.head_dim))
        queries = tf.reshape(query, (N, seq_length, self.heads, self.head_dim))

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        attention = tf.einsum("nqhd,nkhd->nhqk", queries, keys)
        attention = attention / tf.math.sqrt(float(self.head_dim))
        attention = tf.nn.softmax(attention, axis=-1)

        out = tf.einsum("nhql,nlhd->nqhd", attention, values)
        out = tf.reshape(out, (N, seq_length, self.embed_size))
        out = self.fc_out(out)
        return out

class TransformerBlock(layers.Layer):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

        self.feed_forward = keras.Sequential(
            [
                layers.Dense(forward_expansion * embed_size, activation="relu"),
                layers.Dense(embed_size),
            ]
        )

        self.dropout = layers.Dropout(dropout)

    def call(self, value, key, query):
        attention = self.attention(value, key, query)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out



In [8]:
class BigramLanguageModel(keras.Model):
    def __init__(self, vocab_size, embed_size, heads, n_layers, max_length, forward_expansion, dropout):
        super(BigramLanguageModel, self).__init__()
        self.embedding = layers.Embedding(vocab_size, embed_size)
        self.positional_embedding = layers.Embedding(max_length, embed_size)
        self.transformer_blocks = [
            TransformerBlock(embed_size, heads, dropout, forward_expansion)
            for _ in range(n_layers)
        ]
        self.dropout = layers.Dropout(dropout)
        self.fc_out = layers.Dense(vocab_size)

    def call(self, x):
        N, seq_length = x.shape
        positions = tf.range(start=0, limit=seq_length, delta=1)
        out = self.embedding(x)
        out += self.positional_embedding(positions)

        for block in self.transformer_blocks:
            out = block(out, out, out)

        out = self.dropout(out)
        out = self.fc_out(out)
        return out


In [9]:
model = BigramLanguageModel(
    vocab_size,
    n_embd,
    n_head,
    n_layer,
    block_size,
    forward_expansion=n_embd * 4,
    dropout=dropout
)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
generated_text = []
for iteration in range(max_iters):
    x_batch, y_batch = get_batch(train_data_tensor, batch_size, block_size)
    with tf.GradientTape() as tape:
        logits = model(x_batch)
        loss = loss_fn(y_batch, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    if iteration % eval_interval == 0:
        print(f"Iteration {iteration}, Loss: {loss.numpy()}")



Iteration 0, Loss: 4.995068073272705
Iteration 100, Loss: 3.3651232719421387
Iteration 200, Loss: 3.32047438621521
Iteration 300, Loss: 3.290300130844116
Iteration 400, Loss: 3.366560220718384
Iteration 500, Loss: 3.29370379447937
Iteration 600, Loss: 3.3987698554992676
Iteration 700, Loss: 3.275597333908081
Iteration 800, Loss: 3.3413777351379395
Iteration 900, Loss: 3.042710542678833
Iteration 1000, Loss: 2.942659854888916
Iteration 1100, Loss: 2.6498396396636963
Iteration 1200, Loss: 2.543227195739746
Iteration 1300, Loss: 2.6203181743621826
Iteration 1400, Loss: 2.636044502258301
Iteration 1500, Loss: 2.4142119884490967
Iteration 1600, Loss: 2.5263314247131348
Iteration 1700, Loss: 2.2889983654022217
Iteration 1800, Loss: 2.366955280303955
Iteration 1900, Loss: 2.380937337875366
Iteration 2000, Loss: 2.32063889503479
Iteration 2100, Loss: 2.405618190765381
Iteration 2200, Loss: 2.21152663230896
Iteration 2300, Loss: 2.2271673679351807
Iteration 2400, Loss: 2.429385185241699
Iterati

In [10]:
def generate_text(model, start_string, max_generate_length=2000):
    # Convert start_string to tensor
    input_eval = [stoi[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)

    generated_text = []

    model.reset_states()
    for i in range(max_generate_length):
        logits = model(input_eval)
        # Use a multinomial distribution to predict the token returned by the model
        predicted_id = tf.random.categorical(logits[:, 0, :], num_samples=1)[-1,0].numpy()


        # Append the predicted token to the input string and the generated text
        input_eval = tf.expand_dims([predicted_id], 0)
        generated_text.append(itos[predicted_id])

    return ''.join(generated_text)

start_string = " "  # You can use a space, or any other starting token
print(generate_text(model, start_string))


aq
 ' qsa qAAnACNA.,.....gtc
 aaI baqACCNCNWvnCNwCnCnARCNpWWcRAWP
 i..  a
 hc?     sc
 gc?]....  gWvnCsyCCnCCnNWc
 gggttgttcBCCnCnGM
  ,... nYNCCnOcynCni.....,   ghOMPynnAnCCCnBTYYWMyCCT
c?R...... 
aonCnCnCNCnCnCnCnAnBCnNcPe  hWM
   ac!lb'WWcynCnCnNAnAnvnA..  nNJWWWc
 ugLbBWc?qCNunCNCnCNAttgtgWWcyNACCNAm
   tthCFnCnNCCCT
  qRAnCnCnAnNwAnNcWMc?Anttc?nNAnCnAAnNOcyCnAnnCNCCnCCCNCnAnCnCtc?W[
 qR........ hAiiDcaI.,....  nOc!
.. gWMc?Rc
 nCnAR..,.   gtc?WMc!rtc?aetaqa gttcyCnNwwCCnOtttc
 unAnCnnwCCnY:sSgc
... qAnCnpwcc?fii c
hc...'... qAnbc?......  
lbnCnA.   cAnCnCCNcypCCnCNAnNCCnCnCnwdddrc!Anii   nCNCCCnCnAnCpvNcc?
    gttgcbnAnwnCnCNCnATcBCnjnmTCNYCnAnNo
  qAnCnCCnYFnCnCnCCN  ggcyTMpTjjnCnNA...,...   q
WM?
 ahAlsJAp:.....,.  e nCnCnAnGWWWMWvCNCCnQtggggM cP
 gtttc!
  nO
nCCnCnCnCnNAbc?qAnJ:l  aqRR. a anCnCnCnAwo nCnAtttc...... lddddddrc
u   a  tc?fqRqR..,.   onNCCpCCCnNAnCnCnCnnCnCnGOcyNNbn(pWTO
a
o c?[R..,.etc?RCn
 nCnA..,......  aa
c
aqAntt nCCnCNCnNc?
hTACnBbc!J
 qRc?.,. ggt. acyCCNCCnB