In [113]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler

In [21]:
data_dir = os.path.join(os.path.dirname(os.getcwd()), "Data/Tiny shakespeare/input.txt")

In [22]:
with open(data_dir, 'r') as f:
    text = f.read()

In [203]:
vocab = sorted(list(set(text)))
vocab_size = len(sorted(list(set(text)))) 

# Hyperparameters
batch_size = 4 #B
block_size = 10 #T
emb_dim = 32 #C

if torch.cuda.is_available():
    device = "cuda"
elif torch.has_mps:
    device = "mps"
else:
    device = "cpu"


In [204]:
token_encodings = {}
token_decodings = {}
for i, token in enumerate(vocab):
    token_encodings[token] = i
    token_decodings[i] = token

In [205]:
def encode(txt):
    enc_char = [token_encodings[char] for char in txt]
    return enc_char

def decode(enc_tokens):
    dec_char = [token_decodings[idx] for idx in enc_tokens]
    # decoded_str = "".join(dec_char);
    return dec_char

def generate_batch(batch_size, block_size):
    idx = torch.randint(0, vocab_size - max_tokens - 1, (batch_size,))
    data = torch.tensor(
        [encode(text[i : i + block_size]) for i in idx], device=device
    ) # B x T 
    targets = torch.tensor(
        [encode(text[i + 1 : i + block_size + 1]) for i in idx], device=device
    ) # B x T 
    return data, targets

token_emb_table = nn.Embedding(vocab_size, emb_dim, device=device)

In [206]:
data, targets = generate_batch(batch_size, block_size)

In [218]:
class GPT(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, vocab_size, device=device)
        self.net = nn.Sequential(nn.Linear(vocab_size, emb_dim), nn.ReLU(), nn.Linear(emb_dim, vocab_size), nn.ReLU())
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim

    def forward(self, x, targets=None):
        x = self.token_emb_table(x) # B, T, C
        logits = self.net(x)
        B, T, C = logits.shape
        loss = nn.functional.cross_entropy(logits.view(B*T, C), targets.view(B*T)) if targets is not None else None 
        return logits, loss
    
    def generate(self, idx, block_size):
        for _ in range(block_size):
            logits, loss = self.forward(idx)
            logits = logits[:, -1, :]
            probabs = nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probabs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


    def train(self, num_steps, batch_size):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, betas=(0.9, 0.999))
        for step in range(num_steps):
            optimizer.zero_grad()
            data, targets = generate_batch(batch_size, block_size)
            logits, loss = self.forward(data, targets)
            loss.backward()
            optimizer.step()
            if (step+1) % 100 == 0:
                print(f"Step {step}, loss {loss.item()}")


In [221]:
class SelfAttention(nn.Module):
    def __init__(self, emb_size, head_size):
        super().__init__()
        self.emb_size = emb_size
        self.head_size = head_size
        self.q = nn.Linear(emb_size, self.head_size)
        self.k = nn.Linear(emb_size, self.head_size)
        self.v = nn.Linear(emb_size, self.head_size)
    
    def forward(self, x):
        q = self.q(x) # B, T, C -> B, T, H
        k = self.k(x)
        v = self.v(x)
        wei = q @ k.T / np.sqrt(self.head_size) # B, T, H @ B, H, T -> B, T, T
        wei = nn.functional.softmax(wei, dim=-1)
        logits = wei @ v # B, T, H  
        return logits

In [222]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads, emb_size):
        super().__init__()
        self.n_heads = n_heads
        self.emb_size = emb_size
        self.head_size = emb_size // n_heads
    
    def forward(self, x):
        logits = []
        for i in range(self.n_heads):
            att_head = SelfAttention(self.emb_size, self.head_size)
            logits.append(att_head(x))
        logits = torch.stack(logits, dim=-1)
        return logits


In [224]:
sa = SelfAttention(emb_dim, 4)

In [225]:
sa(data)

RuntimeError: MPS device does not support linear for non-float inputs

In [210]:
gpt = GPT(vocab_size, emb_dim).to(device)
logits, loss = gpt(data, targets)

In [217]:
print(*(decode(gpt.generate(torch.zeros((1,1), dtype=torch.long, device=device), block_size=100)[0].tolist())))


 Z x ? u V R ' Z z e d   w e   w e f o L u n B e K Q - : L V Z q : g e   X P x y   w e d   s p e y e n y   S s p e   h e f D y   ? q k 3 e r c e n . c e r P A G A   C 3 f C ? I   C J q c e n y   w e  


In [214]:
gpt.train(200, 10)

Step 99, loss 2.713660955429077
Step 199, loss 2.395210027694702
