In [1]:
import tiktoken

In [2]:
enc = tiktoken.get_encoding("cl100k_base")
with open('input.txt', 'r') as f:
    text = f.read()
tokens = enc.encode(text)
padding_token = enc.encode("\n")[0]
n_vocab = enc.n_vocab
len(tokens)

301829

In [112]:
block_size = 8
batch_size = 4
n_embd = 32
n_heads = int(n_embd / 4)
n_blocks = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [220]:
test = [1, 2,3 ,4,5, 6, 7]
test[:1]

[1]

In [237]:
data = torch.tensor(tokens, dtype=torch.long)
n = int(0.9*data.shape[0])
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

torch.Size([4, 9])
torch.Size([6707, 4, 9])


In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [161]:
class Head(nn.Module):
    def __init__(self, block_size, n_embd, head_size):
        super().__init__()
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((block_size, block_size))))
    def forward(self, x):
        _, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        wei = q @ k.mT * C**-0.5
        
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

In [162]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, n_embd, n_heads):
        super().__init__()
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_heads = n_heads
        assert (self.n_embd % self.n_heads) == 0
        self.head_size = n_embd // self.n_heads
        self.heads = nn.ModuleList([Head(self.block_size, self.n_embd, self.head_size) for _ in range(self.n_heads)])
        self.projection = nn.Linear(n_embd, n_embd)
    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.projection(x)
        return x

In [163]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.n_embd = n_embd
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd),
        )
    def forward(self, x): return self.net(x)

In [164]:
class TransformerBlock(nn.Module):
    def __init__(self, block_size, n_embd, n_heads):
        super().__init__()
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_heads = n_heads
        self.self_attention = MultiHeadAttention(self.block_size, self.n_embd, self.n_heads)
        self.feedforward = FeedForward(self.n_embd)
        self.ln1 = nn.LayerNorm(self.n_embd)
        self.ln2 = nn.LayerNorm(self.n_embd)
    def forward(self, x):
        x = x + self.self_attention(self.ln1(x))
        x = x + self.feedforward(self.ln2(x))
        return x

In [172]:
class Transformer(nn.Module):
    def __init__(self, block_size, n_vocab, n_embd, n_heads, n_blocks):
        super().__init__()
        self.block_size = block_size
        self.n_vocab = n_vocab
        self.n_embd = n_embd
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        
        self.embedding = nn.Embedding(self.n_vocab, self.n_embd)
        self.positional_embedding = nn.Embedding(self.block_size, self.n_embd)
        self.blocks = nn.Sequential(*[TransformerBlock(self.block_size, self.n_embd, self.n_heads) for _ in range(self.n_blocks)])
        self.ln = nn.LayerNorm(self.n_embd)
        self.lm_head = nn.Linear(self.n_embd, self.n_vocab)
    def forward(self, idx, targets=None):
        tok_embd = self.embedding(idx)
        pos_embd = self.positional_embedding(torch.arange(0, self.block_size))
        x = tok_embd + pos_embd
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.lm_head(x)
        loss = None
        
        # if targets is None:
        #     pass
        return logits

In [173]:
model = Transformer(block_size, n_vocab, n_embd, n_heads, n_blocks)
model(batches[0])

tensor([[[ 1.2706, -0.1236,  1.1974,  ...,  0.4005, -0.2644, -0.2203],
         [-0.4871, -0.3172, -0.0730,  ..., -0.4406,  1.2822,  0.0958],
         [-0.2143, -0.2342, -0.5432,  ..., -1.0173,  0.4597,  0.1983],
         ...,
         [-0.8043, -0.4944,  0.0324,  ...,  0.2026,  1.5352, -0.0804],
         [-0.3575,  1.1241, -0.3230,  ...,  0.1771,  0.5986, -0.2469],
         [-0.3133, -0.0172,  0.2017,  ..., -0.6932, -0.0298,  0.1927]],

        [[ 0.6229, -0.7065,  0.4110,  ...,  0.6928, -0.8343,  0.5626],
         [-1.1165, -0.6047, -0.7571,  ..., -0.4478,  1.3511,  0.3368],
         [-0.3604, -0.6866, -1.1339,  ..., -0.3051,  0.6021,  0.3928],
         ...,
         [-0.9221, -1.2039, -0.5934,  ...,  0.0033,  0.5720, -0.0851],
         [-0.1702,  0.8628,  0.1382,  ...,  0.0643,  0.3970, -0.3166],
         [-0.4408, -0.1644, -0.3611,  ..., -0.4050, -1.4617,  0.2096]],

        [[ 0.6229, -0.7065,  0.4110,  ...,  0.6928, -0.8343,  0.5626],
         [-0.4555, -0.1284,  0.2218,  ..., -0

In [15]:
batches[0]

tensor([[ 5451, 47317,   512, 10438,   584, 10570,   904,  4726],
        [   11,  6865,   757,  6604,   382,  2460,   512, 96945],
        [   11,  6604,   382,  5451, 47317,   512,  2675,   527],
        [  682, 20250,  4856,   311,  2815,  1109,   311,  2138]])

In [71]:
torch.cat([torch.randn((4, 8, 128)), torch.randn((4, 8, 128))], dim=-1).shape

torch.Size([4, 8, 256])