<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [3]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-15 07:51:22--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-06-15 07:51:23 (17.7 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
##################
# HYPERPARAMTERS #
##################

In [46]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8
n_heads = 4

In [6]:
############
# ENCODING #
############

In [7]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [8]:
#############
# EMBEDDING #
#############

In [9]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [10]:
t = torch.tensor(encode("Haubi"))

In [11]:
embedding(t)

tensor([[-0.9273,  0.3304,  0.6638, -0.4366,  0.4378, -0.2666, -0.0975,  0.1392,
         -0.0128, -0.0692,  0.5741,  0.4418, -1.3849,  1.0975,  0.5832, -0.9658,
         -1.6483, -0.4084,  0.8306, -1.4172,  1.6174, -0.6841, -0.9252, -0.5369,
          0.9811,  0.0101, -0.7452,  0.7741,  0.2763, -1.3349, -0.2718, -0.0312],
        [-2.2190, -0.3504, -1.6768,  1.9151,  0.4400, -0.1605,  0.4397, -2.0158,
         -0.6278,  0.4262,  0.4565, -1.2894,  1.0605, -1.9683, -0.7313,  1.2038,
          0.3054, -1.1260,  0.5817,  1.1115,  0.7540, -1.0459,  0.2068, -0.8568,
         -0.4302, -1.0945, -0.6069,  0.7064, -1.3334, -1.1765, -2.6555,  0.8990],
        [ 0.6363,  0.4486, -1.2197,  1.4263, -0.2736,  0.5772, -0.3240,  0.7389,
          0.4402,  1.2262,  0.3455, -0.4795, -0.6717, -0.3513, -1.7074,  0.2935,
         -1.6747, -0.1381, -0.6014,  0.8597,  0.8348,  0.7566, -0.8566, -0.8768,
         -0.9643, -1.0129,  1.2567,  1.5986,  0.5272, -1.0388, -0.5194,  0.6921],
        [-0.7371,  0.6034

In [12]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [13]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [14]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [15]:
embedded_tensor = embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))
embedded_tensor

tensor([[-1.5310,  0.2781,  0.3253,  0.6672,  0.9284,  0.3864, -0.1435, -2.7747,
         -1.3929, -0.6017,  1.2140, -0.2050, -1.5046,  3.3244,  0.3451, -1.3694,
         -0.5394, -2.4792,  0.3771, -1.1443,  0.1556, -0.8633, -0.0871, -1.5353,
          0.2487,  0.4565, -0.9097, -0.9768,  1.6561, -1.6015, -2.0356, -0.3675],
        [-1.4176,  0.1038, -4.5760,  0.6590,  0.8875,  0.1559,  1.2824, -2.1702,
          1.0330,  2.2309,  1.3258, -1.8062,  2.2117, -1.1957, -0.7506,  0.1473,
         -1.1387,  0.1314, -0.4494, -0.0161,  0.6809, -1.1862,  0.0420, -1.5893,
          0.6788,  0.5010, -0.5437,  0.7290, -1.6392, -1.6903, -3.4132,  1.2522],
        [ 0.1742,  0.2318, -0.1092,  2.4308,  0.5409,  0.5441, -1.9386,  0.3867,
          1.1274,  0.6398,  0.8996,  0.4444,  0.0757,  0.2784, -1.9436,  1.0691,
         -1.5470,  0.9550,  0.2697,  1.8185,  1.8194,  0.8487, -1.3616, -0.9076,
         -1.0161, -2.3691, -0.0755,  3.9784, -1.4509, -1.5981, -1.1311, -0.5237],
        [-0.9978,  3.2911

In [16]:
#########################
# SINGLE HEAD ATTENTION #
#########################

In [63]:
# since its not MultiHeadAttention input_dim == output_dim
block_size = len(t) if t is not None else block_size

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
tril = torch.tril(torch.ones(block_size, block_size, dtype=torch.long))

In [41]:
k = key_layer(embedded_tensor)
q = query_layer(embedded_tensor)
wei = (q @ k.transpose(-2,-1)) / (embedding_dim ** -0.5)
wei = wei.masked_fill(tril == 0, -float("inf"))
wei = F.softmax(wei, dim=-1)
v = value_layer(embedded_tensor)

out = wei @ v
wei, out

(tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [4.4617e-10, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.4815e-26, 9.9050e-08, 1.0000e+00, 0.0000e+00, 0.0000e+00],
         [4.8059e-18, 3.7613e-24, 2.9675e-08, 1.0000e+00, 0.0000e+00],
         [9.9995e-01, 5.4872e-05, 2.8917e-22, 1.4990e-29, 6.6640e-10]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[ 0.4012,  0.1096,  1.2421, -0.1876, -0.6656,  0.4108, -0.2936, -0.5350,
           0.0802,  0.0615,  0.0282, -0.6259,  0.9548,  0.7146, -0.3993, -0.5750,
          -0.4476, -0.1339, -0.5695, -1.1833,  0.4406, -0.4054, -0.1534, -0.1182,
           0.3593, -0.2678, -0.6013, -1.0542,  0.3007, -0.0448, -0.1898,  0.6341],
         [-0.3699, -1.8706,  0.2445, -0.9308, -0.1890,  0.6169, -0.5355, -0.8250,
           1.4217, -1.2424,  0.8778, -0.7089,  0.2601, -0.4499, -1.1164,  0.6310,
           0.5790, -0.7322,  0.8378, -1.2645,  1.5278,  1.0927, -0.7256, -0.0085,
           0.1970,  1.0582,  0.0

In [42]:
########################
# MULTI HEAD ATTENTION #
########################

In [71]:
class AttentionHead(nn.Module):

  def __init__(self, embedding_dim, head_size):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.query_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.key_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.value_layer = nn.Linear(in_features=embedding_dim, out_features=head_size)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.long)))

  def forward(self, x):
    B, T, C = x.shape
    q = self.query_layer(x)
    k = self.key_layer(x)
    v = self.value_layer(x)
    wei = (q @ k.transpose(-2,-1)) / C ** -0.5
    wei = wei.masked_fill(self.tril == 0, float("-inf"))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v

    return out

In [72]:
class MultiHeadAttention(nn.Module):

  def __init__(self, embedding_dim, n_heads):
    super().__init__()
    head_size = embedding_dim // n_heads
    self.heads = nn.ModuleList([AttentionHead(embedding_dim=embedding_dim, head_size=head_size) for _ in range(n_heads)])
    self.projection_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.projection_layer(x)
    return out

In [61]:
embedded_tensor = embedded_tensor.unsqueeze(0)
embedded_tensor.shape

torch.Size([1, 5, 32])

In [73]:
multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
multi_head_attention(embedded_tensor)

tensor([[[-0.3960,  0.1228,  0.1425,  0.2320,  0.8935,  0.0596, -0.0883,
          -0.2751,  0.0022, -0.0870,  0.5670, -0.2230,  0.6527, -0.2754,
          -0.3198,  0.2378, -0.8303, -0.1637, -0.4679, -0.4889, -0.5391,
           0.5433,  0.1891,  1.0591,  0.1788, -0.5837, -0.3828, -0.1633,
          -0.3994, -0.1508,  0.9394,  0.7019],
         [-0.5248, -0.1877, -0.0716, -0.5373,  0.7200, -0.4733, -0.2121,
          -0.4697, -0.5646, -0.4163,  0.2752, -0.0155,  0.1934, -0.1055,
           0.4372,  0.4479, -0.4577,  0.0885, -0.7893, -0.7398, -0.4251,
           0.0441,  0.2180,  0.5846,  0.5314, -0.0373, -0.3818, -0.3136,
           0.0224,  0.2330,  0.3222,  0.6853],
         [-1.0075,  0.1129, -0.0110, -0.3201,  0.3604,  0.4122,  0.0969,
          -0.6387,  0.4003, -0.4121,  0.6565, -0.6626,  0.1750, -0.6181,
          -0.1952, -0.3119,  0.7549,  0.5559, -0.9156, -0.8544, -0.5024,
           0.0740,  0.6636,  0.2284,  0.8663,  0.1074,  0.5612, -0.5977,
           0.5530,  1.0429,  0

In [50]:
test = torch.ones(4,8,3)
test

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [31]:
torch.sum(test, dim=-1)

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3., 3.]])