In [1]:
!pip install torchinfo



In [59]:
import pandas as pd
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
import torch
import torch.nn as nn
from torchinfo import summary
import time
import math
from torch.optim.lr_scheduler import LambdaLR

In [60]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
!curl -L -o poem-dataset.zip \
https://www.kaggle.com/api/v1/datasets/download/marufchowdhury/poem-dataset

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 10.3M  100 10.3M    0     0  18.0M      0 --:--:-- --:--:-- --:--:-- 57.5M


In [5]:
!unzip poem-dataset.zip


Archive:  poem-dataset.zip
  inflating: Poems_Dataset.csv       
  inflating: poemDatasetWithSummary.csv  


In [61]:
df=pd.read_csv("Poems_Dataset.csv")
df=df["Poem Content"]
data=df.tolist()


In [62]:
#Hyper Parameters
context_window_length=128
batch_size=256
n_embed=288
n_head=9
n_layers=9
v_size=9000
head_size=n_embed//n_head

In [63]:
tokenizer=Tokenizer(BPE())
tokenizer.pre_tokenizer=ByteLevel(add_prefix_space=True)
trainer=BpeTrainer(vocab_size=v_size)
tokenizer.train_from_iterator(data,trainer)






In [64]:
out=tokenizer.encode("hello my guy how are you")
out.tokens,out.ids


(['Ġhell', 'o', 'Ġmy', 'Ġguy', 'Ġhow', 'Ġare', 'Ġyou'],
 [2368, 78, 280, 5194, 588, 363, 257])

In [65]:
all_ids=[]

for s in data:
    all_ids.extend(tokenizer.encode(s).ids)

idss=torch.tensor(all_ids,dtype=torch.long).to(device)

In [66]:
len(idss)

5462341

In [67]:

ids=idss[:1300000]
len(ids)

1300000

In [68]:
def generator(ids,batch_size,cwl):
    X=[]
    Y=[]
    count=0

    for i in range(len(ids)-cwl):
        X.append(ids[i:i+cwl])
        Y.append(ids[i+1:i+cwl+1])
        count+=1

        if count==batch_size:
            yield torch.stack(X).to(device),torch.stack(Y).to(device)
            X=[]
            Y=[]
            count=0

In [69]:
class AttentionHead(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key=nn.Linear(n_embed,head_size) #(B,T,C)-->(B,T,H)
        self.query=nn.Linear(n_embed,head_size) #(B,T,C)-->(B,T,H)
        self.value=nn.Linear(n_embed,head_size)  #(B,T,C)-->(B,T,H
        self.dropout=nn.Dropout(0.1)

    def forward(self,x):
        k=self.key(x)     #(B,T,H)
        q=self.query(x)   #(B,T,H)
        v=self.value(x)   #(B,T,H)

        # Do Dot product of k and q

        weights=k@q.transpose(-2,-1)*head_size**-0.5  # (B,T,H) x (B,H,T) --> (B,T,T)
        T=x.size(1)
        mask=torch.tril(torch.ones(T,T,device=x.device))
        weights=weights.masked_fill(mask==0,float('-inf'))
        weights=nn.functional.softmax(weights,dim=-1)
        weights = self.dropout(weights)

        output=weights@v #(B,T,T) x (B,T,H) --> (B,T,H)
        return output

In [70]:
class MultiHead(nn.Module):
    def __init__(self,n_head,head_size):
        super().__init__()
        self.heads=nn.ModuleList([AttentionHead(head_size) for _ in range(n_head)])
        self.project=nn.Linear(n_head*head_size,n_embed)
        self.dropout=nn.Dropout(0.1)
    def forward(self,x):
        out=torch.cat([h(x) for h in self.heads],dim=-1)  # (B,T,H*N)
        #out=self.project(out)  # (B,T,H*N) --> (B,T,C) 
        out = self.dropout(out)
        return out

In [71]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.FF=nn.Sequential(
            nn.Linear(n_embed,3*n_embed),
            nn.ReLU(),
            nn.Linear(3*n_embed,n_embed),
            nn.Dropout()
        )

    def forward(self,x):
        return self.FF(x)

In [72]:
class Block(nn.Module):
    def __init__(self,n_embed,n_head):
        super().__init__()
        head_size=n_embed//n_head
        self.SelfAtt = MultiHead(n_head, head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self,x):
        x = x + self.SelfAtt(self.ln1(x)) + self.ffwd(self.ln2(x))
        return x  #(B,T,C)

In [73]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed=nn.Embedding(v_size,n_embed)  # (B,T) --> (B,T,C)
        self.pos_embed=nn.Embedding(context_window_length,n_embed) # (T) --> (T,C)

        self.blocks=nn.Sequential(*[Block(n_embed,n_head) for _ in range(n_layers)])
        self.final_layernorm = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, v_size)

    def forward(self,x):
        # x ==> (B,T)

        tok_embeds=self.embed(x) # (B,T,C)
        pos_embeds=self.pos_embed(torch.arange(x.size(1),device=x.device)) #(T,C)
        x=tok_embeds + pos_embeds # pos_embed r broadcasted and added to every batch element

        x=self.blocks(x)
        x=self.final_layernorm(x)
        logits=self.lm_head(x)

        return logits


    @torch.no_grad()
    def generate(model,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            if idx.size(1)>context_window_length:
                idx_cond=idx[:,-context_window_length:]
            else:
                idx_cond=idx

            logits=model(idx_cond)
            probs=torch.softmax(logits[:,-1,:],dim=-1)
            next_token=torch.multinomial(probs,1)
            idx=torch.cat((idx,next_token),dim=1)

        return idx


In [77]:
model=GPT().to(device)
model=torch.compile(model)
optimizer=torch.optim.AdamW(model.parameters(),lr=0.00007,fused=True)
criterion=nn.CrossEntropyLoss()
def lr_lambda(epoch):
    return 0.5*(1+math.cos(math.pi*epoch/epochs))

scheduler=LambdaLR(optimizer,lr_lambda)
epochs=20

In [78]:
summary(model)

Layer (type:depth-idx)                                  Param #
OptimizedModule                                         --
├─GPT: 1-1                                              --
│    └─Embedding: 2-1                                   2,592,000
│    └─Embedding: 2-2                                   36,864
│    └─Sequential: 2-3                                  --
│    │    └─Block: 3-1                                  832,896
│    │    └─Block: 3-2                                  832,896
│    │    └─Block: 3-3                                  832,896
│    │    └─Block: 3-4                                  832,896
│    │    └─Block: 3-5                                  832,896
│    │    └─Block: 3-6                                  832,896
│    │    └─Block: 3-7                                  832,896
│    │    └─Block: 3-8                                  832,896
│    │    └─Block: 3-9                                  832,896
│    └─LayerNorm: 2-4                                 

In [None]:
scaler=torch.cuda.amp.GradScaler()

for i in range(epochs):
    step=0
    start_epoch=time.time()
    last_print_time=start_epoch

    for x,y in generator(ids,batch_size,context_window_length):
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast():
            logits=model(x)
            logits=logits.view(-1,logits.size(-1))
            y=y.view(-1)
            loss=criterion(logits,y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        step+=1
        if step%150==0:
            now=time.time()
            print(
                f"Epoch: {i+1}, "
                f"Step: {step}, "
                f"Loss: {loss.item():.4f}, "
                f"Time/100 batches: {(now-last_print_time):.2f} sec")
            last_print_time=now
    scheduler.step()
    
    end_epoch=time.time()
    print(f"Epoch {i+1} total time: {(end_epoch-start_epoch):.2f} sec")
    torch.save(model.state_dict(),f"Model_epoch_{i+1}.pt")

  scaler=torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Exception ignored in: <function ExactWeakKeyDictionary.__setitem__.<locals>.<lambda> at 0x7b2e21a0f9c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 954, in <lambda>
    self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))

