<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [3]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-16 19:27:51--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-06-16 19:27:51 (13.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
##################
# HYPERPARAMTERS #
##################

In [5]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4

In [6]:
############
# ENCODING #
############

In [7]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [8]:
#############
# EMBEDDING #
#############

In [9]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [10]:
t = torch.tensor(encode("Haubi"))

In [11]:
embedding(t)

tensor([[-9.5173e-01,  5.1780e-01, -7.9678e-01, -1.3175e+00, -1.0763e+00,
         -7.7784e-01,  2.5596e-01, -1.2730e+00,  6.0056e-01, -1.1393e-01,
          3.6116e-01, -7.1648e-01, -2.7236e-01, -1.5923e+00, -1.9456e-02,
         -6.7467e-01, -6.1299e-01,  1.2723e+00, -2.0015e-01, -7.8960e-01,
         -5.7475e-02, -8.3011e-01,  1.2651e-01, -2.8292e-01, -5.4609e-01,
          1.9741e-01, -4.4609e-02,  2.0074e+00, -1.3648e-02,  1.1672e+00,
          1.1550e+00, -3.4768e-01],
        [-1.5869e+00, -7.4273e-01, -2.0695e+00,  5.9974e-01, -1.6373e+00,
          2.2822e+00,  6.8794e-01, -5.4350e-01, -3.4103e-01, -5.8377e-01,
         -7.4508e-01,  4.3472e-01, -1.5385e+00,  1.0176e+00, -9.0352e-03,
         -5.5720e-01,  1.0464e-01,  5.8074e-01, -1.1895e+00,  2.0971e-01,
          1.0315e+00,  1.7061e-01, -9.8962e-01, -5.5562e-01, -9.4058e-01,
         -1.9713e-01, -7.2097e-02, -2.0620e-01, -9.0988e-01, -1.6105e+00,
          3.3987e-01, -1.0508e+00],
        [-1.5462e+00,  3.6795e-01,  1.58

In [12]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [13]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [14]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [15]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[ 0.4031,  1.5242, -1.4711, -1.4731, -2.3746,  0.7424, -1.4808, -0.6002,
          0.4449,  0.2017,  0.2811, -1.3812, -0.1261, -2.4049,  0.5560,  0.0561,
         -1.4717,  2.7740,  1.2857, -0.4726,  0.8303, -0.3711, -0.5518, -0.1830,
         -0.8156,  1.4427,  0.6201,  1.3365,  0.2396,  1.0576, -0.7943, -0.5432],
        [ 0.3579, -0.9022, -2.4643, -1.0618, -4.4340,  1.0900,  1.1752, -0.3345,
          2.4141,  0.3887, -0.9735, -0.8638, -0.8206, -0.5100,  0.1251,  0.6341,
         -0.2691, -0.1135, -1.0485,  0.6256,  2.8587,  0.3859, -1.0955, -1.3121,
         -1.0247, -0.9247, -1.5952, -0.8083,  0.7482, -1.4537,  0.1571, -1.3824],
        [ 0.9586,  1.1592, -0.3475,  1.7701,  1.8518, -0.5691, -0.9678, -2.2441,
         -0.7347, -0.6185,  0.3041,  0.1427, -0.1808, -0.0668, -0.8789, -1.6833,
          2.0690,  1.9846,  0.3055,  0.1877, -0.1492, -1.0446, -1.5930,  0.8403,
          0.3971, -3.0851,  3.0781,  0.0779, -2.3419,  0.7785,  0.1147,  1.9837],
        [-0.8560, -0.4435

In [16]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [17]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [18]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 3.4261e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [7.7772e-07, 1.2493e-27, 1.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.5852e-10, 1.0000e+00, 2.5168e-21, 9.0518e-31, 0.0000e+00],
         [5.1570e-16, 1.6237e-16, 1.0000e+00, 1.7578e-06, 8.7914e-12]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[-0.0947, -0.4619, -1.0668, -0.3539,  0.5693,  0.7228, -1.8971,  0.0692,
           1.0132, -0.1141, -0.1005,  0.1660, -0.4363,  0.1858, -0.2616,  0.0830,
           0.4446, -0.7536, -0.4261,  0.7051, -0.3212, -0.4210,  0.0957,  1.3890,
           0.0606,  0.6657,  0.7420, -0.5712, -0.3385, -0.5723,  0.7971, -0.4750],
         [-0.0947, -0.4619, -1.0668, -0.3539,  0.5693,  0.7228, -1.8971,  0.0692,
           1.0132, -0.1141, -0.1005,  0.1660, -0.4363,  0.1858, -0.2616,  0.0830,
           0.4446, -0.7536, -0.4261,  0.7051, -0.3212, -0.4210,  0.0957,  1.3890,
           0.0606,  0.6657,  0.7

In [19]:
########################
# MULTI HEAD ATTENTION #
########################

In [20]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.long)))

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) / C ** -0.5
    wei = wei.masked_fill(self.tril == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v

    return out

In [21]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    return out

In [22]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [27]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
t_ma = multi_head_attention(embedded_tensor)
t_ma

tensor([[[-3.3058e-01,  1.5609e-01,  6.0706e-01,  3.7778e-01,  2.2100e-02,
           1.7917e-02, -2.2404e-01, -1.8024e-01,  2.8647e-01, -3.9880e-01,
          -7.2126e-02, -2.7563e-01, -3.9359e-01,  2.0978e-01,  2.2150e-01,
          -4.3236e-01,  1.8456e-01,  1.7068e-01,  1.4273e-01,  4.8865e-01,
          -2.2285e-01, -3.8417e-01, -4.9840e-01, -3.3975e-01, -2.6650e-02,
          -9.8114e-02,  1.9251e-03, -4.4611e-01, -3.1706e-01,  4.2815e-01,
           8.0081e-02,  1.5024e-01],
         [-2.7230e-01,  4.2659e-01,  4.8545e-01, -2.8686e-01,  4.2604e-01,
           6.0631e-01, -5.2107e-01,  4.6679e-01, -1.9142e-01, -5.8465e-01,
          -2.1072e-01,  2.5014e-01, -6.6826e-02, -9.9127e-03,  2.3967e-01,
          -1.6353e-01,  2.7169e-01,  5.9880e-02, -1.1598e-01,  4.3161e-01,
           2.9813e-01, -8.5463e-01, -2.6795e-01, -1.8554e-01, -2.0781e-01,
           9.8717e-02, -1.2805e-01, -6.6290e-01, -4.2944e-01,  4.5616e-01,
          -1.3259e-01,  7.4801e-01],
         [-1.3898e-01, -1.

In [26]:
################
# FEED FORWARD #
################

In [30]:
feed_forward_layer = nn.Sequential(
    nn.Linear(32, 4 * 32),
    nn.Linear(4 * 32, 32)
)

t_ff = feed_forward_layer(t_ma)
t_ff

tensor([[[ 0.1047,  0.1062,  0.1296,  0.1487,  0.0168, -0.0017,  0.1265,
          -0.0800,  0.0205,  0.0741,  0.0610, -0.0471,  0.0959, -0.2016,
           0.1497,  0.0997,  0.1914,  0.0814, -0.1831,  0.1698,  0.2157,
           0.1279, -0.1241,  0.0661, -0.1072,  0.0110,  0.1009,  0.1230,
           0.0199, -0.0581, -0.0990, -0.0433],
         [ 0.0775,  0.0821,  0.2429, -0.0842, -0.0789, -0.1661,  0.3810,
          -0.0154, -0.2219,  0.2125,  0.0796, -0.0068, -0.0079, -0.3919,
           0.2658,  0.0450,  0.1625, -0.0895, -0.3173,  0.0266,  0.2420,
           0.0774, -0.1415, -0.0360, -0.1784,  0.0830,  0.1682, -0.0670,
          -0.0937,  0.0481, -0.0330, -0.0315],
         [-0.1755, -0.2030, -0.1351, -0.2008,  0.0319, -0.1532,  0.2315,
           0.1163,  0.0887,  0.3876,  0.0439, -0.0806, -0.0839, -0.2843,
          -0.1508, -0.0357,  0.1185, -0.0181, -0.2770,  0.1237,  0.2894,
          -0.0960, -0.0279,  0.0806,  0.1171, -0.0997, -0.0392,  0.0041,
          -0.2928,  0.0316, -0

In [29]:
#############
# (LM) HEAD #
#############

In [36]:
lm_head = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
logits = lm_head(t_ff)
logits

tensor([[[-1.5506e-02, -6.6762e-02, -1.3066e-01,  6.5135e-02, -8.3698e-02,
          -5.1504e-02, -2.2719e-02,  1.6913e-01,  9.2439e-02,  2.4152e-02,
           1.6284e-01,  1.7046e-01, -4.5121e-02, -2.1280e-01,  9.1107e-02,
           1.7170e-01, -1.5197e-01,  2.9426e-01,  3.6097e-02, -3.9573e-02,
           2.5146e-01, -1.5428e-01, -1.7070e-01,  1.6593e-01,  7.0708e-02,
          -2.3348e-01,  4.0852e-03,  7.5131e-02,  3.8520e-02,  1.1388e-01,
          -1.0648e-01,  8.9193e-02, -9.9233e-02, -2.7843e-02,  2.6298e-01,
          -9.9775e-02,  6.4700e-02, -1.0298e-01,  1.4936e-01, -2.0188e-01,
           4.6250e-02, -8.0165e-02,  1.6471e-01,  9.2624e-02, -1.2229e-01,
           6.5251e-02,  7.2498e-03, -6.6246e-02,  1.7682e-01, -8.3913e-02,
          -6.7174e-02,  2.1854e-01, -1.4159e-01,  3.4919e-02,  5.0368e-02,
          -1.3656e-01,  1.5915e-01,  1.1124e-01,  1.2764e-01,  3.8893e-02,
          -6.7495e-02, -1.9108e-01, -1.3154e-02,  1.1746e-02,  1.5909e-01],
         [-6.5415e-02, -

In [35]:
###########
# Softmax #
###########

In [53]:
probs = F.softmax(logits[:, -1, :], dim=-1)
encoded_token = torch.argmax(probs)
probs, encoded_token
decode([encoded_token.item()])

't'

In [49]:
decode([58])

't'

In [24]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [25]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])