In [1]:
import torch
from transformers import DistilBertModel

# Load pretrained DistilBERT model
model = DistilBertModel.from_pretrained("distilbert-base-uncased")

# Extract the first transformer's multi-head attention layer
layer = model.transformer.layer[0]
attention_layer = layer.attention

# Override the forward function to print intermediate steps
def debug_mha_forward(self, query, key, value, mask=None, head_mask=None):
    print("\n==== Multi-Head Attention Debugging ====\n")

    # Step 1: Print Input Shapes
    print(f"Step 1: Inputs to MHA")
    print(f"  Query Shape: {query.shape}")  # Expected: (batch, seq_len, hidden_size)
    print(f"  Key Shape: {key.shape}")
    print(f"  Value Shape: {value.shape}\n")

    batch_size, seq_length, hidden_size = query.shape

    # Step 2: Split into multi-heads
    def split_heads(tensor, num_heads, head_dim):
        return tensor.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)

    num_heads = self.n_heads
    head_dim = hidden_size // num_heads

    query = split_heads(query, num_heads, head_dim)
    key = split_heads(key, num_heads, head_dim)
    value = split_heads(value, num_heads, head_dim)

    print(f"Step 2: Split into {num_heads} Heads")
    print(f"  Query Split Shape: {query.shape}")  # (batch, num_heads, seq_len, head_dim)
    print(f"  Key Split Shape: {key.shape}")
    print(f"  Value Split Shape: {value.shape}\n")

    # Step 3: Compute Scaled Dot-Product Attention
    attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
    print(f"Step 3: Compute Attention Scores (QK^T / sqrt(d_h))")
    print(f"  Attention Scores Shape: {attention_scores.shape}")  # (batch, num_heads, seq_len, seq_len)\n

    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))

    attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
    print(f"Step 4: Apply Softmax to Get Attention Weights")
    print(f"  Attention Weights Shape: {attention_probs.shape}\n")  # (batch, num_heads, seq_len, seq_len)

    # Step 5: Compute SV = Attention Weights * V
    attention_output = torch.matmul(attention_probs, value)
    print(f"Step 5: Compute SV (Attention Weights * V)")
    print(f"  SV (Attention Output Before Merge) Shape: {attention_output.shape}\n")  # (batch, num_heads, seq_len, head_dim)

    # Step 6: Merge Heads Back
    attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
    print(f"Step 6: Merge Heads Back")
    print(f"  Final MHA Output Shape: {attention_output.shape}\n")  # (batch, seq_len, hidden_size)

    return attention_output

# Attach the debug function to the attention layer
attention_layer.sa_forward = debug_mha_forward.__get__(attention_layer)

# Create dummy input
input_tensor = torch.rand(1, 32, 768)  # (batch_size=1, seq_len=64, hidden_size=768)

# Run forward pass and print all steps
attention_layer.sa_forward(input_tensor, input_tensor, input_tensor)



==== Multi-Head Attention Debugging ====

Step 1: Inputs to MHA
  Query Shape: torch.Size([1, 32, 768])
  Key Shape: torch.Size([1, 32, 768])
  Value Shape: torch.Size([1, 32, 768])

Step 2: Split into 12 Heads
  Query Split Shape: torch.Size([1, 12, 32, 64])
  Key Split Shape: torch.Size([1, 12, 32, 64])
  Value Split Shape: torch.Size([1, 12, 32, 64])

Step 3: Compute Attention Scores (QK^T / sqrt(d_h))
  Attention Scores Shape: torch.Size([1, 12, 32, 32])
Step 4: Apply Softmax to Get Attention Weights
  Attention Weights Shape: torch.Size([1, 12, 32, 32])

Step 5: Compute SV (Attention Weights * V)
  SV (Attention Output Before Merge) Shape: torch.Size([1, 12, 32, 64])

Step 6: Merge Heads Back
  Final MHA Output Shape: torch.Size([1, 32, 768])



