In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import tiktoken

In [2]:
torch.cuda.is_available(), torch.cuda.get_device_name()

(True, 'NVIDIA T1200 Laptop GPU')

In [3]:
with open("./the_things.txt", encoding="utf-8") as file:
    raw_text = file.read()

len(raw_text)

39181

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

In [5]:
token_ids = tokenizer.encode(raw_text)
len(token_ids)

9275

In [6]:
class BookDataset(Dataset):
    def __init__(self, token_ids, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        for i in range(0, len(token_ids) - max_length, stride):
            self.input_ids.append(torch.tensor(token_ids[i : i + max_length]))
            self.target_ids.append(torch.tensor(token_ids[i + 1 : i + 1 + max_length]))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, index):
        return self.input_ids[index], self.target_ids[index]

In [7]:
book_ds = BookDataset(token_ids, 256, 128)

In [8]:
book_loader = DataLoader(book_ds, batch_size=4, shuffle=True, drop_last=True)

In [15]:
GPT_CONFIG = {
    "vocab_size": tokenizer.n_vocab,
    "context_length": 1024,
    "embed_dim": 768,
    "num_heads": 12,
    "num_layers": 12,
    "dropout_rate": 0.1,
    "qkv_bias": False,
}

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self, in_dim, out_dim, context_length, dropout, num_heads, qkv_bias=False
    ):
        super().__init__()
        assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"

        self.out_dim = out_dim
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads

        self.query_weight = nn.Linear(in_dim, out_dim, bias=qkv_bias)
        self.key_weight = nn.Linear(in_dim, out_dim, bias=qkv_bias)
        self.value_weight = nn.Linear(in_dim, out_dim, bias=qkv_bias)

        self.output_projection = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        batch_size, num_tokens, in_dim = x.shape

        keys = self.key_weight(x)
        queries = self.query_weight(x)
        values = self.value_weight(x)

        # Split the weights to have qkv for each head
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attention_scores = queries @ keys.transpose(2, 3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attention_scores = attention_scores.masked_fill(mask_bool, -torch.inf)

        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        attention_weights = self.dropout(attention_weights)

        context_vector = (attention_weights @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(
            batch_size, num_tokens, self.out_dim
        )
        context_vector = self.output_projection(context_vector)

        return context_vector

In [11]:
class LayerNorm(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()

        self.epsilon = 1e-5
        self.scale = nn.Parameter(torch.ones(embed_dim))
        self.shift = nn.Parameter(torch.zeros(embed_dim))

    def forward(self, x):
        x_mean = x.mean(dim=-1, keepdim=True)
        x_variance = x.var(dim=-1, keepdim=True)
        x_norm = (x - x_mean) / torch.sqrt(x_variance + self.epsilon)

        return self.scale * x_norm + self.shift

In [12]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return (
            0.5
            * x
            * (
                1
                + torch.tanh(
                    torch.sqrt(torch.tensor(2.0 / torch.pi))
                    * (x + 0.044715 * torch.pow(x, 3))
                )
            )
        )

In [13]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
        )

    def forward(self, x):
        return self.layers(x)

In [16]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.mha = MultiHeadAttention(
            in_dim=config["embed_dim"],
            out_dim=config["embed_dim"],
            context_length=config["context_length"],
            dropout=config["dropout_rate"],
            num_heads=config["num_heads"],
            qkv_bias=config["qkv_bias"],
        )

        self.ffn = FeedForward(config["embed_dim"])

        self.norm1 = LayerNorm(config['embed_dim'])
        self.norm2 = LayerNorm(config['embed_dim'])

        self.shortcut_dropout = nn.Dropout(config['dropout_rate'])

    def forward(self, x):
        shortcut = x

        x = self.norm1(x)
        x = self.mha(x)
        x = self.shortcut_dropout(x)
        x += shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = self.shortcut_dropout
        x += shortcut

        return x

In [17]:
class BookGPT(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.token_embedding = nn.Embedding(config['vocab_size'], config['embed_dim'])
        self.position_embedding = nn.Embedding(config['context_length'], config['embed_dim'])
        
        self.dropout = nn.Dropout(config['dropout_rate'])

        self.transformers = nn.Sequential(*[Transformer(config) for _ in range(config['num_layers'])])

        self.norm = LayerNorm(config['embed_dim'])
        self.out_head = nn.Linear(config['embed_dim'], config['vocab_size'], bias=False)

    def forward(self, x):
        batch_size, seq_len = x.shape
        
        token_embeds = self.token_embedding(x)
        position_embeds = self.position_embedding(torch.arange(seq_len, device=x.device))

        x = token_embeds + position_embeds
        x = self.dropout(x)
        x = self.transformers(x)
        x = self.norm(x)
        logits = self.out_head(x)

        return logits