<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/gpts/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [None]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-19 08:56:16--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.5’


2025-06-19 08:56:16 (16.8 MB/s) - ‘input.txt.5’ saved [1115394/1115394]



In [None]:
##################
# HYPERPARAMTERS #
##################

In [None]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4
dropout = 0.2

In [None]:
############
# ENCODING #
############

In [None]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [None]:
#############
# EMBEDDING #
#############

In [None]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [None]:
encode("Hausi")

[20, 39, 59, 57, 47]

In [None]:
t = torch.tensor(encode("Hausi"))

In [None]:
embedding(t)

tensor([[ 0.3424,  0.1892, -1.6773, -0.2433,  0.6982, -0.0476,  1.9028, -0.2640,
          1.3070, -0.0214, -0.2487,  0.2917,  0.5369,  0.0881,  1.0006, -0.9031,
         -0.2213, -0.6106, -0.2551, -0.0557,  0.0485,  0.2072, -0.7861, -1.8906,
         -0.4887, -0.2578, -1.5979,  1.5447,  0.2110, -0.4547, -0.9896,  1.3935],
        [ 0.2197, -2.0798, -0.1218,  1.0459, -0.8955, -0.6874, -0.5732, -0.9338,
          0.1792,  0.7945, -0.7944, -0.6989,  0.9299,  0.0338,  0.3088,  2.5791,
         -1.3677, -0.9097,  1.1450,  0.1276, -0.9261,  1.0553, -0.3282, -0.1189,
          1.5103, -1.3650, -0.3967,  1.0384,  0.2777,  0.3414,  0.9193, -0.3074],
        [-0.0748, -0.5401,  0.0269, -0.1526, -0.3768,  0.0337,  2.0632,  0.1679,
         -0.1358,  1.9571,  0.1846, -0.4047, -0.7504, -0.3845, -0.0075,  0.4695,
          0.9000,  0.0553,  0.5035, -0.2207, -0.7480, -0.7641, -0.3434, -0.7415,
          0.1172,  0.7904,  1.4750,  0.4144, -1.8454, -0.5371,  1.5499,  0.5853],
        [-0.2454,  1.1008

In [None]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [None]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [None]:
postional_embedding(torch.arange(5, device=device)).shape

torch.Size([5, 32])

In [None]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[ 7.8445e-01, -7.3034e-02, -1.2135e+00, -6.9916e-01, -1.3627e+00,
          9.4943e-01,  1.4097e+00,  7.7847e-02,  7.7477e-01,  8.2858e-01,
         -1.8097e+00,  1.3255e+00, -8.3703e-01, -9.8567e-01,  2.3934e+00,
         -2.2901e+00, -3.5064e-02, -5.5237e-01,  4.5043e-01, -1.5222e+00,
          4.0932e-01,  1.7884e+00, -1.1933e+00, -1.1935e+00, -1.3306e+00,
         -2.6845e-01, -2.0549e+00,  9.1813e-01, -1.6675e-01, -1.1061e+00,
         -8.9780e-01,  2.9494e+00],
        [ 4.6685e-03, -2.2583e+00, -6.0004e-01,  3.0041e-01,  1.0476e-01,
         -1.3330e+00, -9.6764e-01, -1.9530e+00,  2.0536e+00,  9.8330e-01,
         -4.1896e-02, -1.0755e+00,  1.5022e+00,  5.9721e-01,  5.1236e-02,
          1.5949e+00, -2.2966e+00,  2.9182e-01,  2.7345e-01,  1.2579e+00,
         -2.1147e-01,  2.5808e+00, -4.8137e-01, -2.1678e-01,  1.9560e+00,
          3.4885e-02, -1.4200e+00, -4.0003e-01,  1.7166e-01,  2.7913e-01,
          2.4373e+00, -1.7765e-01],
        [-1.1957e+00, -1.7523e+00, -8.62

In [None]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [None]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [None]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.5445e-14, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.8295e-01, 1.2970e-07, 8.1705e-01, 0.0000e+00, 0.0000e+00],
         [9.9945e-01, 6.1713e-08, 7.3411e-09, 5.5259e-04, 0.0000e+00],
         [2.6260e-17, 6.8185e-03, 9.8692e-01, 3.8821e-06, 6.2569e-03]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[ 0.2753,  0.7028,  0.6595,  0.3956, -1.5970,  0.8216, -0.1322, -0.6431,
          -0.4727, -0.7208, -0.7353,  0.1581, -0.4889,  1.3761,  0.2936,  0.6707,
           0.3895,  0.0179, -1.7495, -1.0783,  0.0419, -0.1996,  0.4298, -0.9294,
          -0.2900, -0.1771,  0.3229,  0.6348, -1.5061,  0.1425,  1.2432,  0.4923],
         [ 0.2753,  0.7028,  0.6595,  0.3956, -1.5970,  0.8216, -0.1322, -0.6431,
          -0.4727, -0.7208, -0.7353,  0.1581, -0.4889,  1.3761,  0.2936,  0.6707,
           0.3895,  0.0179, -1.7495, -1.0783,  0.0419, -0.1996,  0.4298, -0.9294,
          -0.2900, -0.1771,  0.3

In [None]:
########################
# MULTI HEAD ATTENTION #
########################

In [None]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size, bias=False)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size, bias=False)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size, bias=False)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(p=dropout)

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) * C ** -0.5
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    out = wei @ v

    return out

