In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device= "cuda" if torch.cuda.is_available() else "cpu"
print(device)
blockSize= 64
batchSize= 128
maxIters=3000
evalIters= 100
evalInterval= 500
learningRate=3e-3
nEmbedding =384
nLayer= 4
nHead=4
dropout= 0.2

cuda


In [2]:
with open("wizard_of_oz.txt","r", encoding ="UTF-8") as f:
    text=f.read()
character=sorted(list(set(text)))
vocab_size = len(character)

In [3]:
stringToInt={ ch:i for i,ch in enumerate(character)}
intToString={ i:ch for i,ch in enumerate(character)}

encode= lambda a : [stringToInt[i] for i in a]
decode= lambda a: "".join([intToString[i] for i in a])

data = torch.tensor(encode(text), dtype=torch.long)

In [4]:
n=int(0.8* len(data))
trainData=data[:n]
valData=data[n:]

x=trainData[:blockSize]
y=trainData[1:blockSize+1]

for i in range(0,blockSize):
    context=x[:i+1]
    target=y[i]
    print("when input is: ",context," target is ",target)

when input is:  tensor([80])  target is  tensor(28)
when input is:  tensor([80, 28])  target is  tensor(39)
when input is:  tensor([80, 28, 39])  target is  tensor(42)
when input is:  tensor([80, 28, 39, 42])  target is  tensor(39)
when input is:  tensor([80, 28, 39, 42, 39])  target is  tensor(44)
when input is:  tensor([80, 28, 39, 42, 39, 44])  target is  tensor(32)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32])  target is  tensor(49)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49])  target is  tensor(1)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49,  1])  target is  tensor(25)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49,  1, 25])  target is  tensor(38)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38])  target is  tensor(28)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28])  target is  tensor(1)
when input is:  tensor([80, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1])  target is  tensor(44)
when input is:

In [5]:
def getbatch(split):
    data= trainData if split == 'train' else valData
    ix= torch.randint(len(data)-blockSize , (batchSize,))
    #print(ix)
    x=torch.stack([data[i:i+blockSize] for i in ix])
    y=torch.stack([data[i+1:i+blockSize+1] for i in ix])
    x,y= x.to(device), y.to(device)
    return x, y
x, y =getbatch('train')

In [6]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses= torch.zeros(evalIters)
        for k in range(evalIters):
            x,y = getbatch(split)
            logits, loss =model(x,y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [7]:
class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key= nn.Linear(nEmbedding, head_size, bias=False)
        self.query= nn.Linear(nEmbedding, head_size, bias=False)
        self.value= nn.Linear(nEmbedding, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(blockSize, blockSize))) 
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B,T,C= x.shape
        k= self.key(x)
        q= self.query(x)
        weight= q @ k.transpose(-2,-1)* k.shape[-1]**-0.5
        weight= weight.masked_fill(self.tril[:T, :T] ==0 , float('-inf'))
        weight= F.softmax(weight, dim=-1)
        weight= self.dropout(weight)
        
        v=self.value(x)
        out= weight @ v
        return out
        
class MultiHeadAttention(nn.Module):
    def __init__(self, numHeads , headSize):
        super().__init__()
        self.heads= nn.ModuleList([Head(headSize) for _ in range(numHeads)])
        self.proj= nn.Linear(headSize * numHeads, nEmbedding)
        self.dropout= nn.Dropout(dropout)

    def forward(self,x):
        out = torch.cat([h(x) for h  in self.heads], dim =-1)
        out= self.dropout(self.proj(out))
        return out
        
class FeedForward(nn.Module):
    def __init__(self,nEmbedding):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(nEmbedding, 4 * nEmbedding),
            nn.ReLU(),
            nn.Linear(4 * nEmbedding, nEmbedding),
            nn.Dropout(dropout),
        )
    def forward(self,x ):
        return self.net(x)
        
class Block(nn.Module):
    def __init__(self,nEmbedding, nHead):
        super().__init__()
        headSize= nEmbedding // nHead
        self.selfAttention= MultiHeadAttention(nHead, headSize)
        self.feedForward= FeedForward(nEmbedding)
        self.layerNorm_1= nn.LayerNorm(nEmbedding)
        self.layerNorm_2= nn.LayerNorm(nEmbedding)
        
    def forward(self, x):
        y= self.selfAttention(x)
        x=self.layerNorm_1(x+y)
        y=self.feedForward(x)
        x=self.layerNorm_2(x+y)
        return x
        
