<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [70]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [71]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-18 21:59:00--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-06-18 21:59:00 (18.0 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [72]:
##################
# HYPERPARAMTERS #
##################

In [39]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4

In [40]:
############
# ENCODING #
############

In [41]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [42]:
#############
# EMBEDDING #
#############

In [43]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [44]:
t = torch.tensor(encode("Haubi"))

In [45]:
embedding(t)

tensor([[ 2.5800, -0.0949,  0.7592, -0.3650,  1.3564, -0.8533,  0.5991,  1.1836,
          1.7810,  1.2654,  2.4589,  1.5647,  0.1009, -0.3783,  0.6628, -0.0811,
         -0.3990,  0.3540, -0.8643,  0.9216, -1.4456, -0.1202, -0.7636,  0.7841,
         -0.0887, -0.1448, -0.1449, -0.2748, -1.0291, -0.1812,  0.2785, -2.5819],
        [ 0.8975, -1.0679,  0.0434,  0.5751, -0.5031, -0.9826, -1.0476,  0.1031,
          0.0763,  1.2359, -0.6472,  0.7668, -0.2575, -0.6466,  0.3242,  0.1444,
         -2.3190,  0.5099,  1.3419,  0.1118,  0.5495, -0.0386,  0.2451,  0.9431,
          0.3770,  1.8286, -0.7861,  0.3404,  0.4163,  0.3839, -0.6465,  0.4309],
        [ 0.3578, -0.7341, -0.3830, -0.1290, -2.2323, -1.4891,  0.1240, -1.2331,
         -1.1088,  1.0324, -0.3613, -1.8569, -1.6749,  0.8790, -0.5744,  0.3798,
         -0.9525,  1.1148,  0.2572, -0.1452, -0.4806, -0.6876,  1.1803, -0.8684,
          0.9645,  1.1917, -1.2154,  0.1223, -0.0257, -0.7730, -1.7858, -0.9728],
        [-1.0123,  0.2110

In [46]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [47]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [48]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [49]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[ 4.4155, -0.2162,  0.0065, -1.7307,  1.6822, -0.4931,  1.4304,  1.0572,
          3.9727,  0.7057,  1.8026,  1.3689,  0.3920, -0.1285,  2.1611, -0.8483,
          1.0468,  1.0871, -0.8970,  1.0322, -1.9415, -2.1131, -1.2541,  0.5004,
         -0.4549, -0.2385, -0.5177,  0.1364, -0.9234, -0.0941,  2.1807, -2.1845],
        [ 2.8370, -0.9739,  1.0963,  0.9426, -1.3841,  0.2568, -0.8726, -1.0413,
         -0.4496,  2.0147, -1.6400,  0.6019, -0.3357, -2.0018,  0.2246,  1.3742,
         -0.3641,  0.2372,  0.7284,  1.2351,  0.1472, -0.3036,  0.8625, -0.4231,
         -1.2266,  2.6165, -0.8946, -0.1381, -0.1918,  0.6834, -0.9202, -0.5171],
        [ 0.8210, -0.7542, -2.1300,  0.9705, -0.4383, -1.9120, -1.9551, -1.0408,
         -0.4011,  2.3656, -0.4639, -1.8330, -1.5049, -0.5385, -0.5345,  1.8234,
         -0.0545,  1.8965,  0.3158, -1.3977,  0.4461, -2.2009,  0.7559, -0.4497,
          2.6115,  0.7996, -1.5283,  0.6124, -0.6363,  0.6133, -2.0057, -1.5494],
        [-0.0797, -0.1761

In [50]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [51]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [52]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.9345e-14, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [4.3111e-18, 4.2640e-03, 9.9574e-01, 0.0000e+00, 0.0000e+00],
         [2.5450e-09, 3.5147e-10, 4.9606e-08, 1.0000e+00, 0.0000e+00],
         [9.4689e-01, 5.1852e-02, 2.2750e-10, 1.2546e-03, 6.3807e-17]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[-0.6631, -0.8471,  0.3540, -0.3995,  0.4347, -0.8918,  0.2366,  1.5138,
           0.5878,  1.5658, -0.8060,  0.0741,  0.1215, -1.0118,  0.3655,  0.9200,
          -0.3271, -0.1265,  0.1571, -1.8154, -1.6308, -0.6471,  0.9462,  0.6098,
           1.0935, -0.8272,  0.3581,  0.2497,  0.1220, -0.7735,  0.3773, -0.3455],
         [-0.6631, -0.8471,  0.3540, -0.3995,  0.4347, -0.8918,  0.2366,  1.5138,
           0.5878,  1.5658, -0.8060,  0.0741,  0.1215, -1.0118,  0.3655,  0.9200,
          -0.3271, -0.1265,  0.1571, -1.8154, -1.6308, -0.6471,  0.9462,  0.6098,
           1.0935, -0.8272,  0.3

In [53]:
########################
# MULTI HEAD ATTENTION #
########################

In [54]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.long)))

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) / C ** -0.5
    wei = wei.masked_fill(self.tril == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v

    return out

In [55]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    return out

In [56]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [57]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
t_ma = multi_head_attention(embedded_tensor)
t_ma

tensor([[[-1.0545, -0.8248, -1.2653,  0.3533, -0.3357,  1.2724,  1.2836,
          -0.3933, -0.5026,  0.0551,  0.1272,  0.6789,  0.0463, -0.2834,
           0.4041,  0.7685,  0.2765,  0.6673, -0.0652, -0.4075, -0.0361,
           0.6307, -0.2322, -0.1729, -0.2958,  0.4914, -0.3151, -0.2354,
          -1.1973, -0.3631,  0.1999, -0.7662],
         [-0.5489, -0.2506, -1.0980,  0.1862, -0.6716,  1.2653,  1.0010,
          -0.3168, -0.3356, -0.2549, -0.0885, -0.0576, -0.0037, -0.2394,
           0.4813,  0.6366,  0.0957, -0.3787,  0.2402, -0.1543, -0.2183,
           1.1449, -0.2358,  1.0939,  0.3573,  0.9340,  0.2921,  0.0878,
          -1.2566, -0.1301, -0.6037,  0.0780],
         [-0.5327, -0.2657, -1.2460,  0.0581, -0.1193,  1.0900,  1.2358,
          -0.1291, -0.4546, -0.0655,  0.1892,  0.0221,  0.3257, -0.7979,
          -0.0376,  0.6193, -0.0521,  0.2909,  0.3014,  0.0581, -0.5964,
           1.3661, -0.2233,  0.8258,  0.3469,  0.7413,  0.5981, -0.3863,
          -1.0663, -0.3041, -0

In [58]:
################
# FEED-FORWARD #
################

In [59]:
feed_forward_layer = nn.Sequential(
    nn.Linear(32, 4 * 32),
    nn.Linear(4 * 32, 32)
)

t_ff = feed_forward_layer(t_ma)
t_ff

tensor([[[-0.0211,  0.0744,  0.1264,  0.1699, -0.2576, -0.0369,  0.1442,
           0.1003,  0.1561, -0.0962,  0.3293,  0.0320,  0.1998, -0.2574,
           0.0230,  0.2018,  0.1788,  0.1161,  0.0851,  0.2169, -0.1008,
          -0.1533, -0.0895, -0.1598, -0.2883,  0.1405, -0.0723,  0.4079,
          -0.3392, -0.0624,  0.2001, -0.3019],
         [-0.1276,  0.2321, -0.0021, -0.0290, -0.3454,  0.3835,  0.1918,
          -0.2028,  0.1626, -0.2369,  0.1337, -0.0267,  0.2928, -0.1264,
           0.2216,  0.1580, -0.0204,  0.1128, -0.1065,  0.2821, -0.2652,
           0.1285, -0.0313, -0.3092, -0.2041,  0.2325,  0.0369,  0.3849,
          -0.1925, -0.0475,  0.1249, -0.3190],
         [-0.1152,  0.0384,  0.0619,  0.0358, -0.3171,  0.3231,  0.0837,
          -0.1357,  0.0153, -0.1977,  0.0671, -0.0552,  0.2185, -0.2442,
           0.0798,  0.1346, -0.0491,  0.2840, -0.0619,  0.2442, -0.2030,
          -0.0539, -0.0870, -0.2087, -0.3249,  0.2875, -0.0526,  0.4013,
          -0.2529, -0.0071,  0

In [60]:
########################
# BLOCK AND LAYER-NORM #
########################

In [63]:
class AttentionBlock(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    self.multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
    self.feed_forward = nn.Sequential(
        nn.Linear(in_features=embedding_dim, out_features=4 * embedding_dim),
        nn.Linear(4 * embedding_dim, embedding_dim)
    )

  def forward(self, x):
    x = self.multi_head_attention(x)
    x = self.feed_forward(x)
    return x

In [64]:
attention_block = AttentionBlock(embedding_dim=embedding_dim, n_heads=n_heads)
t_block_2 = attention_block(t_ff)
t_block_2

tensor([[[ 1.3628e-01, -1.1693e-01, -1.4897e-02,  1.0651e-02,  4.1129e-02,
           1.2110e-01,  7.0368e-02, -3.7389e-02,  1.1911e-01, -6.5945e-02,
           3.6065e-04, -7.6318e-02,  1.4899e-01, -4.1852e-02,  5.1185e-02,
          -6.3043e-02, -2.6743e-02, -4.6493e-02,  1.7397e-01,  1.6558e-01,
          -2.9394e-02,  1.1385e-01, -6.5799e-02,  1.6398e-01, -8.4144e-02,
          -1.0490e-01, -3.0380e-02,  3.7267e-03,  1.3677e-01,  1.5944e-01,
          -3.5852e-02,  1.4431e-01],
         [ 1.2629e-01, -1.0874e-01, -1.4859e-02,  5.1874e-03,  2.8196e-02,
           1.1750e-01,  6.6046e-02, -4.3696e-02,  1.1677e-01, -6.5234e-02,
           6.1439e-03, -8.1933e-02,  1.5270e-01, -3.5471e-02,  4.9201e-02,
          -4.9665e-02, -1.9192e-02, -4.9481e-02,  1.5990e-01,  1.6782e-01,
          -2.6687e-02,  1.0790e-01, -6.1787e-02,  1.5872e-01, -7.1192e-02,
          -1.0990e-01, -2.3950e-02, -1.9822e-05,  1.4497e-01,  1.4763e-01,
          -4.0706e-02,  1.4467e-01],
         [ 1.2695e-01, -1.

In [27]:
#############
# (LM) HEAD #
#############

In [65]:
lm_head = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
logits = lm_head(t_block_2)
logits

tensor([[[ 1.3321e-01, -1.1161e-01,  1.9875e-01, -5.9498e-02, -2.0814e-01,
           2.3002e-02, -3.7312e-02, -1.8708e-02,  9.9739e-02, -7.2411e-02,
           1.4606e-01,  1.8686e-02,  2.4535e-01, -9.6089e-02, -8.1794e-04,
           2.9721e-02,  2.2889e-01,  7.6256e-02, -2.2849e-02, -8.3181e-02,
           1.8891e-01, -3.8253e-03,  1.1956e-02,  2.4068e-01,  1.3616e-02,
          -7.8111e-02, -6.2523e-02, -1.3010e-01, -1.8175e-01, -9.9558e-02,
          -6.5696e-02, -5.6963e-02, -1.8788e-02, -1.2138e-02,  8.6316e-02,
          -4.3037e-02, -1.7990e-01, -3.7059e-02, -6.8783e-02, -1.0026e-01,
           1.5803e-01, -4.7147e-02,  1.4728e-01,  9.6631e-02,  1.1522e-01,
          -1.4173e-01, -4.7850e-04,  4.6883e-02, -1.8378e-01,  3.6787e-03,
          -3.7405e-02,  5.9665e-02,  4.7141e-02,  1.0253e-01,  9.6597e-02,
           1.4454e-01,  5.5597e-03,  1.0009e-02,  1.0874e-01, -7.9356e-03,
           1.5096e-01, -3.5419e-02,  4.2991e-02,  1.7127e-01, -2.0562e-01],
         [ 1.3553e-01, -

In [66]:
###########
# Softmax #
###########

In [67]:
probs = F.softmax(logits[:, -1, :], dim=-1)
encoded_token = torch.argmax(probs)
probs, encoded_token
decode([encoded_token.item()])

'?'

In [32]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [33]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])