In [1]:
import requests
from datasets import load_dataset, DatasetDict
from typing import List, Optional
import torch
from torch import nn
from torch.nn import functional as F
from tqdm.notebook import tqdm

import string
import random
from ftfy import fix_text
from collections import defaultdict
from heapq import heappush, heappop
from timeit import timeit
from time import time, sleep

In [2]:
wiki = load_dataset("rahular/simple-wikipedia")

In [3]:
wiki

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 769764
    })
})

In [4]:
test_split = 0.1

train_size = int(0.9*len(wiki['train']))

def wiki_filter(row):
    return len(row['text'])>500

train = wiki['train'].select(range(train_size)).filter(wiki_filter)
test = wiki['train'].select(range(train_size, len(wiki['train']))).filter(wiki_filter)

wiki = DatasetDict({
    'train' : train,
    'test' : test
})

In [5]:
wiki

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 49998
    })
    test: Dataset({
        features: ['text'],
        num_rows: 2855
    })
})

In [6]:
stoi = {c:i for i,c in enumerate(sorted(list(string.printable)))}
default_int = stoi[' ']
print('Vocab(printable chars):\n',''.join(sorted(stoi.keys())))

Vocab(printable chars):
 	
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~


In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### "Tokenizer"

In [6]:
class Tokenizer:
    def __init__(self):
            
        self.chars = sorted(list(string.printable))
        self.itos = {i:c for i,c in enumerate(sorted(list(self.chars)))}
        self.stoi = {c:i for i,c in enumerate(sorted(list(self.chars)))}
        self.default_int = self.stoi[' '] # space as the default replacement of the unknown char
        
    def tokenize(self, text: str):
        return [self.stoi.get(char, self.default_int) for char in fix_text(text)]
    
    def decode(self, tokens: List[int]):
        return ''.join([self.itos[token] for token in tokens])

In [7]:
tokenizer = Tokenizer()

In [10]:
tokens = tokenizer.tokenize('hello')
tokens

[77, 74, 81, 81, 84]

In [352]:
tokenizer.decode(tokens)

'hello'

In [353]:
# train = train.map(tokenizer.tokenize)

train.shuffle().select(range(1))['text']

['Sir Michael Terence Wogan (; 3 August 1938 – 31 January 2016), better known as Terry Wogan, was a veteran Irish-British radio and television broadcaster, who has worked for the British Broadcasting Corporation in the United Kingdom for most of his career. Before he retired from the weekday breakfast programme "Wake Up to Wogan" on BBC Radio 2 on 18 December 2009, Wogan had a regular eight million listeners, making him the most listened to radio broadcaster of any European nation. He began his career at Raidió Teilifís Éireann where he presented shows such as "Jackpot" in the 1960s.']

In [11]:
def get_batches(data, batch_size, context_length, device=None):
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    data = data.shuffle().select(range(batch_size))
    min_data_size = min(len(item['text']) for item in data)
    min_data_size = min(min_data_size, context_length)
    block_size = random.randint(int(min_data_size*0.5), int(min_data_size*0.8))

    data = [tokenizer.tokenize(item['text']) for item in data]

    rand_starts = torch.randint(min_data_size-block_size, (batch_size,))
    
    x, y = torch.empty((0,block_size), dtype=torch.int), torch.empty((0,block_size), dtype=torch.int)
    for start, text in zip(rand_starts, data):
        try:
            x = torch.cat((x, torch.tensor(text[start:start+block_size]).unsqueeze(0) ), dim = 0)
            y = torch.cat((y, torch.tensor(text[start+1:start+block_size+1]).unsqueeze(0) ), dim = 0)    
        except Exception as e:
            print(f"Error during batch creation : {e}")
    
    return x.to(device), y.to(device)

In [355]:
x, y = get_batches(train,2, 1000)
x.shape, y.shape

(torch.Size([2, 455]), torch.Size([2, 455]))

In [8]:
embedding_dim = 768
num_heads = 12
max_iters = 60000
eval_interval = 300
eval_iters = 50
lr = 0.3e-4
dropout = 0.1
vocab_size = len(tokenizer.chars)
num_blocks = 10

