<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/gpts/GPT_dimensions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn

In [3]:
# Lets look at each of the layers
# 1) Encoding
# 2) Embedding
# 3) Positional Encoding
# 4) Attention: key, query, value
# 5) Feedforward
# 6) Block/Layernorm
# 7) Classification / LM Head

In [4]:
################
# DATA EXAMPLE #
################

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open(mode="r", file="input.txt") as f:
  text = f.read()

vocab = list(sorted(set(text)))

--2025-06-11 20:21:00--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-06-11 20:21:01 (27.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [5]:
##################
# HYPERPARAMTERS #
##################

In [6]:
vocab_size = len(vocab)
embedding_dim = 32
block_size = 8

In [7]:
############
# ENCODING #
############

In [8]:
stoi = {v:k for k,v in enumerate(vocab)}
itos = {k:v for k,v in enumerate(vocab)}
encode = lambda seq: [stoi[char] for char in seq]
decode = lambda numbers: "".join([itos[num] for num in numbers])

def get_batch(split: str):
  dataset = train if split == "train" else val

In [9]:
#############
# EMBEDDING #
#############

In [10]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [11]:
t = torch.tensor(encode("Haubi"))

In [12]:
embedding(t)

tensor([[ 8.4160e-02,  1.2279e+00,  3.6452e-01,  1.1919e+00,  8.0834e-01,
          6.7032e-02, -2.1322e+00, -7.1546e-01,  4.9774e-01, -7.0984e-01,
          4.0165e-01,  6.0844e-01, -2.0802e-01,  2.2465e-02,  4.6868e-01,
         -1.5008e+00, -1.4298e+00, -1.1489e+00,  1.5756e+00, -1.8161e-01,
          1.0268e+00,  1.9409e+00,  9.2082e-04, -3.5998e-01, -2.5109e-01,
         -5.2287e-03,  2.5216e-01,  3.8218e-02, -1.7786e+00, -9.1354e-01,
         -1.7150e+00,  9.9102e-01],
        [ 1.2272e+00, -1.3678e+00, -9.7782e-02, -4.2614e-01, -8.0029e-01,
         -7.6366e-01, -1.6327e+00,  1.2827e-01, -5.5607e-01, -1.1568e+00,
          6.7547e-01, -1.0693e+00,  7.9825e-02,  1.6315e+00,  1.3325e+00,
         -1.0740e+00, -8.8620e-01,  5.0538e-01,  1.9118e+00,  5.1301e-01,
         -3.4828e-01, -3.7403e-01,  3.5642e+00,  5.7410e-01, -1.1791e-02,
          2.3124e-01, -5.6447e-01, -1.3990e+00,  3.3921e-01,  6.5988e-01,
         -9.4784e-01, -1.1806e+00],
        [-3.6709e-01,  1.4608e+00, -7.51

In [13]:
#########################
# POSITIONAL 'ENCODING' #
#########################

In [14]:
postional_embedding = nn.Embedding(num_embeddings=5, embedding_dim=32)

In [15]:
postional_embedding(torch.arange(5)).shape

torch.Size([5, 32])

In [16]:
embedding(t) + postional_embedding(torch.arange(5, dtype=torch.long))

tensor([[ 0.6987,  1.7524,  0.8935,  1.5398, -1.5348,  0.1344, -2.5978, -0.9702,
          0.5046,  0.2203, -0.4480,  1.1483, -1.0087,  0.0417,  0.2136,  0.4590,
         -2.1786, -0.1652,  1.7199,  1.3189,  0.2216,  0.7922, -0.8150,  0.0548,
         -0.1591, -1.8776, -0.7800, -0.3977, -3.1736, -1.5457, -2.9193,  0.6388],
        [ 2.2365, -0.7150, -1.8626, -1.2776, -1.7358,  0.9312, -0.7338, -0.3099,
         -0.7670, -1.9667, -0.4987, -1.8199, -1.3588,  2.4415,  2.1411, -0.3670,
          1.4963, -1.2768,  2.0183, -0.0157, -1.0846, -1.9387,  2.6798,  0.2521,
          1.3079,  0.3426, -0.3875,  0.8383, -2.3966,  1.2665, -2.0983, -0.7075],
        [ 0.0077,  0.5554,  1.2895,  1.3382,  0.8501, -2.5709, -2.2213, -2.4642,
         -0.7230, -1.1186, -0.4687, -0.5210, -1.0520,  1.1431, -0.9622,  0.2143,
         -1.5823,  2.4622,  4.4765,  3.3754,  0.3913,  0.5983, -2.3291, -0.3625,
          1.4057, -0.3774,  0.2868, -2.5753,  1.5666,  0.2353,  0.2035, -1.5420],
        [-2.2283, -1.0532

In [17]:
#############
# ATTENTION #
#############

In [19]:
# since its not MultiHeadAttention input_dim == output_dim

key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)