In [58]:
import numpy as np

train_data = np.memmap('../data/train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('../data/val.bin', dtype=np.uint16, mode='r')

In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [60]:
block_size = 32
batch_size = 16

In [61]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    
    # Randomly select chunk of text for training
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
    
    return x, y

In [62]:
x, y = get_batch('train')

In [63]:
x

tensor([[12276,  1375,    44,     0,  8401,  3652, 12471,  2911,  3268,  1268,
         12953,  4905,  8635,  7521, 11435,  3300,  4899,    44,     0,     0,
          5400,  5758,  1776,  7373,  4551, 12953,  1336,  2683,  4913,  1361,
          4971,    44],
        [ 6102,  5240,  2156,  5921,    44,     0,     0,  5707, 11982,  5691,
          7826, 12605, 12953,  1387,  3696,  3003,  2055,  2172,    44,     0,
          4826,  2605,  1296,  5984, 12772, 12953, 10863,  4417,  6103,  5773,
          4839,    44],
        [ 3871,  5024,    36,    36,  5449, 12953,  9464,  2651,  7386,  9098,
          2250,  3311, 10857,    44,     0,  4858,  4178,  1285, 12841, 10912,
          7195,  2113, 12953,  4930,  1390,  1342,  1414,  2051, 11079,  3344,
            44,     0],
        [ 9507, 11658,  4899,    44,     0,  3723,  7735,  3477,  1286,  7737,
          1779,  3309, 12953,  2603,  4867,  1369,  9218,  7733,  3810,  4784,
            44,     0,  2643, 10359,  4148,  5234,  7762,  

In [64]:
y

tensor([[ 1375,    44,     0,  8401,  3652, 12471,  2911,  3268,  1268, 12953,
          4905,  8635,  7521, 11435,  3300,  4899,    44,     0,     0,  5400,
          5758,  1776,  7373,  4551, 12953,  1336,  2683,  4913,  1361,  4971,
            44,     0],
        [ 5240,  2156,  5921,    44,     0,     0,  5707, 11982,  5691,  7826,
         12605, 12953,  1387,  3696,  3003,  2055,  2172,    44,     0,  4826,
          2605,  1296,  5984, 12772, 12953, 10863,  4417,  6103,  5773,  4839,
            44,     0],
        [ 5024,    36,    36,  5449, 12953,  9464,  2651,  7386,  9098,  2250,
          3311, 10857,    44,     0,  4858,  4178,  1285, 12841, 10912,  7195,
          2113, 12953,  4930,  1390,  1342,  1414,  2051, 11079,  3344,    44,
             0, 11668],
        [11658,  4899,    44,     0,  3723,  7735,  3477,  1286,  7737,  1779,
          3309, 12953,  2603,  4867,  1369,  9218,  7733,  3810,  4784,    44,
             0,  2643, 10359,  4148,  5234,  7762,  2956,  

In [65]:
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.0
bias = False

In [66]:
x1 = torch.randn(batch_size, block_size, n_embd)
attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
print(attn(x1).shape)
print(f'Split the last dimension into q, k, v')
q, k, v = torch.split(attn(x1), split_size_or_sections=n_embd, dim=2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

torch.Size([16, 32, 1152])
Split the last dimension into q, k, v
Shape of q, k, v: torch.Size([16, 32, 384]), torch.Size([16, 32, 384]), torch.Size([16, 32, 384])


In [67]:
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q = q.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
k = k.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
v = v.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

Shape of q, k, v: torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64])


In [68]:
# causal mask to ensure that attention is only applied to the left in the input sequence
mask = torch.tril(torch.ones(block_size, block_size))
mask = mask.view(1, 1, block_size, block_size) # Add the batch dimension
mask[:, :, :block_size, :block_size]

tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])

In [69]:
c_proj = nn.Linear(n_embd, n_embd, bias)
attn_dropout = nn.Dropout(dropout)
resid_dropout = nn.Dropout(dropout)

In [70]:
import math

# Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(mask=(mask[:, :, :block_size, :block_size] == 0), value=float('-inf'))
att = F.softmax(att, dim=-1)
att = attn_dropout(att)
y1 = att @ v # (B, nh, T, T) x (B, nh, T, hs) => (B, nh, T, hs)
# (B, nh, T, hs) => (B, T, nh, hs) => (B, T, nh * hs) => (B, T, C)
y1 = y1.transpose(1, 2).contiguous().view(batch_size, block_size, n_embd)
y1 = c_proj(y1)
y1 = resid_dropout(y1)

In [71]:
y1.shape

torch.Size([16, 32, 384])

In [72]:
class SelfAttention(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        assert n_embd % n_head == 0, "The remainder of embedding and head number should be zero."
        
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        self.register_buffer('bias', torch.tril(torch.ones(block_size, block_size))
                                          .view(1, 1, block_size, block_size))
    
    def forward(self, x):
        B, T, C = x.shape # batch_size, sequence length (block_size), embedding dimension (n_embd)
        
        q, k, v = torch.split(self.c_attn(x), split_size_or_sections=n_embd, dim=2)
        q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
        k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
        v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(mask=(self.bias[:, :, :block_size, :block_size] == 0), value=float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        return resid_dropout(c_proj(y))

In [73]:
class LayerNorm(nn.Module):
    
    def __init__(self, n_dim) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(n_dim))
    
    def forward(self, input):
        # Bias is not used in this model
        return F.layer_norm(input, self.weight.shape, self.weight, bias=None, eps=1e-5)

In [74]:
class MLP(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, 4 * n_embd, bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [75]:
class SelfAttentionBlock(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(n_embd)
        self.attn = SelfAttention()
        self.ln_2 = LayerNorm(n_embd)
        self.mlp = MLP()
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [76]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    block_size: int = 32
    vocab_size: int = 12992
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 384 
    dropout: float = 0.0

In [77]:
config = GPTConfig

In [78]:
emb_w = nn.Embedding(config.vocab_size, config.n_embd).weight
proj_w = nn.Linear(config.n_embd, config.vocab_size, bias=False).weight
print(f'{emb_w.shape}, {proj_w.shape}')

torch.Size([12992, 384]), torch.Size([12992, 384])


In [79]:
wte = nn.Embedding(config.vocab_size, config.n_embd)
wpe = nn.Embedding(config.block_size, config.n_embd)
drop = nn.Dropout(config.dropout)
self_attn = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)])
ln_f = LayerNorm(config.n_embd)

lm_head = nn.Linear(config.n_embd, config.vocab_size)

In [80]:
b, t = x.size()

In [81]:
tok_emb = wte(x)

In [82]:
pos = torch.arange(0, t, dtype=torch.long)
pos_emb = wpe(pos)

In [83]:
h = drop(tok_emb + pos_emb)

In [84]:
for block in self_attn:
    h = block(h)

In [85]:
res = ln_f(h)

In [86]:
res

tensor([[[-0.0616, -2.8554,  2.5387,  ...,  0.5750, -0.2359,  0.7878],
         [-0.7315,  0.4567,  0.1484,  ...,  2.6605, -0.6875,  0.2774],
         [-1.3246, -0.3199, -0.0086,  ..., -0.1334, -1.1427,  0.9465],
         ...,
         [ 0.4935,  0.9426,  0.7683,  ..., -2.3904, -0.5727, -0.4588],
         [ 0.9080, -0.1385,  1.3717,  ...,  0.6032, -2.4350, -1.1443],
         [-0.2722,  0.1268, -0.4437,  ...,  0.1638, -1.0512,  1.0896]],

        [[-0.2998,  0.1574,  1.8886,  ...,  0.4090,  1.0215,  1.2194],
         [-0.0888,  1.5727,  1.7173,  ...,  0.3340, -0.3872,  0.4777],
         [-1.7268, -0.3312,  1.3689,  ..., -2.2727, -0.9690,  1.4416],
         ...,
         [-0.7515,  1.2760, -0.7825,  ..., -0.7137,  0.4080, -0.6182],
         [-0.5858, -0.0110,  0.2162,  ..., -0.4953, -1.1886, -0.5534],
         [-0.1956,  0.2227, -0.4753,  ...,  0.0632, -1.1565,  1.3665]],

        [[ 0.8393,  1.0732,  0.5114,  ..., -0.0429,  0.7551,  1.5836],
         [-0.3344,  1.0009,  0.2859,  ...,  0

In [87]:
res.shape

torch.Size([16, 32, 384])

In [88]:
last_step = res[:, [-1], :] # note: using list [-1] to preserve the time dim
last_step.shape

torch.Size([16, 1, 384])

In [89]:
inference_logits = lm_head(last_step)
print(inference_logits.shape)
inference_logits

torch.Size([16, 1, 12992])


tensor([[[-0.0269, -0.3299,  0.1040,  ..., -0.6036,  0.8881,  0.5105]],

        [[ 0.1028, -0.4665,  0.0260,  ..., -0.4697,  0.8926,  0.4367]],

        [[ 0.0830,  0.2496,  0.0176,  ..., -0.0999,  0.8963,  0.9351]],

        ...,

        [[ 0.2034,  0.1112,  0.1992,  ...,  0.2128,  0.2024,  1.2172]],

        [[ 0.7887,  0.4334,  0.2768,  ...,  0.1259,  0.1355,  0.3733]],

        [[-0.3308,  0.3745,  0.9104,  ..., -0.5784,  1.0488,  0.4708]]],
       grad_fn=<ViewBackward0>)

In [90]:
temperature = 1.0

In [91]:
print(inference_logits[:, -1, :].shape)
inference_logits = inference_logits[:, -1, :]
inference_logits / temperature

torch.Size([16, 12992])


tensor([[-0.0269, -0.3299,  0.1040,  ..., -0.6036,  0.8881,  0.5105],
        [ 0.1028, -0.4665,  0.0260,  ..., -0.4697,  0.8926,  0.4367],
        [ 0.0830,  0.2496,  0.0176,  ..., -0.0999,  0.8963,  0.9351],
        ...,
        [ 0.2034,  0.1112,  0.1992,  ...,  0.2128,  0.2024,  1.2172],
        [ 0.7887,  0.4334,  0.2768,  ...,  0.1259,  0.1355,  0.3733],
        [-0.3308,  0.3745,  0.9104,  ..., -0.5784,  1.0488,  0.4708]],
       grad_fn=<DivBackward0>)

In [92]:
probs = F.softmax(inference_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next

tensor([[ 2205],
        [ 6291],
        [ 7252],
        [ 9763],
        [12030],
        [  850],
        [12805],
        [10389],
        [  173],
        [ 6519],
        [ 3376],
        [ 9164],
        [ 5579],
        [ 5365],
        [12385],
        [ 1417]])

In [93]:
logits = lm_head(res)

In [94]:
logits.size()

torch.Size([16, 32, 12992])

In [95]:
logits = logits.view(-1, logits.size(-1)) # (B, T, C) => (B * T, C)
logits

tensor([[-1.0369,  0.1713,  0.1788,  ..., -0.2416,  0.2519, -0.0380],
        [-0.4233, -0.3859,  0.4675,  ..., -0.4487,  0.6512,  0.4881],
        [-0.4201, -0.2159,  0.0110,  ..., -0.4234,  0.6843, -0.4219],
        ...,
        [ 0.3948,  0.2101,  1.1524,  ..., -1.0070, -0.4542,  0.9740],
        [ 0.5829,  0.8687,  0.0483,  ...,  1.1608,  0.0434,  0.1346],
        [-0.3308,  0.3745,  0.9104,  ..., -0.5784,  1.0488,  0.4708]],
       grad_fn=<ViewBackward0>)

In [96]:
y

tensor([[ 1375,    44,     0,  8401,  3652, 12471,  2911,  3268,  1268, 12953,
          4905,  8635,  7521, 11435,  3300,  4899,    44,     0,     0,  5400,
          5758,  1776,  7373,  4551, 12953,  1336,  2683,  4913,  1361,  4971,
            44,     0],
        [ 5240,  2156,  5921,    44,     0,     0,  5707, 11982,  5691,  7826,
         12605, 12953,  1387,  3696,  3003,  2055,  2172,    44,     0,  4826,
          2605,  1296,  5984, 12772, 12953, 10863,  4417,  6103,  5773,  4839,
            44,     0],
        [ 5024,    36,    36,  5449, 12953,  9464,  2651,  7386,  9098,  2250,
          3311, 10857,    44,     0,  4858,  4178,  1285, 12841, 10912,  7195,
          2113, 12953,  4930,  1390,  1342,  1414,  2051, 11079,  3344,    44,
             0, 11668],
        [11658,  4899,    44,     0,  3723,  7735,  3477,  1286,  7737,  1779,
          3309, 12953,  2603,  4867,  1369,  9218,  7733,  3810,  4784,    44,
             0,  2643, 10359,  4148,  5234,  7762,  2956,  

In [97]:
y.view(-1).shape # (B, T, 1) => (B * T)

torch.Size([512])

In [98]:
F.cross_entropy(logits, y.view(-1), ignore_index=-1)

tensor(9.7090, grad_fn=<NllLossBackward0>)

In [99]:
transformer = nn.ModuleDict(dict(
    wte = nn.Embedding(config.vocab_size, config.n_embd),
    wpe = nn.Embedding(config.block_size, config.n_embd),
    drop = nn.Dropout(config.dropout),
    self_attn_block = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)]),
    ln_f = LayerNorm(config.n_embd)
))

lm_head = nn.Linear(config.n_embd, config.vocab_size)

In [100]:
tok_emb = transformer.wte(x)
device = x.device
b, t = x.size()
pos = torch.arange(0, t, dtype=torch.long, device=device)
pos_emb = transformer.wpe(pos)
h = transformer.drop(tok_emb + pos_emb)
for block in transformer.self_attn_block:
    h = block(h)
res = transformer.ln_f(h)

In [101]:
# compute the loss
logits = lm_head(res)
F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1), ignore_index=-1)

tensor(9.7024, grad_fn=<NllLossBackward0>)

In [102]:
# for generation in inference
logits = lm_head(res[:, [-1], :])
logits.shape

torch.Size([16, 1, 12992])