In [None]:
import sys
sys.path.append("../")
import torch

# Attention
from src.components.attention import SelfAttention

**1. Parameters**

In [8]:
batch_size = 2
seq_len = 8
d_model = 128
n_heads = 4
max_seq_len = 512

**2. Embedding creation**

*We create a random tensor to simulate the embeddings input to the attention mechanism.*

In [26]:
# Create input (from embeddings)
# 2 words / 8 tokens each / embedding size 128
x = torch.randn(batch_size, seq_len, d_model)
print(x.shape)
x[0]

torch.Size([2, 8, 128])


tensor([[ 1.4035,  0.3935,  0.5077,  ...,  0.0388,  1.7818,  0.0272],
        [ 0.3740,  0.1924,  0.7580,  ...,  0.4056, -1.3121, -1.6064],
        [-1.0050,  0.7753,  0.8475,  ...,  1.0398,  1.2283, -1.0327],
        ...,
        [-1.9996, -0.1667, -0.6915,  ..., -0.5849, -1.2648,  0.2952],
        [-1.2510,  2.2274, -0.0350,  ..., -0.6277, -0.8276,  0.6030],
        [-1.6712,  1.2298, -1.4076,  ..., -0.5431, -0.2476,  0.8017]])

**3. Self-Attention Layer**

In [9]:
# Create attention module
attn = SelfAttention(d_model, n_heads, max_seq_len=max_seq_len)

# Forward pass
out = attn(x)

[32m2026-02-03 11:17:48.992[0m | [34m[1mDEBUG   [0m | [36msrc.components.attention[0m:[36m__init__[0m:[36m54[0m - [34m[1mCausalSelfAttention: d_model=128, n_heads=4, head_dim=32, dropout=0.1[0m


In [11]:
out.shape

torch.Size([2, 8, 128])

In [24]:
out[0]

tensor([[ 0.2482,  0.6255,  0.4999,  ...,  0.8962, -0.7809, -0.0728],
        [-0.4772,  0.0000,  0.1295,  ...,  1.2026, -0.8311, -0.5821],
        [ 0.1638,  0.8480,  0.1075,  ...,  0.9776, -0.5605,  0.2416],
        ...,
        [-0.5093,  0.0854, -0.0293,  ..., -0.0854, -0.4644,  0.2742],
        [-0.0546,  0.7255, -0.3142,  ..., -0.3760, -0.3325,  0.0056],
        [-0.3602,  0.3267,  0.0000,  ..., -0.2337, -0.2548,  0.1144]],
       grad_fn=<SelectBackward0>)

<pre style="font-family: Menlo, Consolas, monospace; font-size: 14px; line-height: 1.35;">
┌──────────────────────────────────────────────┐
│ Input                                        │
│ (batch, seq, d_model)                        │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Linear Projection                            │
│ Q, K, V                                      │
│ (batch, heads, seq, head_dim)                │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Attention Scores                             │
│ Q @ Kᵀ / √d_k                                │
│ (batch, heads, seq, seq)                     │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Causal Mask                                  │
│ Lower triangular mask                        │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Softmax                                      │
│ Attention weights                            │
│ (batch, heads, seq, seq)                     │
│ Σ weights = 1 (along seq dimension)          │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Weighted Sum                                 │
│ Weights @ V                                  │
│ (batch, heads, seq, head_dim)                │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Reshape                                      │
│ Combine heads                                │
│ (batch, seq, d_model)                        │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Output Projection                            │
│ Final linear                                 │
│ (batch, seq, d_model)                        │
└──────────────────────────────────────────────┘
                    ↓
┌──────────────────────────────────────────────┐
│ Output                                       │
│ (batch, seq, d_model)                        │
└──────────────────────────────────────────────┘
</pre>


---