batch_size = 8
context_length = 800

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Multihead Latent Attention

In [13]:
class CausalMultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int):
        super().__init__()
        head_size = embedding_dim//num_heads
        self.num_heads = num_heads
        self.latent_proj = nn.Linear(embedding_dim, embedding_dim//2)
        self.qkv_proj = nn.Linear(embedding_dim//2, embedding_dim*3)
        self.o_proj = nn.Linear(head_size*num_heads, embedding_dim)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)).to(device))
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, kv_cache = None):
        B, T, C = x.shape
        
        latent = self.latent_proj(x)
        qkv = self.qkv_proj(latent)
        queries, keys, values = qkv.split(C, dim=-1)
        
        queries = queries.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        keys = keys.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        values = values.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        
        wei = queries@keys.transpose(-2,-1)/(queries.shape[-1]**0.5)
        wei = wei.masked_fill(self.tril[:T,:T]==0, -torch.inf)
        weights = F.softmax(wei, dim = -1)
        
        out = weights@values

        out = out.transpose(1,2).reshape(B, T, C)
        out = self.o_proj(out)
        out = self.dropout(out)
        
        return out, None

### Multihead Attention

In [14]:
class CausalMultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int):
        super().__init__()
        head_size = embedding_dim//num_heads
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(embedding_dim, embedding_dim*3)
        self.o_proj = nn.Linear(head_size*num_heads, embedding_dim)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)).to(device))
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, kv_cache = None):
        B, T, C = x.shape
        
        qkv = self.qkv_proj(x)
        queries, keys, values = qkv.split(C, dim=-1)
        
        queries = queries.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        keys = keys.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        values = values.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        
        if kv_cache is not None and not self.training:
            past_keys, past_values = kv_cache
            keys = torch.cat((past_keys, keys), dim = 2)
            values = torch.cat((past_values, values), dim = 2)
        
        wei = queries@keys.transpose(-2,-1)/(queries.shape[-1]**0.5)
        # Masking only if kv cache is None, no concept of masking when only new/last token is passed(kv_cache exists)
        if kv_cache is None:
            wei = wei.masked_fill(self.tril[:T, :T]==0, -torch.inf)
        weights = F.softmax(wei, dim = -1)
        
        out = weights@values

        out = out.transpose(1,2).reshape(B, T, C)
        out = self.o_proj(out)
        out = self.dropout(out)
            
        return (out, (keys,values)) if not self.training else (out, None) 

