optimization techniques :
- always use powers of 2 for hyperparameters initialization. for example, in this code, the vocab size for GPT2 tokenizer is 50257 which is not divisible be 2 so increase the vocab size to 50304. even though we have increased the vocab size, we may think that compute time will increase but actually the compute time decreases. because almost all kernals, all operations are done in powers of 2(like getting 32bits of memory per second so if we want to read 33 bits then it will take 2 seconds and also reading 64 bits will take same 2 seconds), but dont increase to much

- use @tf.function decorator for custom train, val and test methods in tensorflow or use pytorch.compile(). this will enable graph execution rather than executing in eager mode

- use flash attention instead of self attention. the calculations and output are same but flash attention is more optimized. in tensorflow, you can use tensorflow.keras.layers.MultiHeadAttention() with flash_attention=True to enable flash_attention. 'False' to disable or 'None' to automatically decide by keras. or use keras.config.enable_flash_attention() to enable flash attention globally and keras.config.disable_flash_attention() to disable flash attention globally. you will find easiest implementation of flash attention in pytorch using torch.nn.functional.scaled_dot_product_attention().

-  use BF16 operations instead of FP32. both operations offer same range but BF16 has slightly lower precision which may affect the performance of the model negatively but its negligible but we get almost 16X speed then using FP32

- gradient clipping : clip the optimizer gradients : by clipping the gradients, we basically try to control the gradients from changing too much at once. it will control such that the jump is not higher than the threshold mentioned. clipnorm=1.0 is mostly used. monitoring the norm is a good practice(as the epochs increases, norm should be smoothening, no spikes, which indicates smooth training)

- optimizing learning rate(you can use learning rate schedulers like lr_scheduler, cosine_decay, etc), weight decay and beta values of the optmimzer.

- batch size is connected with all other hyperparameters

- use gradiesnt accumulation when you can't fit big batches into memory
for example, lets' say that we can only fit batch size of 32 into the memory but we need to perform training with batch size of 128 in that case we will train with batch size of 32 and only update the gradients after 4 mini batches which means the effective batch size will be 32*4=128

- varying batch_size : model is trained with smaller batch size and then gradually increased



resource : https://www.youtube.com/watch?v=l8pRSuU81PU

In [1]:
!pip install transformers
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.9.0


In [2]:
import tensorflow as tf
import numpy as np
import transformers
import time

In [3]:
# model configuration

class Config:
    vocab_size: int = 50257   #final token(50256) is <|endoftext|> which is used for both BOS(beginning of sentence) and EOS(end of sentence)
    n_positions: int = 1024   #context length
    n_embed: int = 768  #embedding dimension
    n_layer: int = 12   #number of decoder blocks
    n_head: int = 12  #number of attention heads inside each block
    n_inner: int = None  #dimensionality of inner feed-forward layers. None will set it to 4*n_embed
    activation_function: str = 'gelu'  #activation to use inside the feed-forward layers
    resid_pdrop: int = 0.1  #dropout for all fully connected layer in decoder block
    embed_pdrop: int = 0.1  #dropout for embedding layers
    attn_pdrop: int = 0.1   #dropout for attention
    layer_norm_epsilon: int = 1e-5  #epsilon for layer normalization
    initializer_range: int = 0.02
    batch_size: int = 512

config=Config()

In [4]:
tf.keras.mixed_precision.set_global_policy('float32')

#gpt2  model

class FFN(tf.keras.layers.Layer):
    def __init__(self):
        super(FFN, self).__init__()

        self.hidden_units = 4*config.n_embed if config.n_inner is None else config.n_inner
        self.activation = tf.keras.activations.get(config.activation_function)

        self.hidden1 = tf.keras.layers.Dense(self.hidden_units)
        self.final_hidden = tf.keras.layers.Dense(config.n_embed)
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)

    def call(self, input, training=False):
        x=self.hidden1(input)
        x=self.activation(x, approximate=True)
        x=self.final_hidden(x)
        x=self.dropout(x, training=training)
        return x

