In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler

In [2]:
data_dir = os.path.join(os.path.dirname(os.getcwd()), "Data/Tiny shakespeare/input.txt")

In [3]:
with open(data_dir, 'r') as f:
    text = f.read()

In [4]:
vocab = sorted(list(set(text)))
vocab_size = len(sorted(list(set(text)))) 
data_size = len(text)
# Hyperparameters
batch_size = 40 #B
block_size = 200 #T
emb_size = 256 #C
num_blocks = 4
num_heads = 4
head_size = 512

if torch.cuda.is_available():
    device = "cuda"
elif torch.has_mps:
    device = "mps"
else:
    device = "cpu"


In [5]:
token_encodings = {}
token_decodings = {}
for i, token in enumerate(vocab):
    token_encodings[token] = i
    token_decodings[i] = token

In [6]:
def encode(txt):
    enc_char = [token_encodings[char] for char in txt]
    return enc_char

def decode(enc_tokens):
    dec_char = [token_decodings[idx] for idx in enc_tokens]
    decoded_str = "".join(dec_char)
    return decoded_str

def generate_batch(batch_size, block_size):
    idx = torch.randint(0, data_size - block_size - 1, (batch_size,))
    data = torch.tensor(
        [encode(text[i : i + block_size]) for i in idx], device=device
    ) # B x T 
    targets = torch.tensor(
        [encode(text[i + 1 : i + block_size + 1]) for i in idx], device=device
    ) # B x T 
    return data, targets

In [7]:
data, targets = generate_batch(batch_size, block_size)
# print([decode(data[i].cpu().numpy()) for i in range(data.shape[0])])

In [13]:
test = open(data_dir, 'r').read()
len(test)

1115394

In [8]:
class ShakespeareDataset(Dataset):
    def __init__(self, data_dir, train=True):
        super().__init__()
        self.data = open(data_dir, 'r').read()
        train_dataset, val_dataset = torch.nn.utils.data.random_split(self.data, [int(len(self.data) * 0.8), len(self.data) - int(len(self.data) * 0.8)])
        if train:
            self.dataset = train_dataset
        else:
            self.dataset = val_dataset
    
    def __getitem__(self, idx):
        # idx = torch.randint(0, data_size - block_size - 1, (batch_size,))
        data = torch.tensor(
            [encode(self.dataset[i : i + block_size]) for i in idx], device=device
        ) # B x T 
        targets = torch.tensor(
            [encode(self.dataset[i + 1 : i + block_size + 1]) for i in idx], device=device
        ) # B x T 
        return data, targets
    
    def __len__(self):
        return len(self.dataset)

In [None]:
dataset = ShakespeareDataset(data_dir)
sampler = DistributedSampler(dataset, num_replicas=1, rank=0)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

In [9]:
class SelfAttention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.emb_size = emb_size
        self.head_size = head_size
        self.q = nn.Linear(emb_size, self.head_size, device=device)
        self.k = nn.Linear(emb_size, self.head_size, device=device)
        self.v = nn.Linear(emb_size, self.head_size, device=device)
    
    def forward(self, x):
        q = self.q(x) # B, T, C -> B, T, H
        k = self.k(x)
        v = self.v(x)
        B, T, H = q.shape
        wei = q @ k.transpose(-1, -2) / np.sqrt(self.head_size) # B, T, H @ B, H, T -> B, T, T
        # print(wei.shape)
        mask = torch.tril(torch.ones(B, T, T)).to(device)
        wei = wei.masked_fill(mask == 0, float('-inf'))
        wei = nn.functional.softmax(wei, dim=-1)
        out = wei @ v # B, T, H  
        return out

In [10]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.emb_size = emb_size
        self.head_size = emb_size // n_heads
    
    def forward(self, x):
        out = []
        for i in range(self.n_heads):
            att_head = SelfAttention(self.head_size)
            out.append(att_head(x))
        # print(len(out), out[0].shape)
        logits = torch.cat(out, dim=-1)
        return logits


In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, num_blocks):
        super().__init__()

In [16]:
class GPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, emb_size, device=device)
        self.pos_emb_table = nn.Embedding(block_size, emb_size, device=device)
        self.ff_net = nn.Linear(emb_size, emb_size, device=device)
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.head_size = head_size
        self.mha = MultiHeadedAttention(self.num_heads)
        self.layer_norm = nn.LayerNorm(emb_size, device=device, dtype=torch.float32)
        self.final_ll = nn.Linear(emb_size, vocab_size, device=device)

    def forward(self, x, targets=None):
        token_emb = self.token_emb_table(x) # B, T, C
        # print(x.shape, token_emb.shape)
        pos_emb = self.pos_emb_table(torch.arange(x.shape[-1], device=device)) # T, C
        x = token_emb + pos_emb # B, T, C
        for _ in range(num_blocks):
            x_res = x
            x = self.layer_norm(x)
            x = self.mha(x) # B, T, C
#             x = x_res + x
#             x_res = x
            x = x_res + self.ff_net(x) # B, T, vocab_size
        logits = self.final_ll(x)
        # print(logits.shape)
        B, T, C = logits.shape
        if targets is not None:
            loss_fn = torch.nn.CrossEntropyLoss()
            # targets = self.token_emb_table(targets)
            loss = loss_fn(logits.view(B*T, C), targets.view(B*T))
        else:
            loss = None
        
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_slice = idx[:, -block_size:]
            logits, loss = self.forward(idx_slice)
            logits = logits[:, -1, :]
            probabs = nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probabs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
        return decode(idx[0].tolist())


    def train(self, num_steps, batch_size):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, betas=(0.9, 0.999))
        for step in range(num_steps):
            optimizer.zero_grad()
            data, targets = generate_batch(batch_size, block_size)
            logits, loss = self.forward(data, targets)
            loss.backward()
            optimizer.step()
            if (step+1) % 10 == 0:
                print(f"Step {step}, loss {loss.item()}")


In [17]:
gpt = GPT(vocab_size).to(device)

In [38]:
np.sum([p.numel() for p in gpt.parameters()])

150849

In [25]:
logits, loss = gpt(data, targets)

In [None]:
print(gpt.generate(torch.zeros((1,1), dtype=torch.long, device=device), max_tokens=60))


In [None]:
gpt.train(1000, 100)

In [None]:
logits, loss = gpt(data, targets)

In [None]:
import matplotlib.pyplot as plt
plt.colorbar(plt.imshow(logits[:, -1, :].detach().cpu()))