class GPT_LLM(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table= nn.Embedding(vocab_size,nEmbedding)
        self.positional_embedding_table = nn.Embedding(blockSize, nEmbedding)

        self.blocks = nn.Sequential(*[Block(nEmbedding, nHead= nHead) for _ in range(nLayer)])
        self.layerNormFinal =nn.LayerNorm(nEmbedding)
        self.languageModelHead = nn.Linear(nEmbedding , vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self,index, targets=None):
        B,T = index.shape
        
        tokenEmbedding  = self.token_embedding_table(index)
        positionalEmbedding = self.positional_embedding_table(torch.arange(T, device= device))
        x= tokenEmbedding + positionalEmbedding
        x= self.blocks(x)
        x= self.layerNormFinal(x)
        logits = self.languageModelHead(x)
        
        if targets is None :
             loss=None
        else:
            B,T,C =logits.shape
            logits= logits.view(B*T,C)
            targets= targets.view(B*T)
            loss= F.cross_entropy(logits,targets)
        return logits, loss
        
    def generate(self,index , max_new_tokens):
        for _ in range(max_new_tokens):
            index_cond = index[:, -blockSize: ]
            logits , loss = self.forward(index_cond)
            logits= logits[:, -1, :]
            probs= F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index= torch.cat((index,index_next), dim=1)
        return index

model= GPT_LLM(vocab_size)
m = model.to(device)

In [8]:
optimizer= torch.optim.AdamW(model.parameters(), lr=learningRate)
for iter in range(maxIters):
    if iter % evalIters == 0:
        losses= estimate_loss()
        print(f'step : {iter} , loss : {losses}')
    
    xb,yb= getbatch('train')
    logits , loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

step : 0 , loss : {'train': tensor(4.4961), 'val': tensor(4.5008)}
step : 100 , loss : {'train': tensor(3.1492), 'val': tensor(3.1629)}
step : 200 , loss : {'train': tensor(3.1463), 'val': tensor(3.1531)}
step : 300 , loss : {'train': tensor(3.1428), 'val': tensor(3.1507)}
step : 400 , loss : {'train': tensor(3.1429), 'val': tensor(3.1493)}
step : 500 , loss : {'train': tensor(3.1413), 'val': tensor(3.1536)}
step : 600 , loss : {'train': tensor(3.1411), 'val': tensor(3.1515)}
step : 700 , loss : {'train': tensor(3.1405), 'val': tensor(3.1555)}
step : 800 , loss : {'train': tensor(3.1413), 'val': tensor(3.1490)}
step : 900 , loss : {'train': tensor(3.1431), 'val': tensor(3.1493)}
step : 1000 , loss : {'train': tensor(3.1478), 'val': tensor(3.1572)}
step : 1100 , loss : {'train': tensor(3.1404), 'val': tensor(3.1523)}
step : 1200 , loss : {'train': tensor(3.1369), 'val': tensor(3.1541)}
step : 1300 , loss : {'train': tensor(3.1420), 'val': tensor(3.1483)}
step : 1400 , loss : {'train': t

In [10]:
context = torch.zeros((1,1),dtype= torch.long, device =device)
generatedChars=decode(m.generate(context,max_new_tokens=500)[0].tolist())
print(generatedChars)


w dheen   pnrnnpd ttaarehwkl er
 edk,  ne soeau' heei gescfonP nnottlti o hiTytnegg
:yhiewt ledlsoh aytrmanyrt rlhe nr u  ddnv dt.tu wmitsa  omngbHsiwidga aelru hau.eeedi tlok atth ecvh rf  v prm nrny reagboansoan
o gdlnso a o hf  g   eaelmsoac nwceenuhm   onphe o tyrc siL caae. ekedcWt lJtrmIosu inleusaat,"selodene tt hro "
 ewoe re ehe  sth  glg aaelshs,eeuOGbnhi myh easer;tlotat, e
leu nc ohreymudhgee  y hnwennd nr, tgh sftbutee
a ooasn  fpo o see
fomyooovs akshfntt edlpiglr.ah
ka t arnbr oyl
