<a href="https://colab.research.google.com/github/Mithun-033/Text-To-SQL-GPT/blob/main/GPT_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
!pip install galore_torch -q
!pip install torchinfo -q

In [50]:
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
import time
from galore_torch import GaLoreAdamW
from torchinfo import summary

In [30]:
dev="cuda" if torch.cuda.is_available() else "cpu"

In [67]:
@dataclass
class Config:
    n_embed:int=768
    cwl:int=1024
    b_size:int=32
    n_head :int=16
    head_size :int = n_embed//n_head
    vocab_size :int=50304
    n_layer :int=16


In [32]:
class AttentionHead(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.q=nn.Linear(config.n_embed,config.head_size,bias=False)
        self.k=nn.Linear(config.n_embed,config.head_size,bias=False)
        self.v=nn.Linear(config.n_embed,config.head_size,bias=False)

        self.dropout=nn.Dropout(p=0.15)

    def forward(self,x):
        T=x.shape(1)

        key=self.k(x)  # (B,T,H)
        query=self.q(x)  # (B,T,H)
        value=self.v(x)  # (B,T,H)

        weights=nn.functional.scaled_dot_product_attention(query,key,value,is_causal=True)  #For faster training
        # weights=(keys@(query.transpose(-2,-1)))/self.config.head_size**0.5 # (B,T,T)
        # mask=torch.trill(torch.ones(T,T,device=dev))
        # weights=weights.masked_fill(mask==0,float("-inf"))

        # weights=nn.Softmax(weights,dim=-1) #(B,T,T)

        # logits=weights@value #(B,T,H)
        logits=self.dropout(logits)
        return logits



In [44]:
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.Multi=nn.ModuleList([AttentionHead(config) for _ in range(config.n_head)])
        self.project=nn.Linear(config.n_head*config.head_size,config.n_embed)
        self.dropout=nn.Dropout(0.2)

    def forward(self,x):
        output=torch.cat([head(x) for head in self.Multi],dim=-1) #(B,T,H*N)
        output=self.project(output) #(B,T,C)
        output=self.dropout(output)

        return output


In [34]:
class MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.layer=nn.Sequential(
            nn.Linear(config.n_embed,5*config.n_embed),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(5*config.n_embed,config.n_embed)
        )
    def forward(self,x):
        return self.layer(x)

In [46]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.PreNorm1=nn.LayerNorm(config.n_embed)
        self.attention=MultiHeadAttention(config)
        self.PreNorm2=nn.LayerNorm(config.n_embed)
        self.FeedForwardLayer=MLP(config)

    def forward(self,x):
        x=x+self.attention(self.PreNorm1(x))
        x=x+self.FeedForwardLayer(self.PreNorm2)  # Residual connection

        return x

In [64]:
class GPT(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.embed=nn.Embedding(config.vocab_size,config.n_embed)
        self.pos_embed=nn.Embedding(config.cwl,config.n_embed)

        self.blocks=nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )

        self.final_norm=nn.LayerNorm(config.n_embed)
        self.Dense=nn.Linear(config.n_embed,config.vocab_size,bias=False)

        self.Dense.weight=self.embed.weight

    def forward(self,x):
        B,T=x.shape

        tok=self.embed(x)              # (B,T,C)
        pos_ids=torch.arange(T,device=x.device)
        pos=self.pos_embed(pos_ids)    # (T,C)

        x=tok+pos

        x=self.blocks(x)
        x=self.final_norm(x)
        logits=self.Dense(x)

        return logits


In [65]:
def generator(ids,b_size,cwl):
    step=b_size*cwl
    for i in range(0,len(ids)-step-1,step):
        Tot=torch.from_numpy(ids[i:i+step+1])
        X=Tot[:-1].view(b_size,cwl)
        y=Tot[1:].view(b_size,cwl)
        yield X,y


In [68]:
model=GPT(Config())
model=torch.compile(model)
model.to(dev)

summary(model)

Layer (type:depth-idx)                                  Param #
OptimizedModule                                         --
├─GPT: 1-1                                              --
│    └─Embedding: 2-1                                   38,633,472
│    └─Linear: 2-2                                      787,200
│    └─Sequential: 2-3                                  --
│    │    └─Block: 3-1                                  8,265,984
│    │    └─Block: 3-2                                  8,265,984
│    │    └─Block: 3-3                                  8,265,984
│    │    └─Block: 3-4                                  8,265,984
│    │    └─Block: 3-5                                  8,265,984
│    │    └─Block: 3-6                                  8,265,984
│    │    └─Block: 3-7                                  8,265,984
│    │    └─Block: 3-8                                  8,265,984
│    │    └─Block: 3-9                                  8,265,984
│    │    └─Block: 3-10           

In [70]:
total_batches=100000000/(Config.b_size*Config.cwl)
criterion=nn.CrossEntropyLoss()
epochs=1




In [None]:
galore_params=[]
normal_params=[]
for j,i in model.named_parameters():
    if i.ndim<2:
        normal_params.append(i)
    else:
        galore_params.append(i)

optimizer=GaLoreAdamW([
    {"params":galore_params,"rank":128},
    {"params":normal_params}
    ],
    lr=6e-4)

warmup_steps=20
cosine_steps=total_batches-warmup_steps
warmup_scheduler=torch.optim.lr_scheduler.LinearLR(optimizer=optimizer,
                                                   start_factor=0.1,
                                                   end_factor=1,
                                                   total_iters=warmup_steps)
cosine_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                            T_max=cosine_steps)

scheduler=torch.optim.lr_scheduler.SequentialLR(optimizer=optimizer,
                                                schedulers=[warmup_scheduler,cosine_scheduler],
                                                milestones=warmup_steps)


In [None]:
for k in range(11):
    ids=np.load(f"tokens_{k}.npy",mmap_mode="r")
    for i in range(epochs):
        start=time.time()
        loss_accum=0
        optimizer.zero_grad()
        for x,y in generator(ids,Config.b_size,Config.cwl):
            model.train()
            with torch.autocast(device_type=dev,dtype=torch.bfloat16):
                out=model(x)
                out=out.view(-1,Config.vocab_size)
                y=y.view(-1)
                loss=criterion(out,y)
                loss_accum+=loss.item()
                loss=loss/4
                loss.backward()

                steps+=1

                if steps%4==0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                if steps%100==0:
                    end=time.time()
                    print(f"Loss :{loss_accum/100:.5f},Time :{end-start},Batches :{steps/total_batches}")
                    torch.save({
                        "model":model.state_dict(),
                        "optimizer":optimizer.state_dict(),
                        "scheduler":scheduler.state_dict(),
                        "step":steps
                    },"GPT-2.pt")

                    start=end
                    loss_accum=0

                if steps%1000==0:
                    model.eval()
                    loss_accum_val=0
                    val_steps=0
                    with torch.no_grad():
                        for x,y in generator(val_ids,Config.b_size,Config.cwl):
                            with torch.autocast(device_type=dev,dtype=torch.bfloat16):
                                out=model(x)
                                out=out.view(-1,Config.vocab_size)
                                y=y.view(-1)
                                loss=criterion(out,y)
                            loss_accum_val+=loss.item()
                            val_steps+=1


                    print(f"Val Loss :{loss_accum_val/val_steps:.4f}")