In [1]:
import numpy as np

train_data = np.memmap('../data/train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('../data/val.bin', dtype=np.uint16, mode='r')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
block_size = 32
batch_size = 16

In [4]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    
    # Randomly select chunk of text for training
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
    
    return x, y

In [5]:
x, y = get_batch('train')

In [6]:
x

tensor([[ 9437,  9393,    44,     0,  1666,  8986,  4756, 12410,  8558, 12953,
         11718,  6188,  5512, 11399,  5465,    44,     0,  6445,  6487,  5047,
          8135,  6428, 12953,  4822,  5024,  1257, 11416,  4808,    44,     0,
             0,  6087],
        [ 1473,    44,     0,  4620, 12110,  6728,  5619,  2208,  1267,  2087,
         12953,  9283,  4093, 12016, 11759,  2966,  4822, 11984,    44,     0,
          2033,  1342, 12807,  3303,  8966, 11083,  6531, 12953,  9233,  8829,
          7826,  9048],
        [11205, 11165,  5736,  7373, 10988,    44,     0,  8985,  2911, 11250,
          6032,  1845,  1885,  6223, 12953,  7420,  3719, 12718,  1346, 11703,
          5602,  7925,    44,     0, 12113,  1267,  4819, 12953,  2241, 11735,
          7922,    44],
        [ 9723,  2087,  9903,  9752, 12953,  1276,  1264,  9214,  5045,  3783,
          3368,  9418,    44,     0, 10595,  3960,  4889, 11612,  5663,  3645,
          1952, 12953,  2968,  7768,  1387, 10159,  6489,  

In [7]:
y

tensor([[ 9393,    44,     0,  1666,  8986,  4756, 12410,  8558, 12953, 11718,
          6188,  5512, 11399,  5465,    44,     0,  6445,  6487,  5047,  8135,
          6428, 12953,  4822,  5024,  1257, 11416,  4808,    44,     0,     0,
          6087,  5870],
        [   44,     0,  4620, 12110,  6728,  5619,  2208,  1267,  2087, 12953,
          9283,  4093, 12016, 11759,  2966,  4822, 11984,    44,     0,  2033,
          1342, 12807,  3303,  8966, 11083,  6531, 12953,  9233,  8829,  7826,
          9048, 11481],
        [11165,  5736,  7373, 10988,    44,     0,  8985,  2911, 11250,  6032,
          1845,  1885,  6223, 12953,  7420,  3719, 12718,  1346, 11703,  5602,
          7925,    44,     0, 12113,  1267,  4819, 12953,  2241, 11735,  7922,
            44,     0],
        [ 2087,  9903,  9752, 12953,  1276,  1264,  9214,  5045,  3783,  3368,
          9418,    44,     0, 10595,  3960,  4889, 11612,  5663,  3645,  1952,
         12953,  2968,  7768,  1387, 10159,  6489,  2901,  

In [8]:
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.0
bias = False

In [9]:
x1 = torch.randn(batch_size, block_size, n_embd)
attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
print(attn(x1).shape)
print(f'Split the last dimension into q, k, v')
q, k, v = torch.split(attn(x1), split_size_or_sections=n_embd, dim=2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

torch.Size([16, 32, 1152])
Split the last dimension into q, k, v
Shape of q, k, v: torch.Size([16, 32, 384]), torch.Size([16, 32, 384]), torch.Size([16, 32, 384])


In [10]:
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q = q.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
k = k.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
v = v.view(batch_size, block_size, n_head, n_embd // n_head).transpose(1, 2)
print(f'Shape of q, k, v: {q.shape}, {k.shape}, {v.shape}')

Shape of q, k, v: torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64]), torch.Size([16, 6, 32, 64])


In [11]:
# causal mask to ensure that attention is only applied to the left in the input sequence
mask = torch.tril(torch.ones(block_size, block_size))
mask = mask.view(1, 1, block_size, block_size) # Add the batch dimension
mask[:, :, :block_size, :block_size]

tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])

In [12]:
c_proj = nn.Linear(n_embd, n_embd, bias)
attn_dropout = nn.Dropout(dropout)
resid_dropout = nn.Dropout(dropout)

In [13]:
import math

# Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(mask=(mask[:, :, :block_size, :block_size] == 0), value=float('-inf'))
att = F.softmax(att, dim=-1)
att = attn_dropout(att)
y1 = att @ v # (B, nh, T, T) x (B, nh, T, hs) => (B, nh, T, hs)
# (B, nh, T, hs) => (B, T, nh, hs) => (B, T, nh * hs) => (B, T, C)
y1 = y1.transpose(1, 2).contiguous().view(batch_size, block_size, n_embd)
y1 = c_proj(y1)
y1 = resid_dropout(y1)

In [14]:
y1.shape

torch.Size([16, 32, 384])

In [15]:
class SelfAttention(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        assert n_embd % n_head == 0, "The remainder of embedding and head number should be zero."
        
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        self.register_buffer('bias', torch.tril(torch.ones(block_size, block_size))
                                          .view(1, 1, block_size, block_size))
    
    def forward(self, x):
        B, T, C = x.shape # batch_size, sequence length (block_size), embedding dimension (n_embd)
        
        q, k, v = torch.split(attn(x), split_size_or_sections=n_embd, dim=2)
        q = q.view(B, T, n_head, C // n_head).transpose(1, 2)
        k = k.view(B, T, n_head, C // n_head).transpose(1, 2)
        v = v.view(B, T, n_head, C // n_head).transpose(1, 2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(mask=(self.bias[:, :, :block_size, :block_size] == 0), value=float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        return resid_dropout(c_proj(y))

In [16]:
class LayerNorm(nn.Module):
    
    def __init__(self, n_dim, bias) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(n_dim))
        self.bias = nn.Parameter(torch.zeros(n_dim)) if bias else None
    
    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, eps=1e-5)

In [17]:
class MLP(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, 4 * n_embd, bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [18]:
class SelfAttentionBlock(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(n_embd, bias)
        self.attn = SelfAttention()
        self.ln_2 = LayerNorm(n_embd, bias)
        self.mlp = MLP()
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [19]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    block_size: int = 32
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 384 
    dropout: float = 0.0
    bias: bool = True

In [20]:
config = GPTConfig

In [21]:
emb_w = nn.Embedding(config.vocab_size, config.n_embd).weight
proj_w = nn.Linear(config.n_embd, config.vocab_size, bias=False).weight
print(f'{emb_w.shape}, {proj_w.shape}')

torch.Size([50304, 384]), torch.Size([50304, 384])


In [22]:
wte = nn.Embedding(config.vocab_size, config.n_embd)
wpe = nn.Embedding(config.block_size, config.n_embd)
drop = nn.Dropout(config.dropout)
self_attn = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)])
ln_f = LayerNorm(config.n_embd, bias=config.bias)

lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias)

In [24]:
b, t = x.size()

In [26]:
tok_emb = wte(x)

In [45]:
pos = torch.arange(0, t, dtype=torch.long)
pos_emb = wpe(pos)

In [28]:
h = drop(tok_emb + pos_emb)

In [30]:
for block in self_attn:
    h = block(h)

In [31]:
res = ln_f(h)

In [32]:
res

tensor([[[ 0.0030, -0.5735, -0.4376,  ...,  0.9856,  0.5710,  0.4209],
         [ 0.6976, -1.5383, -0.1592,  ...,  1.1898, -0.7068,  0.2978],
         [-0.7431, -0.7057,  0.2619,  ...,  0.2048, -0.3446,  1.0029],
         ...,
         [ 0.6525, -0.0219,  0.2504,  ..., -0.7797,  0.7056,  1.6049],
         [ 0.4847,  0.5181,  0.0750,  ..., -0.4665, -1.7071,  1.8080],
         [ 1.0355,  1.2601, -0.5118,  ..., -0.3321,  0.2384, -0.3286]],

        [[ 0.2986,  2.1307,  0.2097,  ...,  0.3658,  0.9099,  0.6240],
         [ 0.6353,  0.2701,  0.6960,  ...,  0.3434,  0.3542,  0.7578],
         [-0.6637,  0.8668,  1.2509,  ..., -0.6774,  0.9275,  1.9576],
         ...,
         [ 0.4944,  0.2033,  0.6882,  ..., -1.0972,  2.5491,  0.6758],
         [ 1.2667,  0.2470, -0.3598,  ..., -1.3504,  2.0244,  0.8600],
         [ 0.5276,  1.5340, -0.3205,  ..., -0.8318,  1.2741,  0.3036]],

        [[-0.1550, -0.3403,  0.4283,  ..., -0.7204,  0.3657,  1.2489],
         [ 0.9159, -0.4431,  0.4249,  ..., -0

In [33]:
res.shape

torch.Size([16, 32, 384])

In [34]:
last_step = res[:, [-1], :] # note: using list [-1] to preserve the time dim
last_step.shape

torch.Size([16, 1, 384])

In [35]:
inference_logits = lm_head(last_step)
print(inference_logits.shape)
inference_logits

torch.Size([16, 1, 50304])


tensor([[[-0.0696, -0.2707, -0.2855,  ..., -0.0197,  0.2865,  0.3227]],

        [[-0.3631, -0.3368,  0.5571,  ...,  0.9414,  0.3481,  0.8444]],

        [[-0.6672, -1.0062,  0.0357,  ...,  0.4258,  0.2164, -0.3530]],

        ...,

        [[-0.4421, -0.4289, -0.6132,  ...,  0.7328,  0.0855,  1.0328]],

        [[-0.1924, -0.5834, -0.0645,  ...,  1.2259,  0.2577,  0.6021]],

        [[-0.4020, -0.8917,  0.5011,  ..., -0.0258,  0.3603,  0.7311]]],
       grad_fn=<ViewBackward0>)

In [36]:
temperature = 1.0

In [37]:
print(inference_logits[:, -1, :].shape)
inference_logits = inference_logits[:, -1, :]
inference_logits / temperature

torch.Size([16, 50304])


tensor([[-0.0696, -0.2707, -0.2855,  ..., -0.0197,  0.2865,  0.3227],
        [-0.3631, -0.3368,  0.5571,  ...,  0.9414,  0.3481,  0.8444],
        [-0.6672, -1.0062,  0.0357,  ...,  0.4258,  0.2164, -0.3530],
        ...,
        [-0.4421, -0.4289, -0.6132,  ...,  0.7328,  0.0855,  1.0328],
        [-0.1924, -0.5834, -0.0645,  ...,  1.2259,  0.2577,  0.6021],
        [-0.4020, -0.8917,  0.5011,  ..., -0.0258,  0.3603,  0.7311]],
       grad_fn=<DivBackward0>)

In [38]:
probs = F.softmax(inference_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next

tensor([[ 3181],
        [24246],
        [30504],
        [23911],
        [ 6131],
        [47783],
        [28942],
        [18260],
        [43682],
        [14037],
        [ 2503],
        [26882],
        [22515],
        [12443],
        [38892],
        [34382]])

In [39]:
logits = lm_head(res)

In [40]:
logits.size()

torch.Size([16, 32, 50304])

In [41]:
logits = logits.view(-1, logits.size(-1)) # (B, T, C) => (B * T, C)
logits

tensor([[-0.7096,  0.0828, -0.3921,  ...,  0.1466, -0.0737, -0.3202],
        [ 0.3355,  0.1282,  0.3670,  ...,  0.0767,  0.3652, -0.6138],
        [ 0.1141,  1.1284,  0.5404,  ..., -0.1263, -0.0282, -0.6873],
        ...,
        [ 0.4092, -0.5623,  0.2820,  ..., -0.8355,  0.2703, -0.1455],
        [-0.0415, -0.6487,  0.3678,  ...,  0.1901, -0.3452, -0.1307],
        [-0.4020, -0.8917,  0.5011,  ..., -0.0258,  0.3603,  0.7311]],
       grad_fn=<ViewBackward0>)

In [42]:
y

tensor([[12016,  1328,  1561,  3306,  4851,  2171,    44,     0,  3696,  3696,
          4505, 12682, 11070, 11701,  5695, 12953,  1387,  3696,  3882, 12682,
          3240,  4851,  1846,    44,     0, 12682,  2656,  2039, 10019,  5944,
          5734,  3783],
        [ 5024, 12953, 11206, 11213,  1480,  2880,  5881,  5847,  3621,    44,
             0,     0,  9140, 11832,  2901,  1264,  7762,  1398,  3222, 12953,
         11956,  4905,  1375, 11605,  3306,  7373,  3259,    44,     0,  4122,
          3882,  3836],
        [ 2020,  6763,  4822,  2184,  1268, 10571,  2262,    44,     0,  2033,
          3696, 11513,  4301,  1299,  1480,  1341, 12953,  1257,  3234,  4851,
         10520,  7922,  3715, 12089,    44,     0,  2954, 12016,  1676, 10970,
         10321,  5619],
        [12953,  3225,  3905, 11713,  2880,  4905,  1578,  9957,    44,     0,
          3316,  3140,  5047,  1894,  4851,  1299,  2651, 12953, 11194,  6647,
          5039, 11409,  4980,  7696,  2873,    44,     0,  

In [43]:
y.view(-1).shape # (B, T, 1) => (B * T)

torch.Size([512])

In [44]:
F.cross_entropy(logits, y.view(-1), ignore_index=-1)

tensor(10.9734, grad_fn=<NllLossBackward0>)

In [27]:
transformer = nn.ModuleDict(dict(
    wte = nn.Embedding(config.vocab_size, config.n_embd),
    wpe = nn.Embedding(config.block_size, config.n_embd),
    drop = nn.Dropout(config.dropout),
    self_attn_block = nn.ModuleList([SelfAttentionBlock() for _ in range(config.n_layer)]),
    ln_f = LayerNorm(config.n_embd, bias=config.bias)
))

lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias)

In [30]:
tok_emb = transformer.wte(x)
device = x.device
b, t = x.size()
pos = torch.arange(0, t, dtype=torch.long, device=device)
pos_emb = transformer.wpe(pos)
h = transformer.drop(tok_emb + pos_emb)
for block in transformer.self_attn_block:
    h = block(h)
res = transformer.ln_f(h)

In [34]:
# compute the loss
logits = lm_head(res)
F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1), ignore_index=-1)

tensor(10.9427, grad_fn=<NllLossBackward0>)

In [36]:
# for generation in inference
logits = lm_head(res[:, [-1], :])
logits.shape

torch.Size([16, 1, 50304])