<a href="https://colab.research.google.com/github/Mithun-033/Text-To-SQL-GPT/blob/main/GPT_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
import time

In [None]:
dev="cuda" if torch.cuda.is_available() else "cpu"

In [None]:
@dataclass
class Config:
    n_embed:int=1
    cwl:int=1
    b_size:int=1
    head_size :int=1
    n_head :int=1
    vocab_size :int=1
    n_layer :int=1


In [None]:
class AttentionHead(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.q=nn.Linear(config.n_embed,config.head_size,bias=False)
        self.k=nn.Linear(config.n_embed,config.head_size,bias=False)
        self.v=nn.Linear(config.n_embed,config.head_size,bias=False)

        self.dropout=nn.Dropout(p=0.15)

    def forward(self,x):
        T=x.shape(1)

        keys=self.k(x)  # (B,T,H)
        query=self.q(x)  # (B,T,H)
        value=self.v(x)  # (B,T,H)

        weights=(keys@(query.transpose(-2,-1)))/self.config.head_size**0.5 # (B,T,T)
        mask=torch.trill(torch.ones(T,T,device=dev))
        weights=weights.masked_fill(mask==0,float("-inf"))

        weights=nn.Softmax(weights,dim=-1) #(B,T,T)

        logits=weights@value #(B,T,H)
        logits=self.dropout(logits)
        return logits



In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.Multi=nn.ModuleList([AttentionHead() for _ in range(config.n_head)])
        self.project=nn.Linear(config.n_head*config.head_size,config.n_embed)
        self.dropout=nn.Dropout(0.2)

    def forward(self,x):
        output=torch.cat([head(x) for head in self.Multi],dim=-1) #(B,T,H*N)
        output=self.project(output) #(B,T,N)
        output=self.dropout(output)

        return output


In [None]:
class MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.layer=nn.Sequential(
            nn.Linear(config.n_embed,5*config.n_embed),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(5*config.n_embed,config.n_embed)
        )
    def forward(self,x):
        return self.layer(x)

In [None]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config

        self.PreNorm1=nn.LayerNorm(config.n_embed)
        self.attention=MultiHeadAttention()
        self.PreNorm2=nn.LayerNorm(config.n_embed)
        self.FeedForwardLayer=MLP()

    def forward(self,x):
        x=x+self.attention(self.PreNorm1(x))
        x=x+self.FeedForwardLayer(self.PreNorm2)  # Residual connection

        return x

In [None]:
class GPT(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config
        # X --> (B,T)
        self.embed = nn.Embedding(config.vocab_size,config.n_embed)  # (B,T,C)
        self.pos_embed=nn.Linear(config.cwl,config.n_embed) # (T,C)
        self.blocks=nn.Sequential(*[Block() for _ in range(config.n_layer)])
        self.Final_norm=nn.LayerNorm(config.n_embed)
        self.Dense=nn.Linear(config.n_embed,config.vocab_size)

    def forward(self,x):
        x=self.embed(x) # (B,T,C)
        pos=self.pos_embed(torch.arange(self.config.cwl),device=dev)
        x+=pos #sneaky broadcast

        out=self.blocks(x) # (B,T,C)
        out=self.Final_norm(out) # (B,T,C)

        logits=self.Dense(out) #(B,T,V)

        return logits



In [None]:
model=GPT()
model=torch.compile(model)
model.to(device)
optimizer=torch.optim.AdamW(model.parameters(),lr=6e-4)
criterion=nn.CrossEntropyLoss()
epochs=1
total_batches=len(ids)/b_size

In [None]:
for i in range(epochs):
    optimizer.zero_grad()
    steps=0
    start=time.time()
    loss_accum=0
    for x,y in generator(ids,batch_size,cwl):
        with torch.autocast(device_type=dev,dtype=torch.bfloat16):
            out=model(x)
            out=out.view(-1,1)
            y=y.view(-1)
            loss+=criterion(out,y)/4
            loss_accum+=loss.item()
            loss.backward()

            steps+=1

            if steps%4==0:
                optimizer.step()
                optimizer.zero_grad()

            if steps%100:
                end=time.time()
                print(f"Loss :{loss_accum},Time :{start-end},Batches :{steps/total_batches}")
                start=end
                loss_accum=0