####################################################################
# NOTE : another optimization technique is to use flash attention. pytorch has the best implementation. you can also use tensorflow.keras.layers.MultiHeadAttention(flash_attention=True) and keras.config.enable_flash_attention() to enable flash_attention gloablly or keras.config.disable_flash_attention() to disable flash_attention globally
######################################################################
# the attention layer in gpt2 is slightly different then tf.keras.layers.MultiHeadAttention() because in gpt2's attention combines QKV projections and then split.
class GPT2Attention(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.c_attn = tf.keras.layers.Dense(3*config.n_embed)  # 3*768 for QKV
        self.c_proj = tf.keras.layers.Dense(config.n_embed)
        self.dropout = tf.keras.layers.Dropout(0.1)

    def split_heads(self, x):
        head_dim=config.n_embed//config.n_head
        return tf.transpose(tf.reshape(x, (tf.shape(x)[0], tf.shape(x)[1], config.n_head, head_dim)), (0, 2, 1, 3))

    def call(self, x):
        """use flash attention"""
        batch_size = tf.shape(x)[0]
        qkv = self.c_attn(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)

        attn = tf.matmul(q, k, transpose_b=True)/tf.sqrt(64.0)
        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x = tf.matmul(attn, v)
        x = tf.reshape(tf.transpose(x, (0, 2, 1, 3)), (batch_size, -1, config.n_embed))
        return self.c_proj(x)

class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(DecoderBlock, self).__init__()

        self.attention = GPT2Attention()
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon)

        self.ffn = FFN()

    def call(self, input, training=False):
        x=self.layernorm1(input)
        x=self.attention(x)
        # skip connection
        sk1=x+input
        x=self.layernorm2(sk1)
        x=self.ffn(x, training=training)
        # skip connection
        sk2=x+sk1

        return sk2

class Decoder(tf.keras.layers.Layer):
    def __init__(self):
        super(Decoder, self).__init__()

        self.embedding=tf.keras.layers.Embedding(
            input_dim=config.vocab_size,
            output_dim=config.n_embed,
            embeddings_initializer=tf.keras.initializers.GlorotUniform()
        )

        self.pos_embedding=tf.keras.layers.Embedding(
            input_dim=config.n_positions,
            output_dim=config.n_embed,
            embeddings_initializer=tf.keras.initializers.GlorotUniform()
        )

        self.decoder_blocks=[DecoderBlock() for _ in range(config.n_layer)]

    def call(self, input, training=False):

        # token embedding
        x1 = self.embedding(input)
        # position embedding
        seq_length=tf.shape(input)[1]
        batch_size=tf.shape(input)[0]
        pos_ids=tf.range(seq_length)
        pos_ids=tf.expand_dims(pos_ids, 0)
        pos_ids=tf.tile(pos_ids, [batch_size, 1])
        # pos_ids=np.resize(np.arange(input.shape[1]), input.shape)
        # pos_ids=np.resize(np.arange(input.shape[1]), config.batch_size*input.shape[1])
        # pos_ids=pos_ids.reshape((config.batch_size, config.n_positions))
        x2=self.pos_embedding(pos_ids)
        # merging the position embeddings
        x=x1+x2

        # decoder blocks
        for block in self.decoder_blocks:
            x=block(x, training=training)

        return x

# for weight tying
class OutputProjection(tf.keras.layers.Layer):
    def __init__(self, embedding_layer):
        super(OutputProjection, self).__init__()
        self.embedding_layer=embedding_layer
    def call(self, x):
        return tf.matmul(x, self.embedding_layer.embeddings, transpose_b=True)