In [2]:
class FeedForwardVanilla(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.up_proj = nn.Linear(embedding_dim, embedding_dim*8//3, bias=False)
        self.silu_proj = nn.Linear(embedding_dim, embedding_dim*8//3, bias=False)
        self.down_proj = nn.Linear(embedding_dim*8//3, embedding_dim, bias=False)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x):
        out = F.silu(self.silu_proj(x))*self.up_proj(x)
        out = self.down_proj(out)
        out = self.dropout(out)
        
        return out

In [3]:
class FeedForward(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.hidden_dim = embedding_dim*8//3
        self.up_proj = nn.Linear(embedding_dim, self.hidden_dim*2, bias=False)
        self.down_proj = nn.Linear(self.hidden_dim, embedding_dim, bias=False)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x):
        x1, x2 = self.up_proj(x).split(self.hidden_dim, dim = -1) 
        out = F.silu(x1)*x2
        out = self.down_proj(out)
        out = self.dropout(out)
        
        return out

### Parallelized SwiGLU benchmark

In [12]:
ed = 4096
dev = 'cuda'

x = torch.randn((1,200,ed)).to(dev)

ff_vanilla = FeedForwardVanilla(ed).to(dev)
ff = FeedForward(ed).to(dev)

def run_ffn_benchmark(num_trials = 1000):
    t1s = 0
    t2s = 0
    
    for _ in range(num_trials):
        t2 = timeit(lambda : ff(x), number = 1)
        t1 = timeit(lambda : ff_vanilla(x), number = 1)
        
        t1s += t1
        t2s += t2
    
    return t1s/num_trials, t2s/num_trials

t1, t2 = run_ffn_benchmark()

print(f"Vanilla FFN time: {t1 : .4f}s. Parallelized FFN time: {t2 : .4f}s. Speed-up : {100*(t1-t2)/t1 : .2f}%")

Vanilla FFN time:  0.0108s. Parallelized FFN time:  0.0050s. Speed-up :  53.84%


In [17]:
class DynamicTanh(nn.Module):
    def __init__(self, normalized_shape, eps=1e-4, init_alpha=0.5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
        self.alpha = nn.Parameter(torch.ones(1)*init_alpha)
        
    def forward(self, x):
        out = F.tanh(self.alpha*x)
        out = self.gamma*out + self.beta
        return out

In [18]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim: int, num_heads):
        super().__init__()
        self.multi_head_attention = CausalMultiHeadAttention(embedding_dim, num_heads)
        self.feed_forward_net = FeedForward(embedding_dim)
        self.dynamic_tanh1 = DynamicTanh(embedding_dim)
        self.dynamic_tanh2 = DynamicTanh(embedding_dim)
        
    def forward(self, x, kv_cache = None):
        out, kv_cache = self.multi_head_attention(self.dynamic_tanh1(x), kv_cache)
        out = x + out
        out = out + self.feed_forward_net(self.dynamic_tanh2(out))
        
        return out, kv_cache

In [19]:
class GPT(nn.Module):
    def __init__(self, embedding_dim: int = 64, num_heads: int = 8, num_blocks = 8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embedding = nn.Embedding(context_length, embedding_dim)
        self.blocks = nn.ModuleList([
            DecoderBlock(embedding_dim, num_heads) for _ in range(num_blocks)
        ])
        self.dynamic_tanh = DynamicTanh(embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)
        
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        
    def forward(self, tokens, targets=None, kv_cache = None):

        if kv_cache is None:
            kv_cache = [None] * len(self.blocks)
            out = self.pos_embedding(torch.arange(tokens.shape[-1], device=device)) + self.embedding(tokens)
        else:
            # trim the kv_cache to keep the context valid
            T_past = kv_cache[0][0].shape[2]
            if T_past >= context_length:
                trim = lambda past_kv : (past_kv[0][:, :, -(context_length-1):, :], past_kv[1][:, :, -(context_length-1):, :])  
                kv_cache = [trim(kv_cache[i]) for i in range(len(self.blocks))] 
            
            tokens = tokens[:, [-1]]
            out = self.pos_embedding(torch.arange(T_past, T_past+1, device=device)) + self.embedding(tokens)

        for i, block in enumerate(self.blocks):
            out, updated_block_cache = block(out, kv_cache[i])
            kv_cache[i] = updated_block_cache
            
        out = self.dynamic_tanh(out)
        
        # If no targets, it is inference and we only care about the last token
        if targets is None:    
            out = out[:, [-1], :]
            
        logits = self.lm_head(out)
                                 
        if targets is None:
            return logits, kv_cache
            
        B, T, C = logits.shape

        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    @torch.no_grad()
    def generate(self, tokens, temperature = 1, top_k = None, max_new_tokens=100, use_cache = True):
        assert temperature>0, "temperature needs to be positive for comprehensible generations"
        if top_k is not None:
            assert top_k>0 and isinstance(top_k, int), "Non-positive or non-int top_k doesn't make sense!"
        
        kv_cache = None
        for _ in range(max_new_tokens): 
            context = tokens[:,-context_length:]
            
            if use_cache is False:
                logits, _ = self(context, None)
            else:
                logits, kv_cache = self(context, None, kv_cache)
            
            logits = logits[:,-1,:]/temperature
            
            if top_k is not None:
                logits = self._get_topk_logits(logits, top_k)
                
            probabilities = F.softmax(logits, dim=1)
            next_token = torch.multinomial(probabilities, 1)
            
            tokens = torch.cat((tokens, next_token), dim=1)
            
        return tokens
    
    def _get_topk_logits(self, logits, k):
        v, _ = torch.topk(logits, k, dim=-1)
        min_values = v[:, -1].unsqueeze(-1).expand_as(logits)
        
        return torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits)

#     def _get_topk_logits(self, logits, k: int):
#         heap = []
        
#         for logit in logits:
#             heappush(heap, logit)
#             if len(heap)>k:
#                 heappop(heap)
           
#         logits[logits<heap[0]] = -torch.inf
        
#         return logits

In [20]:
model = GPT(embedding_dim=embedding_dim, num_heads=num_heads, num_blocks=num_blocks).to(device)

In [21]:
model

GPT(
  (embedding): Embedding(100, 768)
  (pos_embedding): Embedding(800, 768)
  (blocks): ModuleList(
    (0-9): 10 x DecoderBlock(
      (multi_head_attention): CausalMultiHeadAttention(
        (qkv_proj): Linear(in_features=768, out_features=2304, bias=True)
        (o_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward_net): FeedForward(
        (up_proj): Linear(in_features=768, out_features=4096, bias=False)
        (down_proj): Linear(in_features=2048, out_features=768, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dynamic_tanh1): DynamicTanh()
      (dynamic_tanh2): DynamicTanh()
    )
  )
  (dynamic_tanh): DynamicTanh()
  (lm_head): Linear(in_features=768, out_features=100, bias=True)
)

In [22]:
print(f"Model has {sum(p.numel() for p in model.parameters())/1e6 :.2f}M parameters")

Model has 71.61M parameters


In [23]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [24]:
min_loss = torch.inf

In [25]:
model_path = "C:/Users/dhars/Documents/Sagemaker notebooks/GPT1-71M-wiki"

In [27]:
state_dict = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [26]:
scaler = torch.amp.GradScaler()

In [None]:
for iter_ in tqdm(range(1, max_iters+1), colour='green'):
    model.train()
    
    x,y = get_batches(data=train, batch_size=batch_size, context_length=context_length, device=device)    
    
    with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
        logits, loss = model(tokens=x,targets=y)
        
    optimizer.zero_grad(set_to_none = True)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    if iter_%100==0:
        print(f"[{iter_}/{max_iters}]: Train loss: {loss.mean(): .2f}")
    if iter_%eval_interval==0 or iter_==max_iters:
        model.eval()
        eval_losses = torch.zeros(eval_iters)
        
        with torch.no_grad():
            for i in range(eval_iters):
                x,y = get_batches(test, batch_size, context_length, device)
                logits, loss = model(x,y)
                eval_losses[i] = loss
            eval_loss = eval_losses.mean()
            if eval_loss<min_loss:
                min_loss = eval_loss
                print(f"Eval loss improved: {eval_loss: .2f}, saving checkpoint")
                torch.save(model.state_dict(), model_path)

In [43]:
print(f'Min Eval loss: {min_loss: .2f}')

Min Eval loss:  1.16


In [29]:
def get_gpt_response(prompt: str, max_new_tokens: int =400, use_cache = True) -> str:
    model.eval()
    
    prompt = torch.tensor(tokenizer.tokenize(prompt), device=device).unsqueeze(0)
    out = tokenizer.decode(model.generate(prompt, max_new_tokens=max_new_tokens, use_cache=use_cache)[0].tolist())
    
    return out

In [66]:
print(get_gpt_response("Alan Turing was ", max_new_tokens=500))

Alan Turing was born as the 24th president of Sony. He made captains during public mosquititions from a private model during a bridge burning burning in 1854. Its market was high in World War II. It was unnocciated in Srvestle. This was called "GBD!". Missing transportation. The 28th Century replaced the Public War. The security is when hands so that the public British was adopted by a variety of books called ""The World "Hell: Editzure Eug" talking to "Du Checkpins"". Where gy not of the their mechix ovar jert


### KV cache speed up on CPU

In [334]:
timeit(lambda : get_gpt_response("Indian subcontinent", max_new_tokens=100), number=5)

40.247228200081736

In [335]:
timeit(lambda : get_gpt_response("Indian subcontinent", max_new_tokens=100, use_cache=False), number=5)

161.07912310003303

### KV cache speed up on GPU

In [29]:
timeit(lambda : get_gpt_response("Indian subcontinent", max_new_tokens=750), number=5)

95.50981409987435

In [28]:
timeit(lambda : get_gpt_response("Indian subcontinent", max_new_tokens=750, use_cache=False), number=5)

106.67140390002169