In [3]:
import numpy as np

train_data = np.memmap('../data/train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('../data/val.bin', dtype=np.uint16, mode='r')

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
block_size = 32
batch_size = 16

In [6]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    
    # Randomly select chunk of text for training
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
    
    return x, y

In [7]:
x, y = get_batch('train')

In [8]:
x

tensor([[ 5025,  3703, 10124,    44,     0,     0,  3695,  7094,  7373, 10339,
          2125, 12953,  2903,  3189,  3240,  2208,  3628,    44,     0,  4770,
         10931, 10202,  3921,  1868, 12953,  6078,  4118, 10520,  4307,  5040,
            44,     0],
        [ 6489,  9283,  3728,  7692, 11391, 11432, 12953,  1920, 10237,  1434,
          3921,  8405,  4454, 11250,    44,     0, 10317,  2130,  2050,  3370,
          6444,  1788,  6428, 12953,  7299,  3696,  6109,  4968, 10143,  2956,
          1387,    44],
        [   44,     0,  1282,  1282,  5602,  8465,  2605,  4798,  3195,    44,
             0, 12631, 11737,  1267,  3835,  5088,    44,     0,  2146,  2947,
          3905,  1264,  6794,    44,     0,  8966,  3968,  8901,  1836,  4856,
            44,     0],
        [ 5095,    44,     0,  5917, 11593,  1337,  1267,  1774, 12953,  7298,
          1346,  1480,  4867,  2605, 12956,     0,  9078,  9447,  1361,  9098,
          2968, 12953, 10863,  5984,  5693,  4277,  5734,  

In [9]:
y

tensor([[ 3703, 10124,    44,     0,     0,  3695,  7094,  7373, 10339,  2125,
         12953,  2903,  3189,  3240,  2208,  3628,    44,     0,  4770, 10931,
         10202,  3921,  1868, 12953,  6078,  4118, 10520,  4307,  5040,    44,
             0,  1346],
        [ 9283,  3728,  7692, 11391, 11432, 12953,  1920, 10237,  1434,  3921,
          8405,  4454, 11250,    44,     0, 10317,  2130,  2050,  3370,  6444,
          1788,  6428, 12953,  7299,  3696,  6109,  4968, 10143,  2956,  1387,
            44,     0],
        [    0,  1282,  1282,  5602,  8465,  2605,  4798,  3195,    44,     0,
         12631, 11737,  1267,  3835,  5088,    44,     0,  2146,  2947,  3905,
          1264,  6794,    44,     0,  8966,  3968,  8901,  1836,  4856,    44,
             0,  2071],
        [   44,     0,  5917, 11593,  1337,  1267,  1774, 12953,  7298,  1346,
          1480,  4867,  2605, 12956,     0,  9078,  9447,  1361,  9098,  2968,
         12953, 10863,  5984,  5693,  4277,  5734,    44,  

In [10]:
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.0
bias = False

### Attention forward pass

In [11]:
x1 = torch.randn(batch_size, block_size, n_embd)
attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
print(attn(x1).shape)
print(f'Split the last dimension into q, k, v')
q, k, v = torch.split(attn(x1), split_size_or_sections=n_embd, dim=2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

torch.Size([16, 32, 1152])
Split the last dimension into q, k, v
Shape of q, k, v: torch.Size([16, 32, 384]), torch.Size([16, 32, 384]), torch.Size([16, 32, 384])


In [12]:
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q = q.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
k = k.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
v = v.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

Shape of q, k, v: torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64])


In [13]:
# causal mask to ensure that attention is only applied to the left in the input sequence
mask = torch.tril(torch.ones(block_size, block_size))
mask = mask.view(1, 1, block_size, block_size) # Add the batch dimension
mask[:, :, :block_size, :block_size]

tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])

In [14]:
c_proj = nn.Linear(n_embd, n_embd, bias)
attn_dropout = nn.Dropout(dropout)
resid_dropout = nn.Dropout(dropout)

In [15]:
import math

# Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(mask=(mask[:, :, :block_size, :block_size] == 0), value=float('-inf'))
att = F.softmax(att, dim=-1)
att = attn_dropout(att)
y1 = att @ v # (B, nh, T, T) x (B, nh, T, hs) => (B, nh, T, hs)
# (B, nh, T, hs) => (B, T, nh, hs) => (B, T, nh * hs) => (B, T, C)
y1 = y1.transpose(1, 2).contiguous().view(batch_size, block_size, n_embd)
y1 = c_proj(y1)
y1 = resid_dropout(y1)

In [16]:
y1.shape

torch.Size([16, 32, 384])

### Building blocks of the transformer

In [17]:
class SelfAttention(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        assert n_embd % n_head == 0, "The remainder of embedding and head number should be zero."
        
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        self.register_buffer('bias', torch.tril(torch.ones(block_size, block_size))
                                          .view(1, 1, block_size, block_size))
    
    def forward(self, x):
        B, T, C = x.shape # batch_size, sequence length (block_size), embedding dimension (n_embd)
        
        q, k, v = torch.split(self.c_attn(x), split_size_or_sections=n_embd, dim=2)
        q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
        k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
        v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(mask=(self.bias[:, :, :block_size, :block_size] == 0), value=float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        return resid_dropout(c_proj(y))

In [18]:
class LayerNorm(nn.Module):
    
    def __init__(self, n_dim) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(n_dim))
    
    def forward(self, input):
        # Bias is not used in this model
        return F.layer_norm(input, self.weight.shape, self.weight, bias=None, eps=1e-5)

In [19]:
class MLP(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, 4 * n_embd, bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [20]:
class SelfAttentionBlock(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(n_embd)
        self.attn = SelfAttention()
        self.ln_2 = LayerNorm(n_embd)
        self.mlp = MLP()
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

### Forward pass of the transformer model

In [21]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    block_size: int = 32
    vocab_size: int = 12992
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 384 
    dropout: float = 0.0

In [22]:
config = GPTConfig

In [23]:
# The embedding layer and final output projection layer are in the same shape
# They can share the same weights => weight tying
emb_w = nn.Embedding(config.vocab_size, config.n_embd).weight
proj_w = nn.Linear(config.n_embd, config.vocab_size, bias=False).weight
print(f'{emb_w.shape}, {proj_w.shape}')

torch.Size([12992, 384]), torch.Size([12992, 384])


In [24]:
wte = nn.Embedding(config.vocab_size, config.n_embd)
wpe = nn.Embedding(config.block_size, config.n_embd)
drop = nn.Dropout(config.dropout)
self_attn = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)])
ln_f = LayerNorm(config.n_embd)

lm_head = nn.Linear(config.n_embd, config.vocab_size)

In [25]:
b, t = x.size()

In [26]:
tok_emb = wte(x)

In [27]:
pos = torch.arange(0, t, dtype=torch.long)
pos_emb = wpe(pos)

In [28]:
h = drop(tok_emb + pos_emb)

In [29]:
for block in self_attn:
    h = block(h)

In [30]:
res = ln_f(h)

In [31]:
res

tensor([[[-0.7403,  0.2919,  0.8253,  ...,  1.3536,  0.6279,  0.9969],
         [-1.3320, -0.0372,  0.7929,  ...,  1.4257, -0.0236, -0.4465],
         [ 0.1600,  0.0606, -0.8494,  ...,  0.0821, -1.6503, -0.0502],
         ...,
         [-0.1120, -0.1676,  0.1251,  ...,  0.2305, -0.0765, -0.1127],
         [ 0.4834, -0.6434,  0.0100,  ...,  0.3718,  0.1212, -1.1368],
         [ 0.0526,  0.0383,  0.0407,  ..., -0.5134,  1.7667, -1.7377]],

        [[ 0.6691, -1.0205, -1.7012,  ..., -0.1609,  0.1108,  0.2928],
         [-0.3143,  0.9292,  0.2006,  ...,  0.7758, -0.1992, -0.7312],
         [-0.0777, -0.5686, -1.1394,  ..., -0.0989,  0.1733, -0.0890],
         ...,
         [-0.7618,  0.2117,  0.5519,  ..., -1.3541,  1.7957, -0.4803],
         [ 0.6352, -0.7841,  1.1753,  ..., -0.2185,  0.9883,  0.2905],
         [ 0.4276,  0.1279,  0.5564,  ...,  0.1594,  1.0094, -2.4802]],

        [[ 0.2992,  0.0563, -0.3367,  ..., -0.2081,  0.4746, -0.4614],
         [-1.0432,  0.4137,  0.3052,  ..., -0

In [32]:
res.shape

torch.Size([16, 32, 384])

In [33]:
logits = lm_head(res)

In [34]:
logits.size()

torch.Size([16, 32, 12992])

In [35]:
logits = logits.view(-1, logits.size(-1)) # (B, T, C) => (B * T, C)
logits

tensor([[-0.1369, -0.0214, -0.1410,  ...,  0.5903,  0.1987,  0.0593],
        [-0.3992, -0.6033, -0.8112,  ...,  0.5634,  0.0331,  0.7070],
        [ 0.6359, -0.2673, -0.8278,  ...,  0.5802,  0.8409,  0.7029],
        ...,
        [-0.3523, -0.1463, -0.6002,  ...,  0.1998, -0.2649, -0.1879],
        [ 0.4501, -0.5308,  0.7656,  ..., -0.4483, -0.6711,  0.7724],
        [ 0.0777, -0.5441, -1.1556,  ...,  0.2921,  0.2450, -0.0256]],
       grad_fn=<ViewBackward0>)

In [36]:
y.view(-1).shape # (B, T, 1) => (B * T)

torch.Size([512])

In [37]:
F.cross_entropy(logits, y.view(-1), ignore_index=-1)

tensor(9.6869, grad_fn=<NllLossBackward0>)

In [38]:
transformer = nn.ModuleDict(dict(
    wte = nn.Embedding(config.vocab_size, config.n_embd),
    wpe = nn.Embedding(config.block_size, config.n_embd),
    drop = nn.Dropout(config.dropout),
    self_attn_block = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)]),
    ln_f = LayerNorm(config.n_embd)
))

lm_head = nn.Linear(config.n_embd, config.vocab_size)

In [39]:
tok_emb = transformer.wte(x)
device = x.device
b, t = x.size()
pos = torch.arange(0, t, dtype=torch.long, device=device)
pos_emb = transformer.wpe(pos)
h = transformer.drop(tok_emb + pos_emb)
for block in transformer.self_attn_block:
    h = block(h)
res = transformer.ln_f(h)

In [40]:
# compute the loss
logits = lm_head(res)
F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1), ignore_index=-1)

tensor(9.6859, grad_fn=<NllLossBackward0>)

In [41]:
# for generation in inference
logits = lm_head(res[:, [-1], :])
logits.shape

torch.Size([16, 1, 12992])

### Sample from the model

In [42]:
x = torch.tensor([0], dtype=torch.long).view(-1 ,1)
x

tensor([[0]])

In [43]:
b, t = x.shape
wte = nn.Embedding(config.vocab_size, config.n_embd)

x_tok_emb = wte(x)
x_tok_emb.shape

torch.Size([1, 1, 384])

In [44]:
wpe = nn.Embedding(config.block_size, config.n_embd)

x_pos_emb = wpe(torch.arange(0, t, dtype=torch.long))
x_pos_emb.shape

torch.Size([1, 384])

In [45]:
x_h = x_tok_emb + x_pos_emb
x_h.shape

torch.Size([1, 1, 384])

In [46]:
drop = nn.Dropout(config.dropout)

x_h = drop(x_h)

In [47]:
c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
attn_dropout = nn.Dropout(config.dropout)
proj_dropout = nn.Dropout(config.dropout)

mask = torch.tril(torch.ones(config.block_size, config.block_size))
mask = mask.view(1, 1, config.block_size, config.block_size) # Add the batch dimension

In [70]:
B, T, C = x_h.shape
B, T, C

(1, 1, 384)

In [49]:
q, k, v = c_attn(x_h).split(config.n_embd, dim=2)

In [50]:
print(f'q, k, v shape: {q.shape}, {k.shape}, {v.shape}')

q, k, v shape: torch.Size([1, 1, 384]), torch.Size([1, 1, 384]), torch.Size([1, 1, 384])


In [51]:
k = k.view(B, T, config.n_head, C // config.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, config.n_head, C // config.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, config.n_head, C // config.n_head).transpose(1, 2) # (B, nh, T, hs)

In [52]:
print(f'q, k, v shape: {q.shape}, {k.shape}, {v.shape}')

q, k, v shape: torch.Size([1, 12, 1, 32]), torch.Size([1, 12, 1, 32]), torch.Size([1, 12, 1, 32])


In [53]:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att.shape

torch.Size([1, 12, 1, 1])

In [54]:
att = att.masked_fill(mask[:, :, :T, :T] == 0, float('-inf'))
att.shape

torch.Size([1, 12, 1, 1])

In [55]:
att = F.softmax(att, dim=-1)
att.shape

torch.Size([1, 12, 1, 1])

In [56]:
att = attn_dropout(att)
att.shape

torch.Size([1, 12, 1, 1])

In [57]:
y = att @ v
y.shape

torch.Size([1, 12, 1, 32])

In [58]:
y = y.transpose(1, 2).contiguous().view(B, T, C)
y.shape

torch.Size([1, 1, 384])

In [59]:
y = proj_dropout(c_proj(y))
y.shape

torch.Size([1, 1, 384])

In [60]:
ln_1 = LayerNorm(config.n_embd)
x_h = x_h + ln_1(y)
x_h.shape

torch.Size([1, 1, 384])

In [61]:
c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
gelu = nn.GELU()
c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
dropout = nn.Dropout(config.dropout)

In [62]:
x_h = c_fc(x_h)
x_h = gelu(x_h)
x_h = c_proj(x_h)
x_h = dropout(x_h)
x_h.shape

torch.Size([1, 1, 384])

In [63]:
ln_2 = ln_1 = LayerNorm(config.n_embd)
x_h = x_h + ln_2(x_h)
x_h.shape

torch.Size([1, 1, 384])

In [64]:
ln_f = LayerNorm(config.n_embd)

res = ln_f(x_h)
res.shape

torch.Size([1, 1, 384])

In [65]:
lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

logits = lm_head(res[:, [-1], :])
logits.shape

torch.Size([1, 1, 12992])

In [67]:
temperature = 1.0
logits = logits[:, -1, :] / temperature
logits.shape

torch.Size([1, 12992])

In [68]:
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next

tensor([[8486]])