In [1]:

import torch.nn as nn
import torch 
from dataclasses import dataclass
from torch.nn import functional as F
from transformers import AutoTokenizer
import time
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import glob 
import random
import datetime

device = "mps"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# import gc
# print(gc.collect())
# print(torch.mps.empty_cache())

In [6]:

class MyGPTDataLoader:
    def __init__(self, B, T, input_file_path, input_files_list):

        self.input_file_path = input_file_path
        self.input_files_list = input_files_list
        self.B = B
        self.T = T

        with open(input_file_path, "r") as file : 
            data = file.read()
        
        self.enc = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
        tokens = self.enc.encode(data)
        self.tokens = torch.tensor(tokens)
        self.current_position = 0


        print(f"total tokens {len(self.tokens)}")
        print(f"1 epoch = {len(self.tokens) // (B*T)} batches")

    def next_batch(self) : 
        B, T = self.B, self.T

        buff = self.tokens[self.current_position : self.current_position+B*T+1]
        x = buff[:-1].view(B, T)
        y = buff[1:].view(B, T)

        self.current_position += B*T

        if self.current_position+B*T+1 > len(self.tokens) : 
            
            random.shuffle(self.input_files_list)

            with open(self.input_file_path, 'w') as output_file : 
                output_file.truncate(0)
                #TODO : We could add stopwords after every document, to indicate model that this is different. 
                for input_files_path in self.input_files_list : 
                    with open(input_files_path, 'r') as input_file : 
                        data = input_file.read()
                        output_file.write(data)

            with open(self.input_file_path, 'r') as file : 
                data = file.read()
            tokens = self.enc.encode(data)
            self.tokens = torch.tensor(tokens)
            self.current_position = 0

        return x, y


@dataclass
class MyGPTConfig:
    n_ctx : int = 1024
    vocab_size : int = 50257
    n_embed : int = 768
    n_head : int = 12
    n_layer : int = 12


class SelfAttention(nn.Module) :
    def __init__(self, config) :
        self.config = config
        super().__init__()

        # We divide n_embed into n_heads metrices to calculate attention q,k,v metrices
        assert config.n_embed % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embed, 3*config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.c_proj.MYGPT_SCALE_INIT = 1

        self.register_buffer("bias", torch.tril(torch.ones(config.n_ctx, config.n_ctx))
                             .view(1, 1, config.n_ctx, config.n_ctx))

    def forward(self, x) :
        B, T, C = x.size()  # Batch size, token length, n_embed
        qkv = self.c_attn(x)

        q, k, v = qkv.split(self.config.n_embed, dim=2)
        q = q.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)  ## Dimension = (B, n_head, T, n_embed // n_head)
        k = k.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)
        v = v.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Setting is_casual = True automatically ensures masking and lower trianglular matrix structure

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by sir
        # output projection
        y = self.c_proj(y)
        return y



class MLP(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()

        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)
        self.c_proj.MYGPT_SCALE_INIT = 1

    def forward(self, x) : 
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)

        return x


class GPTOneBlock(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()

        self.ln_1 = nn.LayerNorm(config.n_embed)
        self.attn = SelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embed)
        self.mlp = MLP(config)

    def forward(self, x) :
        # At attn stage they are going to exchange information with each other in respect to how interesing they find each other, while in MLP stage no exchange takes place and each thinks individually that what they found in themselves and other tokens in the attn interaction that they recently had. 
        # Residual pathways are important optimization step as they help to pass gradients from top to bottom so that bottom also gets something to improve upon. This helps mostly in very deep neural networks. 
        x = x + self.attn(self.ln_1(x))   # WE want a clear path of only pure 'x' to go all the way from inputs to output straight so that during backprop at this juction gradients get's distributed , and some of them go processed through these attn/MLP layers while ensuring some portion of it goes downward straight to the inputs. 
        x = x + self.mlp(self.ln_2(x))   # Continuing above, this is a type of optimization technique . 
        return x



class MyGPT2(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embed), # Word-Token-Embedding (weights of the token embedding)
            wpe = nn.Embedding(config.n_ctx, config.n_embed), #Word-position embedding (wts of postion embedding)
            h = nn.ModuleList(GPTOneBlock(config) for _ in range(config.n_layer)), #This will contain all the hidden blocks repeated n_layers time. Each block contains layerNorm1, self attention_mechanism, layernorm2 and mlp. 
            ln_f = nn.LayerNorm(config.n_embed) #Gpt2 paper introduced a final layer norm to be added after all the attention blocks. 
        ))

        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)   # The final language model head to project n_embed into n_vocab space. 

        # Implement weight sharing as shown in the paper
            # Also saves 40M parameters learning. 
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self.__init_weights)

    def __init_weights(self, module) : 
        #FIXME : Wte and lm_head are weight sharing , so they will be intialized twice. We could fix that. 
        if isinstance(module, nn.Linear) : 
            std = 0.02
            if hasattr(module, "MYGPT_SCALE_INIT") : 
                std *= (2 * self.config.n_layer) ** -0.5 # NO. of residual layers is 2 x n_layers. Every single of layer has two pathwasys that add up -> MLP and attn. 
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)

            if module.bias is not None : 
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding) : 
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)


    def forward(self, idx, targets=None) : 
        B, T = idx.size()
        assert T <= self.config.n_ctx
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_embed = self.transformer.wpe(pos)
        tok_embed = self.transformer.wte(idx)

        x = pos_embed + tok_embed
        for one_block in self.transformer.h : 
            # print("my_gpt_forward_for_loop", x.size())
            x = one_block(x)

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)


        loss = None
        if targets is not None : 
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    
    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
        return optimizer


