It's just a reproduction of https://michaelwornow.net/2024/01/18/counting-params-in-transformer

In [28]:
import torch
import math

from transformers import AutoModel

In [29]:
from transformers import AutoModel

In [30]:
model = AutoModel.from_pretrained("gpt2")

def count_trainable_params(model, is_human: bool = False):
    params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return f"{params / 1e6:.2f}M" if is_human else params
def count_untrainable_params(model, is_human: bool = False):
    params: int = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return f"{params / 1e6:.2f}M" if is_human else params

In [31]:
print(model)

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)


In [32]:
print("Total # of trainable parameters:", count_trainable_params(model, is_human=True))
print("Total # of untrainable parameters:", count_untrainable_params(model, is_human=True))

Total # of trainable parameters: 124.44M
Total # of untrainable parameters: 0.00M


`(wte): Embedding(50257, 768)`

Word Token Embedding (WTE), which means: 
* $V (\text{total number of tokens in our vocabulary}) = 50257$

It is a matrix of size $V(50257)$ by $E(768)$.

In other words, our vocabulary has a total of $50257$ unique tokens, and each token is represented by a dense vector of $768$ floating point numbers.

Besides, $E (\text{size of the embedding vector}) = 768$

$V∗E=50257∗768=38,603,776$

`(wpe): Embedding(1024, 768)`

Word Position Embedding (WPE), which means:
* $P \text{(the maximum sequence length that our model can handle)} = 1024$
It is a matrix of size $P (1024)$ by $E (768)$.

This means that the maximum sequence length that our model can handle is $1024$ tokens. This is also referred to as the “context window.”

$Params = P∗ E = 1024 ∗ 768 = 786,432$

The embeddings from these two layers will get added together to create “position-aware” embeddings of our input tokens.

In [33]:
V: int = model.config.vocab_size
E: int = model.config.n_embd
P: int = model.config.n_positions
expected_wte = V * E
expected_wpe: int = P * E
print(f"wte | Expected: {expected_wte}")
print(f"wte | True:     {count_trainable_params(model._modules['wte'])}")
print(f"wpe | Expected: {expected_wpe}")
print(f"wpe | True:     {count_trainable_params(model._modules['wpe'])}")

wte | Expected: 38597376
wte | True:     38597376
wpe | Expected: 786432
wpe | True:     786432


### Transformer Layers

```bash
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
```
Let's breakdown the components of a transformer layer.

`ln_1`:

* This is a `LayerNorm` layer.
* This is responsible for "normalizing" the input before it is passed to the attention layer. It normalizes across the last dimension, which is the embedding dimension. This means that the values along the embedding dimension will be normally distributed with mean of 0 and a standard deviation of 1.
* The `eps=1e-5 parameter` is the value ϵ added to the denominator. It is used for numerical stability, to prevent division by zero.
* The `elementwise_affine=True` parameter means that the layer will learn a bias β and gain γ for each embedding dimension.
* The formula for LayerNorm is as follows:

$$
\text{LayerNorm}(x) = \gamma \cdot \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} + \beta \\
\text{where:} \\
E[x] = \frac{1}{n}\sum_{i=1}^n x_i \quad \text{(mean)} \\
Var[x] = \frac{1}{n}\sum_{i=1}^n (x_i - \mu)^2 \quad \text{(variance)}
$$



`E[x]` and `Var[x]` are calculated on the fly as the mean of the input (x) across the embedding dimension.
Thus, the only learnable parameters here are β and γ, which are vectors of size E (768). 

Params = 2 * E = 2 * 768 = 1536

In [34]:
expected_ln_1 = 2 * E
print(f"ln_1 | Expected: {expected_ln_1}")
print(f"ln_1 | True:     {count_trainable_params(model._modules['h'][0].ln_1)}")

ln_1 | Expected: 1536
ln_1 | True:     1536


