In [1]:
import torch
from torch import tensor
import torch.nn as nn
from torch.nn import functional as F

In [2]:
torch.manual_seed(1111)
gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 32
context_window = 256
max_iter = 500
eval_iters = 100
eval_interval = 50
lr = 3e-4
n_embeds = 384
num_heads = 6
n_layer = 6
dropout = 0.2

In [3]:
with open('./data/shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
vocab_len = len(vocab)

stoi = {c:i for i,c in enumerate(vocab)}
itos = {i:c for i,c in enumerate(vocab)}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])

In [4]:
data = torch.tensor(encode(text), dtype=torch.long)
index = int(0.9*len(data))

train_data = data[:index]
test_data = data[index:]

In [5]:
def data_batcher(split_type):
    data = train_data if split_type == "train" else test_data
    
    ix = torch.randint(len(data) - context_window, (batch_size,))
    
    context = torch.stack([data[i:i+context_window] for i in ix])
    target = torch.stack([data[i+1:i+context_window+1] for i in ix])
    
    context, target = context.to(gpu), target.to(gpu)
    
    return context, target

In [6]:
class SelfAttentionHead(nn.Module):
    
    def __init__(self, head_size):
        super().__init__()

        self.key = nn.Linear(n_embeds, head_size, bias=False)
        self.query = nn.Linear(n_embeds, head_size, bias=False)
        self.value = nn.Linear(n_embeds, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_window, context_window)))
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        B, T, C = x.shape
        
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        weights = q @ k.transpose(-2, -1) * C**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        
        xbow = weights @ v
        return xbow 
    
class MHA(nn.Module):
    
    def __init__(self, head_size, num_heads):
        super().__init__()
        self.heads = nn.ModuleList((SelfAttentionHead(head_size) for _ in range(num_heads)))
        self.proj = nn.Linear(n_embeds, n_embeds)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        
        return out
    
class FeedForward(nn.Module):
    
    def __init__(self, n_embeds):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(n_embeds, 4 * n_embeds),
            nn.ReLU(),
            nn.Linear(4 * n_embeds, n_embeds),
            nn.Dropout(dropout),
        )
        
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    
    def __init__(self, n_embeds, num_heads):
        super().__init__()
        
        head_size = n_embeds // num_heads
        self.self_att = MHA(num_heads, head_size)
        self.ffn = FeedForward(n_embeds)
        self.ln1 = nn.LayerNorm(n_embeds)
        self.ln2 = nn.LayerNorm(n_embeds)
        
    def forward(self, x):
        x = x + self.self_att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class ScratchGPT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embeddings = nn.Embedding(vocab_len, n_embeds)
        self.positions = nn.Embedding(context_window, n_embeds)
        self.blocks = nn.Sequential(*[Block(n_embeds, num_heads) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embeds)
        self.lm_head = nn.Linear(n_embeds, vocab_len)
        
        self.apply(self.init_weights)
        
    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, context, targets=None):
        B, T = context.shape
        
        token_embeds = self.embeddings(context)
        pos_embeds = self.positions(torch.arange(T, device=gpu))
        x = token_embeds + pos_embeds
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        if not targets is None:
            logits = logits.view(-1, vocab_len)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        
        return logits, loss
    
    def generate(self, context, max_preds):
        for _ in range(max_preds):
            context = context[:, -context_window:]
            
            logits, loss = self(context)
            logits = logits[:, -1, :]
            
            prob = F.softmax(logits, dim=-1)
            
            next_context = torch.multinomial(prob, num_samples=1)
            context = torch.cat((context, next_context), dim=1)
            
        return context
    
    @torch.no_grad()
    def calc_loss(self):
        LM.eval()
        output = {}
        for split in ['train', 'test']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                context, targets = data_batcher(split)
                logits, loss = LM(context, targets)
                losses[k] = loss.item()
            output[split] = losses.mean()
        LM.train()
        return output
    
    def trainLM(self):
        optimizer = torch.optim.AdamW(LM.parameters(), lr=lr)
            
        for iter in range(max_iter):
            
            if iter+1 % eval_interval == 0:
                losses = self.calc_loss()
                print(f"[{iter+1}]: train loss = {losses['train']}, eval loss = {losses['test']}")
                
            context, targets = data_batcher("train")
            logits, loss = self(context, targets)
            
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

In [7]:
LM = ScratchGPT()
LM = LM.to(gpu)
LM.trainLM()

[1]: train loss = 4.37258768081665, eval loss = 4.363542079925537


In [8]:
context = torch.zeros((1, 1), dtype=torch.long, device=gpu)
print(decode(LM.generate(context, 500)[0].tolist()))

unof tha ive'sHM, weZ:
OI
Cu n ofosandvemome i;d, bewit sinZ and iqWhe, cy co tind bomy yomend nghedSIurk:

Yosnd t nofingheisker'lnVWhe'sze,
:
LGcBuco i'd:
HOLEUMnon me d shor s ound teres ltouthe o s the bisbus bGe toed alvou xIOI avess o thoh gh,
CRWheng