class GPT(tf.keras.models.Model):
    def __init__(self):
        super(GPT, self).__init__()
        self.decoder=Decoder()
        self.final_layernorm=tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon)
        self.output_projection=OutputProjection(self.decoder.embedding)

    def build(self, input_shape):
        self.input_layer = tf.keras.layers.Input(shape=input_shape[1:])
        _ = self.call(self.input_layer)
        super(GPT, self).build(input_shape)

    def call(self, input, training=False):
        x=self.decoder(input, training=training)
        # final layer normalization
        x=self.final_layernorm(x)

        # GPT2 uses weight tying, where we use the same weights of embedding layer to the final output dense layer to reduce the number of parameters. to achieve this, we perform matrix multiplication on the decoder's final output to the transpose of embedding layer's weights to get the logits.
        logits=self.output_projection(x)

        return logits

    @tf.function   #boosts performance by executing in graph model rather than eager mode(default). also use in inference methods like test_step()
    def train_step(self, data):
        """custom training methos"""
        loss=0.0
        inputs, targets=data
        with tf.GradientTape() as tape:
            logits=self.call(inputs, training=True)
            loss=self.compute_loss(x=inputs, y=targets, y_pred=logits)  #we will only use last token to calcualte the loss because only the last token is predicted using all the input tokens or consider it as the prediction of next token given the input tokens. it is taken cared at custom loss function
        # compute gradients
        gradients=tape.gradient(loss, self.trainable_variables)
        # apply gradients
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {'loss': loss}

    @tf.function
    def on_validation(self, data):
        """custom validation"""
        val_loss=0.0
        inputs, targets=data
        logits=self.call(inputs, training=False)
        val_loss=self.compute_loss(x=inputs, y=targets, y_pred=logits)
        return {'val_loss': val_loss}


model=GPT()
model.build(input_shape=(config.batch_size, config.n_positions))
print(model.summary())


None


In [5]:
gpt=transformers.TFGPT2Model.from_pretrained('gpt2')
gpt.summary()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFGPT2Model.

All the weights of TFGPT2Model were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2Model for predictions without further training.


Model: "tfgpt2_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 transformer (TFGPT2MainLay  multiple                  124439808 
 er)                                                             
                                                                 