In [35]:
from torch.nn import LayerNorm
# NLP example cf. https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
batch, sentence_length, embedding_dim = 20, 5, 10
esp = 1e-05
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = LayerNorm(embedding_dim, eps=esp, elementwise_affine=True, bias=True)
print("NPL example: LayerNorm(embedding_dim) = ", layer_norm(embedding).shape)
# Parameters:
# - `normalized_shape`: The shape of the input features to be normalized. For embeddings, this is typically (embedding_dim,)
# - `eps`: A small value added for numerical stability when normalizing
# - `elementwise_affine`: If True, learns an affine transform (gamma, beta) per feature (Default: True)
# - `bias`: If True, adds a learnable bias term β after the affine transform. Only relevant if elementwise_affine is True. (Default: True)

# How LayerNorm works:
# 1. For each sample in the batch:
#    - Compute mean μ and variance σ² across the feature dimension
#    - Normalize: x̂ = (x - μ) / sqrt(σ² + eps)
#    - If elementwise_affine=True: y = γ * x̂ + β

# Let's calculate LayerNorm manually to understand how it works under the hood
x = embedding  # Shape: (batch, sentence_length, embedding_dim)
print("\nManual LayerNorm calculation:")
print("Input x shape:", x.shape)

# 1. Calculate mean for each position
mean = x.mean(dim=-1, keepdim=True)  # Shape: (batch, sentence_length, 1)
print("Mean μ shape:", mean.shape)

# 2. Calculate variance for each position
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)  # Shape: (batch, sentence_length, 1)
print("Variance σ² shape:", var.shape)

# 3. Normalize
eps = 1e-05
x_norm = (x - mean) / torch.sqrt(var + eps)
print("Normalized x̂ shape:", x_norm.shape)

# 4. Since weight=1 and bias=0, the output equals x_norm
y = x_norm  # In general case: y = weight * x_norm + bias
print("Output y shape:", y.shape)

# Verify with PyTorch's LayerNorm
ln = LayerNorm(embedding_dim, elementwise_affine=True)
# Set weight to 1 and bias to 0
ln.weight.data.fill_(1.0)
ln.bias.data.fill_(0.0)
y_torch = ln(x)
print("\nPyTorch LayerNorm output shape:", y_torch.shape)
print("Maximum difference:", (y - y_torch).abs().max().item())

# Print first batch, first position values for comparison
print("\nFirst batch, first position comparison:")
print("Manual output[0,0]:", y[0,0])
print("PyTorch output[0,0]:", y_torch[0,0])



NPL example: LayerNorm(embedding_dim) =  torch.Size([20, 5, 10])

Manual LayerNorm calculation:
Input x shape: torch.Size([20, 5, 10])
Mean μ shape: torch.Size([20, 5, 1])
Variance σ² shape: torch.Size([20, 5, 1])
Normalized x̂ shape: torch.Size([20, 5, 10])
Output y shape: torch.Size([20, 5, 10])

PyTorch LayerNorm output shape: torch.Size([20, 5, 10])
Maximum difference: 3.5762786865234375e-07

First batch, first position comparison:
Manual output[0,0]: tensor([ 0.9932, -0.3416, -0.4775, -1.2895,  0.0837, -2.1132,  0.9516,  0.7095,
         0.8504,  0.6334])
PyTorch output[0,0]: tensor([ 0.9932, -0.3416, -0.4775, -1.2895,  0.0837, -2.1132,  0.9516,  0.7095,
         0.8504,  0.6334], grad_fn=<SelectBackward0>)


```bash
(attn): GPT2SdpaAttention(
    (c_attn): Conv1D(nf=2304, nx=768)
    (c_proj): Conv1D(nf=768, nx=768)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
)
```