tensor([[[0.4388, 0.5095, 0.5035,  ..., 0.5491, 0.5187, 0.5283],
         [0.4288, 0.4965, 0.5083,  ..., 0.5577, 0.5238, 0.5086],
         [0.4437, 0.5094, 0.5236,  ..., 0.5610, 0.5125, 0.5377],
         ...,
         [0.4654, 0.5045, 0.5149,  ..., 0.5473, 0.5233, 0.5203],
         [0.4599, 0.5065, 0.5262,  ..., 0.5363, 0.5304, 0.5120],
         [0.4452, 0.5029, 0.5358,  ..., 0.5510, 0.5219, 0.5326]]])

In [3]:
import torch
from transformers import DistilBertModel

# Load DistilBERT model (pretrained)
model = DistilBertModel.from_pretrained("distilbert-base-uncased")

# Define input tensor (batch_size=1, sequence_length=64, hidden_size=768)
batch_size = 1
sequence_length = 32
hidden_size = 768

input_tensor = torch.rand(batch_size, sequence_length, hidden_size)  # Simulated input embeddings

# Extract first transformer layer
layer = model.transformer.layer[0]

# Define a helper function to track computation inside q_lin, k_lin, v_lin
def check_projection(layer, tensor, name):
    print(f"\n==== Checking {name} ====")
    print(f"Input Shape: {tensor.shape}")  # Should be (batch_size, S, d_model)

    # Forward pass through the projection layer
    output = layer.attention.__getattr__(name)(tensor)  # Equivalent to layer.attention.q_lin(tensor), etc.

    print(f"Weight Shape: {layer.attention.__getattr__(name).weight.shape}")  # Should be (768, 768)
    print(f"Bias Shape: {layer.attention.__getattr__(name).bias.shape if layer.attention.__getattr__(name).bias is not None else 'None'}")
    print(f"Output Shape: {output.shape}")  # Should be (batch_size, S, d_model)

    return output

# Check input and output shapes for Q, K, V projections
q_output = check_projection(layer, input_tensor, "q_lin")
k_output = check_projection(layer, input_tensor, "k_lin")
v_output = check_projection(layer, input_tensor, "v_lin")
o_output = check_projection(layer, input_tensor, "out_lin")

# Confirm output shape consistency
assert q_output.shape == k_output.shape == v_output.shape == o_output.shape, "Q, K, V, O output shapes do not match!"
print("\n✅ Q, K, V, O, projections produce consistent shapes.")



==== Checking q_lin ====
Input Shape: torch.Size([1, 32, 768])
Weight Shape: torch.Size([768, 768])
Bias Shape: torch.Size([768])
Output Shape: torch.Size([1, 32, 768])

==== Checking k_lin ====
Input Shape: torch.Size([1, 32, 768])
Weight Shape: torch.Size([768, 768])
Bias Shape: torch.Size([768])
Output Shape: torch.Size([1, 32, 768])

==== Checking v_lin ====
Input Shape: torch.Size([1, 32, 768])
Weight Shape: torch.Size([768, 768])
Bias Shape: torch.Size([768])
Output Shape: torch.Size([1, 32, 768])

==== Checking out_lin ====
Input Shape: torch.Size([1, 32, 768])
Weight Shape: torch.Size([768, 768])
Bias Shape: torch.Size([768])
Output Shape: torch.Size([1, 32, 768])

✅ Q, K, V, O, projections produce consistent shapes.


In [4]:
print(layer.attention)

DistilBertSdpaAttention(
  (dropout): Dropout(p=0.1, inplace=False)
  (q_lin): Linear(in_features=768, out_features=768, bias=True)
  (k_lin): Linear(in_features=768, out_features=768, bias=True)
  (v_lin): Linear(in_features=768, out_features=768, bias=True)
  (out_lin): Linear(in_features=768, out_features=768, bias=True)
)


In [5]:
print(layer.ffn)

FFN(
  (dropout): Dropout(p=0.1, inplace=False)
  (lin1): Linear(in_features=768, out_features=3072, bias=True)
  (lin2): Linear(in_features=3072, out_features=768, bias=True)
  (activation): GELUActivation()
)