In [7]:
num_sentences_to_generate = 5
max_seq_length = 300
B = 4
T = 1024

# Store train stats in csv file. 
logs_path = f"logs_{datetime.datetime.today().date()}.csv"
open(logs_path, 'w').close()
logs_file = open(logs_path, "w+")
logs_file.write("step,loss,norm,lr,token_per_sec,time\n")
logs_file.flush()
logs_file.close()

In [None]:
model = MyGPT2(MyGPTConfig(vocab_size=50304))
model.eval()
model.to(device)
print("Model Loaded")

input_data_files = glob.glob("dataset/*/hi/*")
train_dataloader = MyGPTDataLoader(B, T, "shayar.txt", input_data_files)
print("DataLoader Ready")


total_req_batch_size = 32768 
assert total_req_batch_size % (B * T) == 0, "we should fit"
grad_accum_steps = total_req_batch_size // (B * T)
print(f"TOtal batch simulation : {total_req_batch_size}, and it will be reached per {grad_accum_steps} steps")

In [None]:
# Custom wrapper for the constant phase
class ConstantLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, constant_lr, last_epoch=-1):
        self.constant_lr = constant_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [self.constant_lr for _ in self.base_lrs]
    


max_lr = 6e-4
warmup_steps = 100
total_var_lr_steps = 900
constant_lr = 0.1 * max_lr
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=max_lr,  betas=(0.9, 0.95), device_type=device)

# Combine schedulers in SequentialLR
lr_scheduler = SequentialLR(
    optimizer,
    schedulers=[LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=warmup_steps), 
                CosineAnnealingLR(optimizer, T_max=total_var_lr_steps - warmup_steps, eta_min=constant_lr), 
                ConstantLRScheduler(optimizer, constant_lr=constant_lr)],
    milestones=[warmup_steps, total_var_lr_steps],  # Transition points
)

steps = []
lr_list = []

In [None]:
for step in range(total_train_steps + 1000) :
    optimizer.zero_grad()
    t0 = time.time()
    loss_accum = 0
    for grad_accum_step in range(grad_accum_steps): 

        x, y = train_dataloader.next_batch()
        x, y = x.to(device), y.to(device) 
        logits, loss = model(x, y)
        loss /= grad_accum_steps
        loss.backward()

        loss_accum += loss.detach()
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    optimizer.step()
    lr_scheduler.step()

    torch.mps.synchronize()
    t1 = time.time()

    steps.append(step)
    lr_list.append(lr_scheduler.get_last_lr())

    #NOTE : Maybe per 100 more suitable
    if step % 200 == 0: 
        checkpoint = {
            "model" : model.state_dict()}

        torch.save(checkpoint, f"model_{step}.ckpt")
        print(f"step: {step} | loss:{loss_accum.item():.4f} | norm : {norm:.4f} | lr : {lr_scheduler.get_last_lr()[0]} | token_per_sec = {(train_dataloader.B * train_dataloader.T * grad_accum_steps) / (t1-t0)}) | time : {t1-t0}")


    #Write Fetched 
    with open(logs_path, "a") as logs_file:
        logs_file.write(f"{step},{loss_accum.item():.4f},{norm:.4f},{lr_scheduler.get_last_lr()[0]},{(train_dataloader.B * train_dataloader.T * grad_accum_steps) / (t1-t0)},{t1-t0}\n")


In [9]:
# #TODO : 
#     # Torch.autocast () for loss thing into fp.16 --> Not supported mac. Recent push. 
#     # Enable torch.complile for model. But currently not supported for MPS backend. 


# import torch._dynamo
# torch._dynamo.config.suppress_errors = True

In [None]:
# checkpoint = {"model" : model.state_dict()}
# torch.save(checkpoint, f"model_{step}.ckpt")
# print(f"step: {step} | loss:{loss_accum.item():.4f} | norm : {norm:.4f} | lr : {lr_scheduler.get_last_lr()[0]} | token_per_sec = {(train_dataloader.B * train_dataloader.T * grad_accum_steps) / (t1-t0)}) | time : {t1-t0}")