`attn`:
* This is a GPT2Attention layer, aka “self-attention”.
* This computes the self-attention scores between each token in the input sequence.
* It is comprised of four sub-layers:
    * `c_attn`:
        * This is a Conv1D layer.
        * This confused me for a while. What was this Conv1D layer doing in the middle of a transformer layer? I thought it was supposed to be an MLP? My understanding is that it is basically a linear layer, but with the weights transposed. I’m not sure what motivated this design decision, so if anyone knows please leave a comment.
        * It is responsible for transforming the input into the query, key, and value matrices for the attention calculation.
        * It is a matrix of size E (768) by 3 * E (2304) plus a bias vector of size 3 * E (2304). The 3 * E is because we have 3 inputs to the attention layer: the query, the key, and the value. Each of these inputs is a vector of size E (768), so we have to generate a total of 3 * E (2304) elements.
    * `c_proj`:
        * This is a Conv1D layer.
        * It is responsible for combining the outputs of the attention heads (in our case, there are 12 heads amongst which 768 dims are equally divided, which gives each head a 64-dim output).
        * It is a matrix of size E (768) by E (768) plus a bias vector of size E (768).
    * `attn_dropout`:
        * This is a Dropout layer.  It is responsible for dropping out a fraction (p=0.1) of activations post-attention calculation during training. __This has no trainable parameters__.
    * `resid_dropout` is a Dropout layer.  It is responsible for dropping out a fraction (p=0.1) of activations post-projection during training. __This has no trainable parameters.__

> Residual Dropout We apply dropout [ 33] to the output of each sub-layer, before it is added to the sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of $P_{drop} = 0.1$.

https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html#scaled-dot-product-attention

In [43]:
attn = model._modules['h'][1].attn
print(f"c_attn shape: {attn.c_attn.weight.shape}")
print(f"c_proj shape: {attn.c_proj.weight.shape}")


c_attn shape: torch.Size([768, 2304])
c_proj shape: torch.Size([768, 768])
c_attn shape: torch.Size([768, 2304])
c_proj shape: torch.Size([768, 768])


In [47]:
## Understanding the attention layer

In [46]:
import inspect

# Print the actual forward function of the attention layer
print("GPT2Attention forward function:")
print(inspect.getsource(attn.forward))

GPT2Attention forward function:
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if output_attentions or head_mask is not None:
                "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
                "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
                "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
            

# Self-attention vs Cross-attention

Self-attention and cross-attention are two key variants of the attention mechanism:

## Self-attention
- Takes a single sequence as input and learns relationships between all positions within that sequence
- Query, Key, and Value matrices are all derived from the same input sequence
- Used in encoder and decoder self-attention layers
- Helps model understand internal relationships and dependencies within a sequence
- Example: In "The cat sat on the mat", self-attention helps relate "cat" to "sat" to understand who performed the action

## Cross-attention 
- Takes two different sequences as input and learns relationships between positions across the sequences
- Query matrix comes from one sequence, while Key and Value matrices come from another sequence
- Used in encoder-decoder attention layers to relate decoder states to encoder outputs
- Helps model connect and align information between different sequences
- Example: In translation, cross-attention helps align words in source language to corresponding words in target language

The key difference is that self-attention operates within a single sequence, while cross-attention operates across two different sequences. GPT-2 uses only self-attention since it's a decoder-only model.
So let's extract relevant part from the above forward function.

```python
# attn forward function (simplified)
def forward(
    self,
    hidden_states: Optional[Tuple[torch.FloatTensor]],
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
    bsz, q_len, _ = hidden_states.size()

    # Initial attention projections
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

    query = self._split_heads(query, self.num_heads, self.head_dim)
    key = self._split_heads(key, self.num_heads, self.head_dim)
    value = self._split_heads(value, self.num_heads, self.head_dim)
    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
    is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=self.attn_dropout.p if self.training else 0.0,
        is_causal=is_causal,
    )

    # Reshape outputs
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.view(bsz, q_len, self.embed_dim)

    # Final projection
    attn_output = self.c_proj(attn_output)
    attn_output = self.resid_dropout(attn_output)

    return attn_output
```

In [27]:

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value