KeyboardInterrupt: 


In [34]:
val_ids=idss[-20000:-1]
len(val_ids)

19999

In [37]:
for i in range(1,28):
    temp=torch.load(f"Model_epoch_{i}.pt",map_location=device)
    temp={k.replace("_orig_mod.",""):v for k,v in temp.items()}
    model=GPT()
    model.load_state_dict(temp)
    model.to(device)
    avg_loss=0
    count=0
    with torch.no_grad():
        model.eval()
        for x,y in generator(val_ids,batch_size,context_window_length):
            logits=model(x)
            logits=logits.view(-1,logits.size(-1))
            y=y.view(-1)
            loss=criterion(logits,y)
            count+=1
            avg_loss+=loss.item()

    print(f"Model_{i} Val Loss:{avg_loss/count}")
        

Model_1 Val Loss:6.1067119751657755
Model_2 Val Loss:5.841304796082633
Model_3 Val Loss:5.666963466576168
Model_4 Val Loss:5.545412727764675
Model_5 Val Loss:5.459126625742231
Model_6 Val Loss:5.394868578229632
Model_7 Val Loss:5.344340511730739
Model_8 Val Loss:5.30354768889291
Model_9 Val Loss:5.269190847873688
Model_10 Val Loss:5.240001865795681
Model_11 Val Loss:5.215300704751696
Model_12 Val Loss:5.1942658339227945
Model_13 Val Loss:5.175960012844631
Model_14 Val Loss:5.160851972443717
Model_15 Val Loss:5.149416131632669
Model_16 Val Loss:5.14115423815591
Model_17 Val Loss:5.136325606278011
Model_18 Val Loss:5.133300117083958
Model_19 Val Loss:5.132233653749738
Model_20 Val Loss:5.1347266009875705
Model_21 Val Loss:5.139067879744938
Model_22 Val Loss:5.144578048161098
Model_23 Val Loss:5.15180368082864
Model_24 Val Loss:5.160773115498679
Model_25 Val Loss:5.17087288413729
Model_26 Val Loss:5.1827168720109125
Model_27 Val Loss:5.194707700184414


In [None]:
while True:
    x=input("Enter starting text:")
    y=tokenizer.encode(x).ids
    context=torch.tensor([y],device=device)
    print(tokenizer.decode(model.generate(context, max_new_tokens=50)[0].tolist()))


Enter starting text:hello
hell o ’ w D u cour ’ come je ’ ro ’ ty ’ u y ’ i ’ ’ ele ky ex ä ’ min ex ’ mo m ’ t ’ j ’ ti ’ m ä ny j not earth ä w y i ’ a ’ i
Enter starting text:sun rises 
sun rises in k ids St ted out . A A point lo ch , - be er ers , T H ching to see more , W of pur P ass down from its corner , n be every sa ace your thr w al th the sweet i y ,
Enter starting text:sagarika
s ag ar i k a los ces their And Com i ó aws B es ition i y a la i j an am ty or a bled a j ic o ch once you can the building i ol er j u yo ec i en , C R una e a their que


KeyboardInterrupt: Interrupted by user