Total params: 124439808 (474.70 MB)
Trainable params: 124439808 (474.70 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


number of paramters from our model and gpt model are exactly matching

In [6]:
def load_weights(model, gpt):

    # all layer names and weights of the respective layers from huggingface gpt2-small
    gpt_layer_name=[i.name for i in gpt.layers[0].submodules]
    gpt_weights=[i.weights for i in gpt.layers[0].submodules]

    # setting weights
    # embedding layer
    model.decoder.embedding.set_weights([gpt_weights[gpt_layer_name.index('wte')][0].numpy()])
    # position embedding layer
    model.decoder.pos_embedding.set_weights([gpt_weights[gpt_layer_name.index('wpe')][0].numpy()])

    # decoder blocks
    for i in range(config.n_layer):
        weights=gpt_weights[gpt_layer_name.index(f'h_._{i}')]
        # layernorm1 weights
        gamma=weights[0].numpy()
        beta=weights[1].numpy()
        model.decoder.decoder_blocks[i].layernorm1.set_weights([gamma, beta])
        # layernrom2 weights
        gamma=weights[6].numpy()
        beta=weights[7].numpy()
        model.decoder.decoder_blocks[i].layernorm2.set_weights([gamma, beta])
        # attention weights
        c_attn_weights=weights[2].numpy()
        c_attn_bias=weights[3].numpy()
        c_proj_weights=weights[4].numpy()
        c_proj_bias=weights[5].numpy()
        model.decoder.decoder_blocks[i].attention.c_attn.set_weights(
            [c_attn_weights, c_attn_bias.reshape((c_attn_bias.shape[1]))]
        )
        model.decoder.decoder_blocks[i].attention.c_proj.set_weights(
            [c_proj_weights, c_proj_bias.reshape((c_proj_bias.shape[1]))]
        )
        # feed forward layer weights
        c_fc_weights=weights[8].numpy()
        c_fc_bias=weights[9].numpy()
        c_proj_weights=weights[10].numpy()
        c_proj_bias=weights[11].numpy()
        model.decoder.decoder_blocks[i].ffn.hidden1.set_weights(
            [c_fc_weights, c_fc_bias.reshape((c_fc_bias.shape[1]))]
        )
        model.decoder.decoder_blocks[i].ffn.final_hidden.set_weights(
            [c_proj_weights, c_proj_bias.reshape((c_proj_bias.shape[1]))]
        )

    # final layer normalization weights
    gamma=gpt_weights[gpt_layer_name.index('ln_f')][0].numpy()
    beta=gpt_weights[gpt_layer_name.index('ln_f')][1].numpy()
    model.final_layernorm.set_weights([gamma, beta])

    return model

model=load_weights(model, gpt)

In [7]:
# gpt2 tokenizer - you either load the tokenizer from huggingface or use tiktoken(OpenAI's official library)
# we are going to use from huggingface
tokenizer=transformers.AutoTokenizer.from_pretrained('gpt2')

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
text='what is computer?'
seq=tokenizer(text)['input_ids']
seq=np.array(seq).reshape((1, len(seq)))
# logits prediction
logits=model.predict(seq, verbose=0)
# softmax
ypred=tf.nn.softmax(logits)
ypred=tf.argmax(ypred, axis=-1)
ypred_decoded = tokenizer.decode(ypred[0])

print(f'original text : {text}')
print(f'encoded sequence : {seq}')
print(f'input shape : {seq.shape}')
print(f'predicted tokens : {ypred}')
print(f'predicted sequence : {ypred_decoded}')
print(f'output logits shape : {logits.shape}\t output shape : {ypred.shape}')

original text : what is computer?
encoded sequence : [[10919   318  3644    30]]
input shape : (1, 4)
predicted tokens : [[  11  340   30 3644]]
predicted sequence : , it? computer
output logits shape : (1, 4, 50257)	 output shape : (1, 4)


in the output ',' is predicted using 'what' from input, 'it' is predicted using 'what' and 'is' from input and so on till last output token is predicted using all of the input tokens

In [9]:
# sample text generation
sample="Hello, I'm a super star who like to"
max_length=30
temperature=1.5  # to control the randomness of the output. balance between coherence and creativity.
# low temperature(<1.0 - like 0.1, 0.2, etc) - picks the most probable output(the best but the options/variety will be low) - best for most accurate answers
# high temperature(>1.0 - like 1.5, 1.6, etc) - picks less probable but adds more variety to the output(not the best but creative) - most suited for creative writing tasks, etc
# default temperature - 1.0

top_k=5  # select top-5 outputs

seq = tokenizer(sample)['input_ids']
topk_outputs=[]

for i in range(max_length):
    input=seq[-config.n_positions:]
    input=np.array(input).reshape((1, len(input)))

    logits=model(input)
    scaled_logits=logits/temperature
    scaled_logits=scaled_logits[0, -1, :] #selecting the last token
    # top_k predictions with their indices
    top_k_logits, top_k_indices = tf.math.top_k(scaled_logits, k=top_k)
    topk_outputs.append(top_k_indices.numpy())

    # probability of the top_k selected predictions
    top_k_proba = tf.nn.softmax(top_k_logits).numpy()
    # next token selection
    next_token = np.random.choice(top_k_indices.numpy(), p=top_k_proba)
    seq.append(next_token)


print(f'sample text : {sample}')
print(f'top {top_k} samples : ')
# decoding the sequences
for i in range(top_k):
    seq = tokenizer.decode(np.array(topk_outputs)[:, i])
    print(f'sequence {i+1} - \r{seq}')

sample text : Hello, I'm a super star who like to
top 5 samples : 
sequence 1 -  star who who star, star, star, star, star I, star Star,, a Star a Star, Star Star a Star a Star Star
sequence 2 -  be, I'm a, who super who super star I who star Star, I Star, star, star a super star, star, superstar
sequence 3 -  like like star, who'm a I to I I super, Star super I a star who, who super to Super, to super to Super a
sequence 4 -  love a super like to am I who a stars to Star to super Super star who super I super I, super star to who, super star star
sequence 5 -  have I to to I to to and I Star who stars a. I to to to to Super to Super who a a I Super who a,


the model is not yet fine-tuned

## Micro GPT

we will build a smaller version of GPT2 to pre-train and fine-tune it with some data

In [10]:
config=Config()
# new model parameters
config.vocab_size=50307
config.n_positions=128  #context length
config.n_embed=64   #embedding dimension
config.n_layer=4  #decoder blocks
config.n_head=4   #attention heads in each block
config.n_inner=None
config.activation_function='gelu'
config.resid_pdrop=0.15   #dropout
config.embed_pdrop=0.15  #dropout for embedding layers
config.attn_pdrop=0.15   #dropout for attention
config.layer_norm_epsilon=1e-5
config.initializer_range=0.02
config.batch_size=64

In [25]:
tf.keras.mixed_precision.set_global_policy('float32')

model=GPT()
model.build(input_shape=(config.batch_size, config.n_positions))
print(model.summary())

None


micro GPT has ~3.5 million parameters only

In [12]:
# downloading tiny shakesphere dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-02-16 04:15:09--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-02-16 04:15:10 (25.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [13]:
with open('./input.txt', 'r', encoding='utf-8') as f:
    data=f.read()
data=tokenizer(data)['input_ids']

x=[]
y=[]

for i in range(1, len(data)):
    seq=data[:i][-config.n_positions:]
    x.append(seq)
    y.append(data[i])

x=tf.keras.preprocessing.sequence.pad_sequences(x, padding='post', maxlen=config.n_positions)
y=np.array(y).reshape((-1, 1))

print(x.shape, y.shape)

Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors


(338024, 128) (338024, 1)


In [27]:
def custom_loss(ytrue, ypred):
    """custom loss function to consider only the final token from logits"""
    return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(ytrue, ypred[:, -1, :])

model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=3e-4, beta_1=0.90, beta_2=0.95, epsilon=1e-8, clipnorm=1.0, weight_decay=0.1),    #GPT2 used beta1=0.9, beta2=0.95, eposilon=1e-8, clipnorm=1.0, learning_rate=3e-4
    loss=custom_loss
)

