In [1]:
import torch

In [None]:
"""
Architecture Overview:
1. Embedding: Token IDs -> Vectors (wte)
2. Stack of Blocks (Repeated L times):
   - RMSNorm
   - Attention (Mixing info between tokens)
   - RMSNorm
   - MLP (Processing info within a token)
3. Final Norm 
4. LMHead: Vectors -> Logits (Probabilities)
"""

import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class GPTConfig:
    """
    Hyperparameters for the model.
    """
    # ┌─────────────────────────────────────────────────────────┐
    # │           321M CONVERSATIONAL MODEL                     │
    # ├─────────────────────────────────────────────────────────┤
    # │  hidden_dim:        1024                                │
    # │  layers:            20                                  │
    # │  heads:             8                                   │
    # │  head_dim:          128                                 │
    # │  mlp_ratio:         3x                                  │
    # │  vocab_size:        32K                                 │
    # │  context_length:    1024                                │
    # │  embedding:         tied (input = output projection)    │
    # │  activation:        relu squared                        │
    # │  position encoding: RoPE                                │
    # ├─────────────────────────────────────────────────────────┤
    # │  TOTAL PARAMETERS:  243,269,632                        │
    # └─────────────────────────────────────────────────────────┘
    # No KV cache
    # No GQA

    hidden_dim: int = 1024 # hidden dimension
    n_layers: int = 20 # May need to reduce to 22 or 20
    n_heads: int = 8 # head dimension = hidden_dim / n_heads = 128
    mlp_ratio: int = 3
    vocab_size: int = 32*1024
    sequence_len: int = 1024


def norm(x):
    """
    RMSNorm (Root Mean Square Layer Normalization).
    Used to stabilize training by normalizing activation magnitudes.
    """
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))


# def apply_rotatory_positional_encoding(x: torch.Tensor, head_dim: int) -> torch.Tensor:
#     """
#     Apply RoPE to the input tensor.
#     """
#     B, T, C = x.size()
#     # TODO: Implement RoPE
    

In [30]:
config = GPTConfig()
config

GPTConfig(hidden_dim=1024, n_layers=20, n_heads=8, mlp_ratio=3, vocab_size=32768, sequence_len=1024)

In [31]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Causal Self Attention.
    
    1. Projects input to Q, K, V.
    2. Applies RoPE to Q, K for position info.
    3. Computes attention scores (Q @ K) to see how much each token cares about others. Aggregates values (V) based on scores.
    4. Projects output to mix information across heads.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.hidden_dim = config.hidden_dim
        self.head_dim = config.hidden_dim // config.n_heads

        # Linear projections for Query, Key, Value
        self.key = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.query = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)
        self.value = nn.Linear(self.hidden_dim, self.head_dim * self.n_heads, bias=False)

        # Output projection ("o"): mixes results from all heads back into n_embd
        self.proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()

        # 1. Projects input to Q, K, V.
        k = self.key(x).view(B, T, self.n_heads, self.head_dim)
        q = self.query(x).view(B, T, self.n_heads, self.head_dim)
        v = self.value(x).view(B, T, self.n_heads, self.head_dim)
        
        # 2. Applies RoPE to Q, K for position info.
        # TODO: Implement RoPE
        # k = apply_rotatory_positional_encoding(k, self.head_dim)
        # q = apply_rotatory_positional_encoding(q, self.head_dim)

        # 3. Computes attention scores (Q @ K) to see how much each token cares about others.
        q, k = norm(q), norm(k) # QK norm
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # Re-assemble the heads side by side and project back to residual stream
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # 4. Projects output to mix information across heads.
        y = self.proj(y)
        return y



In [32]:
attn = MultiHeadAttention(config)
attn

MultiHeadAttention(
  (key): Linear(in_features=1024, out_features=1024, bias=False)
  (query): Linear(in_features=1024, out_features=1024, bias=False)
  (value): Linear(in_features=1024, out_features=1024, bias=False)
  (proj): Linear(in_features=1024, out_features=1024, bias=False)
)

In [33]:
for param in attn.parameters():
    print(type(param), param.size())

<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])


