<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [74]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [75]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-18 22:48:18--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2025-06-18 22:48:18 (17.4 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [76]:
##################
# HYPERPARAMTERS #
##################

In [77]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4

In [78]:
############
# ENCODING #
############

In [79]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [80]:
#############
# EMBEDDING #
#############

In [81]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [82]:
t = torch.tensor(encode("Haubi"))

In [83]:
embedding(t)

tensor([[ 0.9655,  2.1420,  0.1026, -0.1567,  0.2709, -0.0079,  0.5604, -1.3238,
          0.6949, -0.6346, -0.4504,  0.3740,  0.3065,  0.6298,  0.2285,  0.3683,
          0.2725, -0.6920,  0.1760,  0.1325, -0.0071,  0.9312, -1.2538,  0.7064,
          0.7876, -0.0338, -0.1693,  1.4249, -0.5741, -1.2141, -0.7215, -1.2245],
        [ 0.3530, -1.6809,  1.4108,  1.1818,  1.0998, -0.9324, -0.8904, -0.4201,
         -1.1354,  1.4010, -1.3822,  0.9621,  1.0063, -1.5000, -0.4825, -0.3301,
          0.5177, -1.4639,  1.0956,  1.1664,  1.4591,  2.3299, -1.0775,  0.1901,
          0.4501,  0.3235, -0.3753,  0.0261,  0.7775, -0.3224, -0.2147,  0.0438],
        [ 1.2298,  0.4247,  0.4889, -0.6393, -2.7005, -0.2637, -0.8545, -0.5603,
         -0.0794,  0.6169,  0.2654, -0.0371,  0.4269,  0.9110, -0.1900,  0.2717,
         -0.2002,  0.7572, -0.0031,  2.3400,  0.1033, -0.1757, -1.5082, -0.6113,
          0.1707, -0.9771,  0.4434,  0.1899, -0.7733,  0.9318, -1.2209,  0.2367],
        [ 0.0068,  0.7628

In [84]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [85]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [86]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [87]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[-0.0962,  1.1098, -0.1119, -0.8120, -0.8776, -1.4015,  1.3507, -0.2814,
          1.4955, -0.1668, -1.4774,  0.2952, -0.3414, -0.3457,  1.9336,  0.8511,
          0.7062, -0.9392,  0.4825,  0.9020,  2.0514, -0.0618,  0.4034,  1.4067,
          0.6454, -1.2013,  0.8931,  1.6418, -0.5308, -0.2597, -0.2019, -2.2324],
        [ 2.1937, -2.6700,  0.7154,  1.7961,  0.5056, -0.6785, -0.5441, -0.4413,
          0.2473,  1.4805, -1.9954, -0.9950,  0.6226,  0.4948, -0.9403, -1.8405,
         -1.6159, -1.7384,  0.8802,  1.6368,  1.9047,  1.4315, -0.1333,  1.2675,
         -1.6029, -0.6302, -1.3378, -0.3097, -0.7311, -0.2391,  0.4505,  1.4686],
        [ 1.9625, -0.0598,  0.5681, -1.0882, -3.1110,  0.7178, -1.7550, -2.1459,
          0.2455, -1.6774,  0.1030, -0.1173,  0.0333,  1.5377,  0.8426,  1.1725,
          0.9124, -0.6408,  0.9636,  2.1868,  0.4553,  1.2068, -1.7151,  0.1319,
          1.7400, -1.0255, -0.2647,  1.5574, -1.7029,  1.1208,  0.1079, -1.2516],
        [-1.8695,  1.2554

In [88]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [89]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [90]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.2141e-11, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [2.0567e-05, 9.9998e-01, 4.6030e-07, 0.0000e+00, 0.0000e+00],
         [1.6771e-14, 1.0000e+00, 6.4606e-17, 9.8931e-20, 0.0000e+00],
         [3.6904e-13, 1.8030e-06, 1.3597e-26, 6.5784e-08, 1.0000e+00]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[-1.0764,  0.5365,  0.1440,  0.5200, -0.3727, -0.3480, -0.6684, -1.0946,
           0.1395,  0.3309, -0.4356,  0.4809, -0.5565,  0.3512,  0.1686,  0.4983,
           0.6177,  0.5295, -0.7163,  0.1302, -0.9579, -0.2277, -0.5855, -0.4906,
          -0.8270, -0.1705, -1.7216,  0.3792, -0.0156, -0.1709, -0.0939, -0.1256],
         [-1.0764,  0.5365,  0.1440,  0.5200, -0.3727, -0.3480, -0.6684, -1.0946,
           0.1395,  0.3309, -0.4356,  0.4809, -0.5565,  0.3512,  0.1686,  0.4983,
           0.6177,  0.5295, -0.7163,  0.1302, -0.9579, -0.2277, -0.5855, -0.4906,
          -0.8270, -0.1705, -1.7

In [91]:
########################
# MULTI HEAD ATTENTION #
########################

In [92]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.long)))

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) / C ** -0.5
    wei = wei.masked_fill(self.tril == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v

    return out

In [93]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    return out

In [94]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [95]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
t_ma = multi_head_attention(embedded_tensor)
t_ma

tensor([[[ 3.1181e-01,  1.2491e-01, -1.5138e-01,  4.1956e-01, -5.3279e-01,
           1.7668e-01, -7.4490e-02, -6.5640e-02,  4.7382e-02, -1.0647e+00,
           3.6990e-01, -4.2196e-01,  8.8334e-02,  3.0116e-01, -4.6426e-02,
          -3.9286e-02,  4.6723e-02,  2.9490e-01, -2.1028e-01, -2.8957e-01,
          -8.2012e-02,  4.2746e-01, -4.5696e-01,  2.1170e-01,  1.4161e-03,
           2.0678e-01,  4.9818e-02,  5.9980e-01,  1.3458e-01,  4.1500e-01,
          -6.1290e-01, -3.8942e-02],
         [ 1.0971e-01,  1.6192e-02, -1.7401e-01,  3.2227e-01, -3.1322e-01,
           6.0930e-01, -1.6385e-01, -1.5799e-01,  1.7411e-01, -7.8808e-01,
           5.8384e-01, -4.1368e-01, -1.1081e-01,  4.9320e-01,  1.5624e-01,
           1.4238e-01, -4.6980e-02,  2.7358e-01,  1.0587e-01, -4.0231e-01,
          -1.5190e-01,  2.2247e-01, -3.7877e-01,  2.3561e-01,  4.1086e-02,
           1.5494e-01,  8.9608e-02,  4.6367e-01,  1.3130e-01,  3.4374e-01,
          -4.2633e-01, -3.9975e-02],
         [ 8.2690e-02,  5.

In [96]:
################
# FEED-FORWARD #
################

In [97]:
feed_forward_layer = nn.Sequential(
    nn.Linear(32, 4 * 32),
    nn.Linear(4 * 32, 32)
)

t_ff = feed_forward_layer(t_ma)
t_ff

tensor([[[ 0.1855,  0.1406, -0.0098,  0.1068, -0.0608,  0.1157,  0.0850,
          -0.1138, -0.1459, -0.0182,  0.1400,  0.1149,  0.2611, -0.0882,
          -0.2060,  0.1214, -0.2081,  0.1017, -0.1771, -0.0876,  0.2088,
          -0.0775, -0.1045, -0.0029,  0.1517, -0.0692, -0.0889, -0.1469,
          -0.2707, -0.0705,  0.1229,  0.0825],
         [ 0.1039,  0.1608,  0.1023,  0.0901, -0.0978,  0.1811, -0.0161,
          -0.0798, -0.0830,  0.0105,  0.1548,  0.0488,  0.2979, -0.1161,
          -0.1387,  0.1327, -0.2271,  0.1563, -0.1706, -0.0578,  0.1238,
          -0.0939, -0.0981, -0.0268,  0.1190, -0.0958,  0.0378, -0.1055,
          -0.2925, -0.0595,  0.1302,  0.1399],
         [ 0.0218,  0.1157, -0.0471,  0.2463, -0.1132, -0.1050,  0.3254,
          -0.1132, -0.0178, -0.2035,  0.1378,  0.0899,  0.3808,  0.1055,
          -0.0078, -0.0925, -0.2688,  0.1456, -0.2976, -0.2060,  0.4078,
           0.0199, -0.4583,  0.0058,  0.2251,  0.0942, -0.1280, -0.2469,
          -0.1817, -0.1317,  0

In [98]:
########################
# BLOCK AND LAYER-NORM #
########################

In [99]:
class AttentionBlock(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    self.multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
    self.feed_forward = nn.Sequential(
        nn.Linear(in_features=embedding_dim, out_features=4 * embedding_dim),
        nn.Linear(4 * embedding_dim, embedding_dim)
    )
    self.ln_1 = nn.LayerNorm(embedding_dim)
    self.ln_2 = nn.LayerNorm(embedding_dim)

  def forward(self, x):
    x = x + self.multi_head_attention(self.ln_1(x))
    x = x + self.feed_forward(self.ln_2(x))
    return x

In [100]:
attention_block = AttentionBlock(embedding_dim=embedding_dim, n_heads=n_heads)
t_block_2 = attention_block(t_ff)
t_block_2

tensor([[[-0.3593,  0.4308, -0.0341, -0.0758, -0.6614,  0.4628,  0.6237,
          -0.6133, -0.8757,  0.5346,  0.0156,  0.2852,  0.2849,  0.0760,
          -0.6488,  0.4055, -0.3469, -0.0933, -0.2119, -0.2616,  0.0562,
           0.4644, -0.1922,  0.9401,  0.7545, -0.6049, -0.7730, -0.0098,
           0.2786,  0.2299, -0.2957,  0.8753],
         [-0.4488,  0.4442,  0.1905, -0.6152, -0.3231,  0.5525,  0.5207,
          -0.4730, -0.8173,  0.5421, -0.0208,  0.3799,  0.2891, -0.0610,
          -0.4813,  0.3213, -0.2833,  0.1647, -0.2629, -0.4480,  0.0668,
           0.2828, -0.2775,  0.5457,  0.5247, -0.5027, -0.5147, -0.0790,
           0.3471,  0.3628, -0.3878,  0.9191],
         [-0.7207,  0.4525, -0.2702,  0.2677, -0.2619,  0.2868,  0.9852,
          -0.0617, -1.0268,  0.4289, -0.0144,  0.4526,  0.4052,  0.3802,
           0.0147,  0.4973, -0.0544,  0.3741,  0.0168, -0.4563,  0.0778,
           0.2597, -0.9783,  0.5024,  0.8982,  0.0421, -0.8502,  0.2195,
           0.3973,  0.1194, -0

In [101]:
#############
# (LM) HEAD #
#############

In [102]:
lm_head = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
logits = lm_head(t_block_2)
logits

tensor([[[-0.0692, -0.0310, -0.0866, -0.2803, -0.4678, -0.5694,  0.1613,
          -0.1169,  0.1581, -0.3940,  0.0285,  0.5748,  0.3131, -0.6899,
           0.0111,  0.2994,  0.0951, -0.4830, -0.0173, -0.3108, -0.2294,
          -0.1531,  0.0642, -0.0333, -0.0726, -0.6515,  0.3778, -0.0038,
          -0.2874,  0.5143,  0.1092, -0.0300,  0.0368,  0.6291, -0.1461,
          -0.2808,  0.1075,  0.2912, -0.5865,  0.1630,  0.0140, -0.0544,
          -0.4655, -0.3081, -0.1966, -0.0066,  0.0783, -0.4945, -0.1965,
          -0.4655,  0.1096, -0.5573, -0.0839,  0.3681,  0.2428, -0.1737,
           0.0528, -0.4109, -0.1231, -0.3206, -0.1888,  0.0827,  0.0030,
           0.2019, -0.1586],
         [ 0.0247, -0.0553, -0.1140, -0.2454, -0.3870, -0.5276,  0.0851,
          -0.0589,  0.2005, -0.3367,  0.0100,  0.7434,  0.3223, -0.5889,
          -0.0238,  0.3554,  0.1088, -0.3334, -0.0434, -0.3590, -0.2698,
          -0.2529, -0.0556,  0.1088, -0.1715, -0.5449,  0.2591,  0.1274,
          -0.1987,  0.

In [103]:
###########
# Softmax #
###########

In [104]:
probs = F.softmax(logits[:, -1, :], dim=-1)
encoded_token = torch.argmax(probs)
probs, encoded_token
decode([encoded_token.item()])

'J'

In [105]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [106]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])