In [31]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler

In [32]:
data_dir = os.path.join(os.path.dirname(os.getcwd()), "Data/Tiny shakespeare/input.txt")

In [33]:
with open(data_dir, 'r') as f:
    text = f.read()

In [41]:
vocab = sorted(list(set(text)))
vocab_size = len(sorted(list(set(text)))) 
data_size = len(text)
# Hyperparameters
batch_size = 4 #B
block_size = 10 #T
emb_size = 32 #C

if torch.cuda.is_available():
    device = "cuda"
elif torch.has_mps:
    device = "mps"
else:
    device = "cpu"


In [35]:
token_encodings = {}
token_decodings = {}
for i, token in enumerate(vocab):
    token_encodings[token] = i
    token_decodings[i] = token

In [50]:
def encode(txt):
    enc_char = [token_encodings[char] for char in txt]
    return enc_char

def decode(enc_tokens):
    dec_char = [token_decodings[idx] for idx in enc_tokens]
    decoded_str = "".join(dec_char)
    return decoded_str

def generate_batch(batch_size, block_size):
    idx = torch.randint(0, data_size - block_size - 1, (batch_size,))
    data = torch.tensor(
        [encode(text[i : i + block_size]) for i in idx], device=device
    ) # B x T 
    targets = torch.tensor(
        [encode(text[i + 1 : i + block_size + 1]) for i in idx], device=device
    ) # B x T 
    return data, targets

In [58]:
data, targets = generate_batch(batch_size, block_size)
print([decode(data[i].cpu().numpy()) for i in range(data.shape[0])])

[' caps and ', ' were they', 'his chambe', 'urs, the b']


In [59]:
num_heads = 4
head_size = 64

In [60]:
class SelfAttention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.emb_size = emb_size
        self.head_size = head_size
        self.q = nn.Linear(emb_size, self.head_size, device=device)
        self.k = nn.Linear(emb_size, self.head_size, device=device)
        self.v = nn.Linear(emb_size, self.head_size, device=device)
    
    def forward(self, x):
        q = self.q(x) # B, T, C -> B, T, H
        k = self.k(x)
        v = self.v(x)
        B, T, H = q.shape
        wei = q @ k.transpose(-1, -2) / np.sqrt(self.head_size) # B, T, H @ B, H, T -> B, T, T
        # print(wei.shape)
        mask = torch.tril(torch.ones(B, T, T)).to(device)
        wei = wei.masked_fill(mask == 0, float('-inf'))
        wei = nn.functional.softmax(wei, dim=-1)
        out = wei @ v # B, T, H  
        return out

In [61]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.emb_size = emb_size
        self.head_size = emb_size // n_heads
    
    def forward(self, x):
        out = []
        for i in range(self.n_heads):
            att_head = SelfAttention(self.head_size)
            out.append(att_head(x))
        # print(len(out), out[0].shape)
        logits = torch.cat(out, dim=-1)
        return logits


In [132]:
class GPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, emb_size, device=device)
        self.pos_emb_table = nn.Embedding(block_size, emb_size, device=device)
        self.ff_net = nn.Linear(emb_size, vocab_size, device=device)
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.head_size = head_size
        self.mha = MultiHeadedAttention(self.num_heads)

    def forward(self, x, targets=None):
        token_emb = self.token_emb_table(x) # B, T, C
        # print(x.shape, token_emb.shape)
        pos_emb = self.pos_emb_table(torch.arange(x.shape[-1], device=device)) # T, C
        x = token_emb + pos_emb # B, T, C
        x = self.mha(x) # B, T, C
        logits = self.ff_net(x) # B, T, vocab_size
        # print(logits.shape)
        B, T, C = logits.shape
        if targets is not None:
            loss_fn = torch.nn.CrossEntropyLoss()
            # targets = self.token_emb_table(targets)
            loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
        else:
            loss = None
        
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_slice = idx[:, -block_size:]
            logits, loss = self.forward(idx_slice)
            logits = logits[:, -1, :]
            probabs = nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probabs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
        return decode(idx[0].tolist())


    def train(self, num_steps, batch_size):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, betas=(0.9, 0.999))
        for step in range(num_steps):
            optimizer.zero_grad()
            data, targets = generate_batch(batch_size, block_size)
            logits, loss = self.forward(data, targets)
            loss.backward()
            optimizer.step()
            if (step+1) % 10 == 0:
                print(f"Step {step}, loss {loss.item()}")


In [133]:
gpt = GPT(vocab_size).to(device)

In [134]:
logits, loss = gpt(data, targets)

In [135]:
loss

tensor(4.2502, device='mps:0', grad_fn=<NllLossBackward0>)

In [139]:
gpt.generate(torch.zeros((1,1), dtype=torch.long, device=device), max_tokens=60)


"\nHRYhDzg LB ,&s qyN,kb:'T b.tsiWXG3:NlWef:FoohM?heS$CkPHa3e\nt"

In [138]:
gpt.train(1000, 5)

Step 9, loss 4.026917934417725
Step 19, loss 3.9967408180236816
Step 29, loss 4.072494029998779
Step 39, loss 4.044800281524658
Step 49, loss 3.9608895778656006
Step 59, loss 3.947067975997925
Step 69, loss 4.034868240356445
Step 79, loss 3.88858699798584
Step 89, loss 3.866732120513916
Step 99, loss 3.9286835193634033
Step 109, loss 3.988818407058716
Step 119, loss 3.93001127243042
Step 129, loss 3.9781854152679443
Step 139, loss 3.9018938541412354
Step 149, loss 3.974118709564209
Step 159, loss 3.902158260345459
Step 169, loss 4.003384590148926
Step 179, loss 3.9022223949432373
Step 189, loss 3.9942989349365234
Step 199, loss 3.9868481159210205
Step 209, loss 3.8934073448181152
Step 219, loss 3.9983043670654297
Step 229, loss 3.9016008377075195
Step 239, loss 3.959432601928711
Step 249, loss 3.903341770172119
Step 259, loss 3.885266065597534
Step 269, loss 3.8942017555236816
Step 279, loss 3.904792070388794
Step 289, loss 4.027103900909424
Step 299, loss 3.875434637069702
Step 309, l