In [1]:
!pip install torch torchvision



In [2]:
!pip install wget



In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [13]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
import wget
url="https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
download=wget.download(url)
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [4]:
with open('input.txt','r',encoding='utf-8') as f:
    text=f.read()

In [15]:
text[:200]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'

In [5]:
chars=sorted(list(set(text)))
vocab_size=len(chars)
stoi={ch:i for i,ch in enumerate(chars)}
itos={i:ch for i,ch in enumerate(chars)}
encode=lambda s:[stoi[c] for c in s]
decode=lambda l: ''.join([itos[i] for i in l])

In [6]:
data=torch.tensor(encode(text),dtype=torch.long)
n=int(len(data)*0.9)
train=data[:n]
val=data[n:]

In [69]:
batch_size=64
block_size=256
n_emd=384
max_iters=5000
n_heads=6
dropout=0.2
n_layer=6
device="cpu"
eval_iters=200
eval_interval=2000
l_r=3e-4

In [8]:
class head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.query=nn.Linear(n_emd,head_size,bias=False)
        self.key=nn.Linear(n_emd,head_size,bias=False)
        self.value=nn.Linear(n_emd,head_size,bias=False)
        self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
        self.dropout=nn.Dropout(dropout)

    def forward(self,x):
        b,t,c=x.shape
        k=self.query(x)
        q=self.query(x)
        v=self.value(x)

        wei=q@k.transpose(-2,-1)*k.shape[-1]**-0.5
        wei=wei.masked_fil(self.tril[:t,:t]==0 ,float('-inf'))
        wei=F.softmax(wei,dim=-1)
        wei=self.dropout(wei)
        out=wei@v
        return out

In [9]:
class Multiheadattention(nn.Module):
    def __init__(self, n_heads,head_size):
        super().__init__()
        self.heads=nn.ModuleList([head(head_size) for _ in range(n_heads) ])
        self.proj=nn.Linear(n_heads*head_size,n_emd)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x):
        out=torch.cat([h(x) for h in self.heads],dim=-1)
        out=self.dropout(self.proj(out))
        return out


In [10]:
class FeedForward(nn.Module):
    def __init__(self, n_embds):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(n_embds,4*n_embds),
            nn.ReLU(),
            nn.Linear(4*n_embds,n_embds),
            nn.Dropout(dropout)
        )
    
    def forward(self,x):
        return self.net(x)

In [11]:
class block(nn.Module):
    def __init__(self, n_embds, n_head):
        super().__init__()
        head_size=n_embds//n_head
        self.sa=Multiheadattention(n_head,head_size)
        self.fwd=FeedForward(n_embds)
        self.n1=nn.LayerNorm(n_embds)
        self.n2=nn.LayerNorm(n_embds)
    
    def forward(self,x):
        x=x+self.sa(self.n1(x))
        x=x+self.fwd(self.n2(x))
        return x

In [67]:
class gptlanguagemodel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding=nn.Embedding(vocab_size,n_emd)
        self.positional_embedding=nn.Embedding(block_size,n_emd)
        self.layer=nn.LayerNorm(n_emd)
        self.leniar=nn.Linear(n_emd,vocab_size)
        self.blocks=nn.Sequential(*[block(n_emd,n_heads) for _ in range(n_layer)])

        self.apply(self._init_weight)

    def _init_weight(self,module):
        if isinstance(module,nn.Linear):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.2)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.2)
    
    def forward(self,idx,target=None):
        b,t=idx.shape
        embedding=self.token_embedding(idx)
        position=self.positional_embedding(torch.arange(t,device=device))
        out=embedding+position
        out=self.blocks(out)
        out=self.layer(out)
        logits=self.leniar(out)

        if target==None:
            loss=None

        else:
            b,t,c=locals.shape
            logits=logits.view(b*t,c)
            target=target.view(b*t)
            loss=F.cross_entropy(logits,target)

        return logits,loss
    
    def generate(self,idx,max_new_tokens):
        for _ in range (max_new_tokens):
            idx_cond=idx[:,-block_size:]
            logits,loss=self(idx_cond)
            logits=logits[:,-1,:]
            probs=F.softmax(logits,dim=-1)
            idx_next=torch.multinomial(probs,num_samples=1)
            idx=torch.cat((idx,idx_next),dim=1)
        return idx


    



In [70]:
model=gptlanguagemodel()
m=model.to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=l_r)



In [None]:
for iters in range(max_iters):
    xb,yb=get_batch('data')
    logits,loss=model(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if iters%eval_interval==0 or iters==max_iters-1:
        print(f"The the loss obtained in {iters} is {loss}")
    

context=torch.zeros((1,1),dtype=torch.long,device=device)
print(decode(m.generate(context,max_new_tokens=500)[0].tolist()))


In [21]:
def get_batch(split):
    data1 = train if split=='data' else val
    ix=torch.randint(len(data1)-block_size,(batch_size,))
    x=torch.stack([data1[i:i+block_size] for i in ix])
    y=torch.stack([data1[i+1:i+block_size+1] for i in ix])
    x,y=x.to(device),y.to(device)
    return x,y

In [45]:
b,t=x.shape
b,t

(64, 256)

In [42]:
data[:10]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [None]:
x,y=get_batch(data)
# print(x.shape)
# print(y.shape)
x[:10]
yo=nn.Embedding(block_size,n_emd) 
yo1=nn.Embedding(vocab_size,n_emd)
yo3=yo1(x)
yo4=yo(torch.arange(t,device=device)) 
print(yo3.shape,yo4.shape)
yo2=yo3+yo4
yo2.shape

torch.Size([64, 256, 384]) torch.Size([256, 384])


torch.Size([64, 256, 384])