In [None]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
    self.dropout = nn.Dropout(p=dropout)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    out = self.dropout(out)
    return out

In [None]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [None]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
t_ma = multi_head_attention(embedded_tensor)
t_ma

tensor([[[-1.9018e+00, -6.1596e-01, -3.6287e-01, -2.5908e-01, -5.0330e-01,
          -7.6079e-01,  4.2476e-02, -1.1918e+00,  1.1374e-01,  3.1972e-02,
           1.2544e+00, -0.0000e+00,  4.7994e-01, -4.9977e-01,  0.0000e+00,
           3.8351e-01,  2.4643e-01, -4.8315e-01,  6.7684e-01,  2.1209e-01,
          -1.4506e-01,  9.6667e-01, -9.6921e-02,  1.6085e+00,  1.0439e+00,
          -7.8131e-01,  9.3170e-01, -2.3281e-01,  3.0226e-01, -2.2466e-01,
           5.5134e-01, -1.2402e+00],
         [-1.4795e+00, -0.0000e+00, -9.2388e-03, -5.4363e-01, -5.6743e-01,
          -3.6272e-01,  3.2711e-01, -6.8869e-01,  1.9443e-01,  2.3666e-01,
           6.1336e-01, -3.3735e-01,  6.2848e-01, -0.0000e+00, -5.8455e-02,
           1.5836e-01,  0.0000e+00, -0.0000e+00,  5.5047e-02,  0.0000e+00,
          -1.3570e-01,  5.4271e-01, -2.4354e-01,  0.0000e+00,  3.1379e-01,
          -8.5224e-01,  5.2351e-01, -5.9338e-01, -1.1881e-01,  2.6497e-01,
           0.0000e+00, -8.4700e-01],
         [-1.0377e+00,  8.

In [None]:
################
# FEED-FORWARD #
################

In [None]:
feed_forward_layer = nn.Sequential(
    nn.Linear(32, 4 * 32),
    nn.Linear(4 * 32, 32)
)

t_ff = feed_forward_layer(t_ma)
t_ff

tensor([[[-6.9561e-02, -1.4250e-01,  3.1070e-02, -1.8291e-01,  2.4161e-01,
          -3.7675e-01,  3.2730e-02,  1.8323e-02, -4.9823e-01, -2.0057e-01,
          -1.2709e-01,  3.3223e-01,  7.4427e-02, -1.7740e-01,  6.0727e-01,
           1.7553e-01, -2.5639e-01,  1.9310e-01,  1.2692e-01,  4.8282e-02,
          -2.8849e-01, -5.7440e-03, -3.6806e-02, -1.4971e-01,  3.0018e-01,
           3.5638e-01, -1.1089e-01, -9.4508e-02,  4.0093e-01,  2.0028e-01,
           5.2127e-02, -1.3019e-01],
         [-2.9404e-02, -1.3121e-01, -8.4553e-02, -1.8186e-01,  1.9564e-01,
          -1.9973e-01,  3.8148e-03,  3.2353e-03, -2.4055e-01, -1.9473e-01,
           5.5860e-02,  3.2112e-01,  5.1088e-02, -2.8851e-02,  3.5166e-01,
           2.5772e-02, -4.2427e-03, -4.6352e-02,  2.1989e-01,  9.1748e-02,
          -2.2866e-01,  9.1045e-02, -1.9676e-01, -1.7347e-01,  1.5584e-01,
           2.1007e-01,  4.9586e-02, -3.3962e-02,  1.9153e-01,  2.2269e-01,
          -1.6246e-02, -5.1122e-02],
         [-5.6009e-02,  1.

In [None]:
########################
# BLOCK AND LAYER-NORM #
########################

In [None]:
class AttentionBlock(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    self.multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
    self.feed_forward = nn.Sequential(
        nn.Linear(in_features=embedding_dim, out_features=4 * embedding_dim),
        nn.ReLU(),
        nn.Linear(4 * embedding_dim, embedding_dim),
        nn.Dropout(p=dropout)
    )
    self.ln_1 = nn.LayerNorm(embedding_dim)
    self.ln_2 = nn.LayerNorm(embedding_dim)

  def forward(self, x):
    x = x + self.multi_head_attention(self.ln_1(x))
    x = x + self.feed_forward(self.ln_2(x))
    return x

In [None]:
attention_block = AttentionBlock(embedding_dim=embedding_dim, n_heads=n_heads)
t_block_2 = attention_block(t_ff)
t_block_2

tensor([[[-4.6705e-02, -2.2203e-01,  5.1181e-01,  1.3837e+00, -5.3916e-01,
          -6.9596e-01,  1.3463e-01,  1.2832e-01, -8.7962e-01,  4.9494e-01,
          -1.5083e-01,  2.6517e-01, -1.3036e-01,  8.1951e-02,  1.5051e+00,
          -5.4201e-01, -2.7874e-01,  1.1454e+00,  8.0541e-02, -1.8282e-01,
          -3.2900e-01,  9.9515e-02, -3.8351e-01, -2.0474e-01,  6.5760e-01,
          -8.0057e-02, -4.3272e-01, -5.6177e-02,  3.3964e-01, -5.2800e-01,
           7.3702e-02, -4.5352e-01],
         [-4.6152e-01, -6.1738e-02, -1.6903e-01,  2.7621e-01, -1.6193e-01,
          -5.6841e-01,  1.1046e-01, -9.5739e-02, -2.7522e-01, -2.1133e-02,
          -2.3048e-01,  2.9836e-01, -2.5770e-01, -1.8070e-01,  5.7188e-01,
           4.0308e-02,  1.4410e-01,  2.8084e-01,  3.3511e-01, -6.1585e-01,
          -4.9877e-01,  4.1035e-01, -5.4315e-01, -4.6382e-02,  7.8154e-02,
          -7.0194e-01, -2.4573e-01,  2.1918e-02,  3.4481e-01, -1.3541e-01,
          -2.4991e-01, -5.2442e-01],
         [-2.5963e-02,  1.

In [None]:
#############
# (LM) HEAD #
#############

In [None]:
lm_head = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
logits = lm_head(t_block_2)
logits

tensor([[[-6.0516e-01,  1.5655e-01,  2.6236e-01,  1.8331e-01, -2.7751e-01,
          -1.2413e-01, -2.6983e-03, -1.3559e-01, -2.5169e-02, -3.4957e-01,
          -3.7576e-01, -4.8344e-01, -3.9256e-01,  1.7772e-01, -3.6845e-01,
           1.9012e-01,  1.5632e-01, -3.3650e-01, -3.7067e-01,  2.2451e-01,
          -2.8597e-01,  1.5639e-01,  1.6548e-01, -2.7145e-01, -6.3642e-01,
          -3.3663e-02, -3.7468e-01, -1.4023e-01,  9.9101e-02, -1.1596e-01,
          -3.2395e-01, -2.2003e-01, -9.6624e-01,  4.5272e-02, -6.0836e-02,
          -1.4989e-01,  2.2971e-01, -2.6963e-01, -1.5948e-01,  2.7682e-01,
           4.6036e-01,  4.0381e-01, -6.1989e-01, -2.8557e-01, -5.4751e-01,
           4.7636e-01,  1.6476e-01,  6.0592e-01, -6.0123e-02, -6.1876e-02,
           9.4437e-02, -4.7597e-01, -1.3291e-01,  5.8116e-01, -2.9419e-01,
           6.9870e-02, -7.4187e-03,  3.0580e-03, -8.4652e-01,  1.6101e-01,
          -4.9168e-01, -1.6414e-01, -2.0172e-02, -2.7949e-02,  4.8693e-02],
         [-3.7021e-01, -

In [None]:
###########
# Softmax #
###########

In [None]:
probs = F.softmax(logits[:, -1, :], dim=-1)
encoded_token = torch.argmax(probs)
probs, encoded_token
decode([encoded_token.item()])

't'

In [None]:
#################
# TRAINING LOOP #
#################

In [None]:
#####################
# GENERATE FUNCTION #
#####################

In [None]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [None]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])