In [28]:
# enable TF32 instead of using FP32 for floating point calculation. using TF32 boosts the speed ~8x then FP32. TF32 is lower precision than FP32 but its negligible so performance is not affected much. use tensorflow math functions instead of numpy as much as possible to take full advantage of it.
tf.config.experimental.enable_tensor_float_32_execution(False)
# to make the computation more faster, we will switch to BF16. here also the signed bits/exponent bits is unchanged(same as FP32) so even though the precision is lower then FP32, its very negligible so its 16x faster then FP32 performance drop is also negligible.
# to use Bfloat16 :
#  set mixed precision policy which will combine FP32 and BF16 so some operations like matrix multiplication and convolutions will use BF16 and other operations will use FP32
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# verification
print(tf.keras.mixed_precision.global_policy())

<DTypePolicy "mixed_bfloat16">


![text](/content/img.png)

In [32]:
# fitting the model on small subset-bf16
model.fit(x[:2000], y[:2000], epochs=50, batch_size=config.batch_size)

Epoch 1/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 473ms/step - loss: 6.1857
Epoch 2/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 83ms/step - loss: 5.6876
Epoch 3/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 82ms/step - loss: 5.4282
Epoch 4/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 82ms/step - loss: 5.2854
Epoch 5/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 83ms/step - loss: 5.2381
Epoch 6/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 86ms/step - loss: 5.1830
Epoch 7/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 85ms/step - loss: 5.1568
Epoch 8/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 88ms/step - loss: 5.1317
Epoch 9/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 86ms/step - loss: 5.0603
Epoch 10/50
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 86ms/step - loss: 5.02

<keras.src.callbacks.history.History at 0x78123a189d10>

In [49]:
# custom training loop with gradient accumulation - instead of updating the gradients at every batch, we load smaller batches then calculate the gradients, add them and update the gradients only after accumulating them for total batch size. for example, let total_batch_size=512 but we cant load 512 batch into the memory at the same time so we load batch of 32, calculate the gradients on the mini_batches and add them till we reach batch size of 512 (i.e., we keep adding the gradients 16 times (32*16=512)) and update the optimizer using the accumulated gradients. in this way we can take advantage of using large batch size eventhough we dont have the compute power to handle large batches at the same time. this gradient accumulation can also be used when we want to train the model on increasing batch size like in gpt2/3 where batch size was increasing as the epoch increases.

epochs=50
total_batch_size=512  #we cant fit this many batches into the memory at once
mini_batch_size=64   #larger the total_batch_size and mini_batch_size, more efficient the training is.
assert total_batch_size%mini_batch_size==0  #confirm that total batch size is divisible by batch size
grad_acum_steps=total_batch_size//mini_batch_size
assert len(x[:2048])%total_batch_size==0  #making sure that total batch size is a multiple of total data length
steps=len(x[:2048])//total_batch_size  #number of steps to cover entire data for 1 epoch

