In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [40]:
# ----- HYPERPARAMETERS
max_iters = 9000
eval_iters = 300
eval_interval = max_iters//6
batch_size = 64
block_size = 32
n_embed = 32
lr = 1e-3
device = torch.device('mps' if torch.backends.mps.is_available else 'cpu')
num_heads = 4
head_size = n_embed // num_heads
num_layers = 32
print(device)

mps


In [3]:
with open('tiny-shakeshpere.txt','r') as f:
    text = f.read()

In [4]:
print(f'Length of the dataset : {len(text)}')

Length of the dataset : 1115393


In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [6]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[ch] for ch  in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [7]:
data = torch.tensor(encode(text),dtype = torch.long)
print(data.shape,data.dtype)

torch.Size([1115393]) torch.int64


In [8]:
# Splitting the dataset into training and test
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
# Generating data for prediction
torch.manual_seed(1337)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size,(batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i+1:i + block_size + 1] for i in ix])
    x,y = x.to(device),y.to(device)
    return x,y

xb,yb = get_batch('train')
print(xb.shape,yb.shape)

torch.Size([64, 32]) torch.Size([64, 32])


In [10]:
@torch.no_grad()
def eval_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits,loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [11]:
class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key = nn.Linear(n_embed,head_size,bias = False)
        self.query = nn.Linear(n_embed,head_size,bias = False)
        self.value = nn.Linear(n_embed,head_size,bias = False)

        self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))

    def forward(self,x):
        B,T,C = x.shape
        q = self.query(x)
        k = self.key(x)

        wei = q @ k.transpose(-2,-1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T,:T] == 0,float('-inf'))
        wei = F.softmax(wei,dim  = 1)
        v = self.value(x)
        # print(v.shape,wei.shape,(wei@v).shape)
        out = wei @ v
        return out

In [12]:
class MultiHeadedAttention(nn.Module):
    def __init__(self,num_heads,head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed,n_embed)

    def forward(self,x):
        out = torch.stack([h(x) for h in self.heads],dim = -1)
        out = self.proj(x)
        return out
        

In [13]:
class FeedForward(nn.Module):
    def __init__(self,n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed,4*n_embed),
            nn.ReLU(),
            nn.Linear(4*n_embed,n_embed)
        )
    def forward(self,x):
        return self.net(x)

In [14]:
class Block(nn.Module):
    def __init__(self,n_embed,n_heads):
        super().__init__()
        head_size = n_embed // n_heads
        self.sa = MultiHeadedAttention(n_heads,head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self,x):
        x = x + self.sa(self.ln1(x))
        x = x +self.ffwd(self.ln2(x))
        return x
        

In [15]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size,n_embed)
        self.position_embeddings = nn.Embedding(block_size,n_embed)
        self.lm_head = nn.Linear(n_embed,vocab_size)
        self.blocks = nn.Sequential(*[Block(n_embed,num_heads) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(n_embed)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
    
    def forward(self,idx,targets = None):
        B,T = idx.shape
        token_embeddings = self.token_embedding(idx) # Shape = (B,T,C)
        positional_embeddings = self.position_embeddings(torch.arange(T,device = device)) # Shape = (T,C)
        x = token_embeddings + positional_embeddings # Shape = (B,T,C)
        x = self.blocks(x)
        logits = self.lm_head(x) # Shape -> (B,T,C)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
        return logits,loss


    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits,loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits,dim = 1)
            idx_next = torch.multinomial(probs,num_samples = 1)
            idx = torch.cat((idx,idx_next),dim = 1)
        return idx

In [42]:
model = BigramLanguageModel()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.408769 M parameters


In [17]:
# Writing the optimizer for backpropogation
optimizer = torch.optim.AdamW(m.parameters(),lr=lr)

In [18]:
for iterator in range(max_iters):
    if iterator % eval_interval == 0:
        losses = eval_loss()
        print(f'Step : {iterator} | Train Loss : {losses['train']} | Eval Loss : {losses['val']}')
    xb,yb = get_batch('train')
    logits,loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

Step : 0 | Train Loss : 4.173321723937988 | Eval Loss : 4.173290729522705
Step : 1500 | Train Loss : 2.486271858215332 | Eval Loss : 2.4989116191864014
Step : 3000 | Train Loss : 2.4739644527435303 | Eval Loss : 2.490596294403076
Step : 4500 | Train Loss : 2.4654290676116943 | Eval Loss : 2.487050771713257
Step : 6000 | Train Loss : 2.4647457599639893 | Eval Loss : 2.488279104232788
Step : 7500 | Train Loss : 2.463229179382324 | Eval Loss : 2.4949629306793213


In [19]:
idx = torch.zeros((1,1),dtype = torch.long)
idx = idx.to(device)
print(decode(m.generate(idx,max_new_tokens=500)[0].tolist()))


Foasth prse tizenderst el
O d frnie hy:


Hak, CORineg agnthe t rrigoucowor d s nge?
Ten, rsothy, chouspo is mppry way avend ouburser sickes bokecard dhiceny

He tw el fe oupise he, lbustselownthous;
Nom w
T:
TIONTouly me EUSerks, angnditheland's oe, oghithet f, badogienthofathatey foueay wad,
ureisold array ngestyockield, murs, in mamybalorthyongmyoorord Vofetthindy st
Hefil brveseay alstwanerm to, oupompl wee d pre h, gavit gin Thean apsts lathise my d erouerse IO:
ED d nghathicerire.
II IS:
Y
