<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [3]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-17 21:01:50--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-06-17 21:01:50 (18.3 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
##################
# HYPERPARAMTERS #
##################

In [5]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4

In [6]:
############
# ENCODING #
############

In [7]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [8]:
#############
# EMBEDDING #
#############

In [9]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [10]:
t = torch.tensor(encode("Haubi"))

In [11]:
embedding(t)

tensor([[-7.5366e-01, -1.3355e+00, -9.5343e-01, -1.2788e-01, -5.9261e-01,
         -6.8150e-02,  2.6740e+00, -5.5221e-01,  3.3798e-01,  1.0125e+00,
          5.0238e-01,  1.6657e+00,  8.6034e-01, -3.5557e-01,  1.4691e+00,
          2.4734e-04,  3.9077e-01,  2.7046e-01, -4.7595e-01, -1.3449e+00,
          3.2421e-01,  2.5409e-01,  1.6422e-01, -2.2238e-01,  2.0846e+00,
          9.2306e-01,  1.3367e+00,  1.1470e+00,  9.1751e-01, -6.8443e-01,
         -2.2108e-01,  4.9571e-01],
        [ 1.1345e+00, -2.0509e+00,  1.2738e+00,  2.6450e-01,  9.2227e-01,
          1.3787e+00, -1.6864e-01,  5.2769e-01,  3.8739e-01,  1.0233e-01,
          3.4869e-01,  4.8322e-01, -7.7627e-01, -1.0864e+00, -3.2532e-01,
          1.2600e+00, -7.4000e-01, -6.0799e-01, -7.3361e-01, -2.3849e+00,
         -2.8244e-01,  2.5775e-01, -7.0811e-02, -2.0738e-02, -6.2029e-01,
         -5.9893e-01, -5.6213e-01,  5.3727e-01, -4.8552e-01, -5.8281e-01,
          7.6324e-02,  1.2060e-01],
        [ 4.9665e-01,  5.9644e-01,  3.17

In [12]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [13]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [14]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [15]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[-1.2532e+00, -1.6064e+00,  1.1530e+00,  4.4417e-01, -4.4237e-01,
          1.0037e+00,  2.7590e+00, -4.1486e-01,  3.3742e-01,  1.6255e+00,
          4.4659e-01,  1.7538e-01,  7.9587e-01,  1.0253e+00,  2.0597e+00,
          8.1405e-01, -9.0300e-02, -6.0260e-01,  4.0990e-01, -1.5487e+00,
         -7.1260e-01,  3.9477e-02,  8.1003e-01, -8.8245e-01,  2.6071e+00,
          4.6272e-01,  1.8757e+00,  1.1595e+00, -1.6260e-01, -8.3013e-01,
         -2.1645e-01,  5.1799e-02],
        [-1.1807e+00, -2.1285e+00,  5.0259e-01,  6.6119e-01,  2.4258e+00,
          4.7797e-01, -6.9167e-01,  1.2287e+00, -8.6285e-01, -1.9747e-01,
         -7.3997e-01,  2.3542e-01,  7.5654e-01, -5.6673e-01, -3.3865e-01,
          8.2985e-01, -1.5717e+00, -2.4658e+00,  1.2715e+00, -3.1994e+00,
         -1.2450e+00,  3.6129e-01, -5.4724e-01, -6.9233e-01, -8.9199e-01,
         -1.3559e+00, -7.4846e-01, -9.3394e-01, -1.1047e+00, -1.0573e+00,
          7.3455e-01, -2.7298e-01],
        [-1.3896e-01, -3.6571e-01, -2.07

In [16]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [17]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [18]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.5129e-04, 9.9985e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.3234e-02, 9.6677e-01, 5.4140e-10, 0.0000e+00, 0.0000e+00],
         [9.9939e-01, 2.0449e-09, 3.1674e-11, 6.0595e-04, 0.0000e+00],
         [4.0029e-05, 9.5305e-01, 3.8786e-17, 4.6906e-02, 5.5173e-11]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[ 0.0761, -0.0034, -0.3397, -0.7339, -0.4812,  0.0147, -0.4112,  0.2808,
           0.3040, -0.2096,  0.7364, -0.0332, -0.5945, -0.0239, -0.1495, -0.3491,
          -0.0406,  0.9874, -0.6307,  0.6437, -0.3366,  0.1362, -0.4442, -0.0879,
           1.2862,  0.9153, -0.9047,  0.2046, -0.2290, -0.6694,  0.8536, -0.1262],
         [ 1.2878, -0.4434, -0.4567,  0.6370,  0.3212,  0.5554,  0.4604,  0.2023,
          -0.0785, -0.1363,  0.6354, -1.0524, -0.4914, -1.0078,  1.2723,  0.0829,
          -0.5597,  0.4052,  0.1319,  0.8915, -0.2411, -0.7821, -0.4712,  0.4452,
          -0.2944, -1.2060, -1.1

In [19]:
########################
# MULTI HEAD ATTENTION #
########################

In [20]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.long)))

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) / C ** -0.5
    wei = wei.masked_fill(self.tril == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v

    return out

In [21]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    return out

In [22]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [23]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
t_ma = multi_head_attention(embedded_tensor)
t_ma

tensor([[[ 0.6135,  0.4676, -0.8231, -0.3076,  0.2564,  0.0790,  0.4323,
          -0.9637,  0.1887, -0.2450, -0.0472,  0.4426, -0.0499,  0.0243,
          -0.0822, -0.1053, -0.1355, -0.1296, -0.1779,  0.2735,  0.0200,
          -0.0737, -0.6891, -0.1153,  0.0649, -0.3029, -0.0251, -0.2023,
           0.0402,  0.2548,  0.0284,  0.4415],
         [ 0.2549,  0.2140, -0.3102, -0.0268,  0.6124,  0.0846,  0.2621,
          -0.5599,  0.4099, -0.4623,  0.1836,  0.2203, -0.0319,  0.0119,
          -0.1157,  0.0531, -0.3506, -0.0069,  0.1406, -0.0348, -0.0272,
          -0.2335, -0.5661, -0.1182,  0.1854, -0.6343, -0.2388, -0.3064,
           0.2775, -0.4665,  0.1687, -0.1624],
         [ 0.4415, -0.3585, -0.4773, -0.0803, -0.1515, -0.3374, -0.1755,
           0.1813, -0.2888, -0.3172,  0.2831, -0.0774,  0.2015,  0.8621,
           0.3630,  0.0156, -0.1749, -0.7266,  0.8180, -0.1115,  0.0387,
          -0.3817, -0.0071,  0.0643, -0.2760, -0.2032, -0.1750, -0.1988,
           0.3453,  0.0054,  0

In [None]:
########################
# BLOCK AND LAYER-NORM #
########################

In [24]:
################
# FEED FORWARD #
################

In [25]:
feed_forward_layer = nn.Sequential(
    nn.Linear(32, 4 * 32),
    nn.Linear(4 * 32, 32)
)

t_ff = feed_forward_layer(t_ma)
t_ff

tensor([[[ 0.0322, -0.0222, -0.1197, -0.1735, -0.0272, -0.1213,  0.0260,
           0.0487, -0.0979, -0.1702, -0.0087,  0.0167,  0.3371,  0.1032,
          -0.1167, -0.1177,  0.0526,  0.2495,  0.0627,  0.0480,  0.1684,
           0.0268,  0.1032, -0.0744, -0.2038,  0.0711, -0.0529, -0.2110,
          -0.0500, -0.0478, -0.0237,  0.0801],
         [ 0.0509, -0.1033, -0.1740, -0.1748, -0.0296, -0.1390,  0.0074,
          -0.0642, -0.1643, -0.0703,  0.0762, -0.0052,  0.2806, -0.0990,
          -0.1343, -0.1219, -0.0541,  0.1315, -0.0105,  0.0103,  0.0883,
           0.1194, -0.0304, -0.1036, -0.2488,  0.0967, -0.0571, -0.1384,
          -0.0286, -0.0931, -0.0866, -0.1186],
         [-0.1277,  0.0565, -0.1293,  0.0694,  0.2484, -0.1059, -0.1790,
          -0.2926, -0.3403, -0.2118,  0.1978, -0.1498,  0.0350, -0.1406,
          -0.0943, -0.1315,  0.0386,  0.1323,  0.2245,  0.1076,  0.1844,
           0.1840,  0.0637, -0.0875,  0.0692, -0.0242, -0.0996, -0.0533,
          -0.1709, -0.2883,  0

In [26]:
#############
# (LM) HEAD #
#############

In [27]:
lm_head = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
logits = lm_head(t_ff)
logits

tensor([[[-1.7702e-01, -2.1801e-02,  3.5937e-02,  1.4307e-02, -2.9243e-01,
          -1.9164e-02, -2.1975e-02,  8.3152e-02, -4.2611e-02, -5.8562e-02,
          -4.6714e-03,  1.0765e-01,  7.3775e-02,  2.0142e-01,  4.2462e-02,
          -1.0779e-02, -1.0092e-03, -1.0248e-01, -4.8939e-02,  8.8126e-02,
           2.2856e-01, -5.8215e-02, -2.9996e-02,  2.4623e-01,  1.6888e-01,
           2.1802e-01,  2.6469e-01, -2.2612e-02,  1.0754e-01, -1.0461e-01,
          -7.1441e-02,  2.0041e-01,  6.6664e-02,  7.8956e-03, -1.6338e-01,
           1.3341e-01, -4.6898e-02,  1.1171e-01, -1.1750e-01, -8.0280e-02,
           1.8590e-01, -1.7282e-03,  7.8408e-02, -5.0425e-03,  2.1106e-02,
           5.2028e-02,  1.8085e-01,  1.7908e-01, -1.3047e-01,  1.0643e-01,
          -2.6701e-01, -1.3622e-01,  3.9973e-02,  5.5455e-02,  1.1953e-01,
           9.2822e-02, -2.7424e-02,  8.1813e-02, -1.7213e-01, -4.9791e-02,
          -1.2629e-01, -8.1584e-02,  9.6156e-02,  7.6072e-02,  9.8191e-02],
         [-9.2503e-02, -

In [28]:
###########
# Softmax #
###########

In [29]:
probs = F.softmax(logits[:, -1, :], dim=-1)
encoded_token = torch.argmax(probs)
probs, encoded_token
decode([encoded_token.item()])

'b'

In [30]:
decode([58])

't'

In [31]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [32]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])