In [34]:
class FeedForward(nn.Module):
    """
    Feed Forward Network (MLP).
    Processes each token independently (no mixing between tokens).
    Structure: Expand -> ReLU^2 -> Contract
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.proj_up = nn.Linear(config.hidden_dim, config.hidden_dim * config.mlp_ratio, bias=False)
        self.proj_down = nn.Linear(config.hidden_dim * config.mlp_ratio, config.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj_up(x)
        x = F.relu(x).square()
        x = self.proj_down(x)
        return x

In [35]:
ff = FeedForward(config)
ff

FeedForward(
  (proj_up): Linear(in_features=1024, out_features=3072, bias=False)
  (proj_down): Linear(in_features=3072, out_features=1024, bias=False)
)

In [36]:
# %pip install torchinfo

In [37]:
from torchinfo import summary

summary(attn, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

Layer (type:depth-idx)                   Output Shape              Param #
MultiHeadAttention                       [1, 1024, 1024]           --
├─Linear: 1-1                            [1, 1024, 1024]           1,048,576
├─Linear: 1-2                            [1, 1024, 1024]           1,048,576
├─Linear: 1-3                            [1, 1024, 1024]           1,048,576
├─Linear: 1-4                            [1, 1024, 1024]           1,048,576
Total params: 4,194,304
Trainable params: 4,194,304
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.19
Input size (MB): 4.19
Forward/backward pass size (MB): 33.55
Params size (MB): 16.78
Estimated Total Size (MB): 54.53

In [38]:
summary(ff, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

Layer (type:depth-idx)                   Output Shape              Param #
FeedForward                              [1, 1024, 1024]           --
├─Linear: 1-1                            [1, 1024, 3072]           3,145,728
├─Linear: 1-2                            [1, 1024, 1024]           3,145,728
Total params: 6,291,456
Trainable params: 6,291,456
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 6.29
Input size (MB): 4.19
Forward/backward pass size (MB): 33.55
Params size (MB): 25.17
Estimated Total Size (MB): 62.91

In [39]:
class TransformerBlock(nn.Module):
    """
    A single Transformer Block.
    Contains:
    1. Attention (Communication)
    2. MLP (Computation)
    Both use Residual Connections (x + ...) and Pre-Norm.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ff = FeedForward(config)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Attention with residual connection
        x = self.attn(norm(x)) + x
        # MLP with residual connection
        x = self.ff(norm(x)) + x
        return x


In [40]:
block = TransformerBlock(GPTConfig())
block

TransformerBlock(
  (attn): MultiHeadAttention(
    (key): Linear(in_features=1024, out_features=1024, bias=False)
    (query): Linear(in_features=1024, out_features=1024, bias=False)
    (value): Linear(in_features=1024, out_features=1024, bias=False)
    (proj): Linear(in_features=1024, out_features=1024, bias=False)
  )
  (ff): FeedForward(
    (proj_up): Linear(in_features=1024, out_features=3072, bias=False)
    (proj_down): Linear(in_features=3072, out_features=1024, bias=False)
  )
)

In [41]:
summary(block, input_size=(1, config.sequence_len, config.hidden_dim), dtypes=[torch.float32])

Layer (type:depth-idx)                   Output Shape              Param #
TransformerBlock                         [1, 1024, 1024]           --
├─MultiHeadAttention: 1-1                [1, 1024, 1024]           --
│    └─Linear: 2-1                       [1, 1024, 1024]           1,048,576
│    └─Linear: 2-2                       [1, 1024, 1024]           1,048,576
│    └─Linear: 2-3                       [1, 1024, 1024]           1,048,576
│    └─Linear: 2-4                       [1, 1024, 1024]           1,048,576
├─FeedForward: 1-2                       [1, 1024, 1024]           --
│    └─Linear: 2-5                       [1, 1024, 3072]           3,145,728
│    └─Linear: 2-6                       [1, 1024, 1024]           3,145,728
Total params: 10,485,760
Trainable params: 10,485,760
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 10.49
Input size (MB): 4.19
Forward/backward pass size (MB): 67.11
Params size (MB): 41.94
Estimated Total Size (MB): 113.25

In [42]:
class GPT(nn.Module):
    """
    The full GPT model.
    Contains:
    1. Token Embedding
    2. Transformer Blocks (stacked)
    3. Final Normalization
    4. Output Head (LM Head). Same as input embedding.
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        # TODO: check how to tie the weights        
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.token_embedding(x)
        x = norm(x)
        for block in self.blocks:
            x = block(x)
        x = norm(x)
        logits = self.lm_head(x)
        # logits = softcap * torch.tanh(logits / softcap) # squash the logits
        return logits

In [19]:
# gpt.to('cuda')

In [44]:
summary(gpt, input_size=(1, config.sequence_len), dtypes=[torch.long])

Layer (type:depth-idx)                   Output Shape              Param #
GPT                                      [1, 1024, 32768]          --
├─Embedding: 1-1                         [1, 1024, 1024]           33,554,432
├─ModuleList: 1-2                        --                        --
│    └─TransformerBlock: 2-1             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-1      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-2             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-2             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-3      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-4             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-3             [1, 1024, 1024]           --
│    │    └─MultiHeadAttention: 3-5      [1, 1024, 1024]           4,194,304
│    │    └─FeedForward: 3-6             [1, 1024, 1024]           6,291,456
│    └─TransformerBlock: 2-4       

In [45]:
# The weight tying is working correctly — torchinfo just doesn't detect shared parameters by default. 
# It counts each layer's parameters independently.

# This counts UNIQUE parameters (correct count with tying)
real_params = sum(p.numel() for p in gpt.parameters())
print(f"Actual unique parameters: {real_params:,}")

# 318M - 32M (duplicate embedding) ≈ 286M

Actual unique parameters: 243,269,632


In [46]:
# These should all be True
print("Same object:", gpt.lm_head.weight is gpt.token_embedding.weight)
print("Same memory:", gpt.lm_head.weight.data_ptr() == gpt.token_embedding.weight.data_ptr())

Same object: True
Same memory: True


In [47]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 0.00 GB
Cached: 0.00 GB