trainset=tf.data.Dataset.from_tensor_slices((x[:2048], y[:2048])).shuffle(buffer_size=16).batch(batch_size=mini_batch_size)

tf.keras.mixed_precision.set_global_policy('float32')
model=GPT()
model.build(input_shape=(config.batch_size, config.n_positions))
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')


optimizer=tf.keras.optimizers.AdamW(learning_rate=3e-4, beta_1=0.90, beta_2=0.95, epsilon=1e-8, clipnorm=1.0, weight_decay=0.1)
loss_fn=custom_loss

@tf.function
def train_step(inputs, targets, model, loss_fn, accumulated_gradients, apply_optimizer=False, optimizer=None):
    with tf.GradientTape() as tape:
        logits=model(inputs, training=True)
        loss=loss_fn(targets, logits)
    gradients=tape.gradient(loss, model.trainable_variables)
    accumulated_gradients=[(accum_grad+grad) for accum_grad, grad in zip(accumulated_gradients,gradients)]

    if apply_optimizer:
        optimizer.apply_gradients(zip(accumulated_gradients, model.trainable_variables))
        return loss


    return loss, accumulated_gradients

accumulated_gradients=[tf.zeros_like(var) for var in model.trainable_variables]
for epoch in range(epochs):
    start=time.time()
    epoch_loss=0.0

    batch_loss=0.0
    for step, data in enumerate(trainset):
          #to store gradients

        inputs, targets=data
        if step%grad_acum_steps==0:
            loss = train_step(inputs, targets, model, loss_fn, accumulated_gradients, apply_optimizer=True, optimizer=optimizer)
            batch_loss+=loss
            batch_loss/=grad_acum_steps
            epoch_loss+=batch_loss
            batch_loss=0.0
            accumulated_gradients=[tf.zeros_like(var) for var in model.trainable_variables]
            # print(f'updated at {step}')

        else:
            loss, accumulated_gradients = train_step(inputs, targets, model, loss_fn, accumulated_gradients)
            batch_loss+=loss

    epoch_loss/=steps
    end=time.time()
    print(f'epoch:{epoch+1}/{epochs} | avg.loss/epoch:{epoch_loss:.7f} | time:{(end-start):.3f} | tokens/sec:{2048/(end-start):.3f}')


epoch:1/50 | avg.loss/epoch:8.4196215 | time:14.065 | tokens/sec:145.607
epoch:2/50 | avg.loss/epoch:8.3361015 | time:2.906 | tokens/sec:704.786
epoch:3/50 | avg.loss/epoch:8.2450428 | time:2.873 | tokens/sec:712.833
epoch:4/50 | avg.loss/epoch:8.1490850 | time:2.885 | tokens/sec:709.957
epoch:5/50 | avg.loss/epoch:8.0501728 | time:2.900 | tokens/sec:706.280
epoch:6/50 | avg.loss/epoch:7.9489222 | time:2.950 | tokens/sec:694.230
epoch:7/50 | avg.loss/epoch:7.8459454 | time:2.932 | tokens/sec:698.601
epoch:8/50 | avg.loss/epoch:7.7420397 | time:2.971 | tokens/sec:689.387
epoch:9/50 | avg.loss/epoch:7.6361456 | time:2.995 | tokens/sec:683.834
epoch:10/50 | avg.loss/epoch:7.5302601 | time:3.053 | tokens/sec:670.750
epoch:11/50 | avg.loss/epoch:7.4247112 | time:3.064 | tokens/sec:668.474
epoch:12/50 | avg.loss/epoch:7.3181930 | time:3.147 | tokens/sec:650.741
epoch:13/50 | avg.loss/epoch:7.2114353 | time:3.150 | tokens/sec:650.098
epoch:14/50 | avg.loss/epoch:7.1056900 | time:5.104 | token

this is the pretrained stage of the model. we will then fine-tune the model based on our needs. after then we can check the performance of the model using datasets like helloswag, etc

for fine-tuning see ./finetuning.ipynb