# Single self-attention implementation

- A batch of input sequences, each consisting of embeddings for tokens.
- Each embedding has a certain dimensionality (embed_dim).
- The input tensor dimensions are (batch_size, sequence_length, embed_dim).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int):
        super(SingleHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim

        # Linear layers to create queries, keys, and values
        self.query_linear = nn.Linear(embed_dim, embed_dim)
        self.key_linear = nn.Linear(embed_dim, embed_dim)
        self.value_linear = nn.Linear(embed_dim, embed_dim)

        # Final linear layer after attention
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_length, embed_dim = x.shape

        # Compute queries, keys, values
        queries = self.query_linear(x)  # (batch_size, seq_length, embed_dim)
        keys = self.key_linear(x)       # (batch_size, seq_length, embed_dim)
        values = self.value_linear(x)   # (batch_size, seq_length, embed_dim)

        # Compute attention scores
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (embed_dim ** 0.5)

        # Apply mask if provided (useful for ignoring padding tokens)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))

        # Normalize scores into probabilities
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Weighted sum of values
        attention_output = torch.matmul(attention_weights, values)

        # Final linear layer
        output = self.fc_out(attention_output)
        return output

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==== Single-Head Self-Attention ====
class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int):
        super(SingleHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.query_linear = nn.Linear(embed_dim, embed_dim)
        self.key_linear = nn.Linear(embed_dim, embed_dim)
        self.value_linear = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, words=None, mask: torch.Tensor = None):
        batch_size, seq_length, embed_dim = x.shape
        print(f"\n--- INPUT EMBEDDINGS ---")
        for b in range(batch_size):
            print(f"Sentence {b+1}:")
            for i, word in enumerate(words[b]):
                print(f"  {word:10s} -> {x[b, i].detach().numpy()}")

        # Q, K, V
        queries = self.query_linear(x)
        keys = self.key_linear(x)
        values = self.value_linear(x)

        # Attention scores
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (embed_dim ** 0.5)

        # Apply mask if any
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))

        # Softmax → attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Pretty-print attention map
        print(f"\n--- ATTENTION MAP ---")
        for b in range(batch_size):
            print(f"Sentence {b+1}:")
            for i, word in enumerate(words[b]):
                attn_dist = attention_weights[b, i].detach().numpy()
                looked_at = [f"{words[b][j]} ({attn_dist[j]:.2f})" for j in range(seq_length)]
                print(f'  "{word}" looks at: ' + ", ".join(looked_at))

        # Weighted sum of values
        attention_output = torch.matmul(attention_weights, values)

        # Final projection
        output = self.fc_out(attention_output)

        print(f"\n--- FINAL OUTPUT ---")
        for b in range(batch_size):
            print(f"Sentence {b+1}:")
            for i, word in enumerate(words[b]):
                print(f"  {word:10s} -> {output[b, i].detach().numpy()}")

        return output



In [11]:

# ==== Vocabulary ====
vocab = {
    "i": 0,
    "love": 1,
    "pizza": 2,
    "you": 3,
    "hate": 4,
    "broccoli": 5
}
vocab_size = len(vocab)
embed_dim = 4

# ==== Embedding layer ====
embedding_layer = nn.Embedding(vocab_size, embed_dim)

# ==== Sentences ====
sentences = [
    ["i", "love", "pizza"],
    ["i", "hate", "broccoli"]
]
token_ids = torch.tensor([[vocab[w] for w in sent] for sent in sentences])
mask = torch.ones(token_ids.shape, dtype=torch.int)

# ==== Convert to embeddings ====
embedded_sentences = embedding_layer(token_ids)

# ==== Run attention ====
torch.manual_seed(0)
attn = SingleHeadSelfAttention(embed_dim=embed_dim)
attn(embedded_sentences, words=sentences, mask=mask)



--- INPUT EMBEDDINGS ---
Sentence 1:
  i          -> [-0.09334823  0.6870502  -0.83831537  0.00089182]
  love       -> [ 0.8418941  -0.40003455  1.039462    0.3581531 ]
  pizza      -> [0.07324605 1.1133184  0.28226727 0.43422565]
Sentence 2:
  i          -> [-0.09334823  0.6870502  -0.83831537  0.00089182]
  hate       -> [ 0.20641631 -0.33344787 -0.42883     0.23291828]
  broccoli   -> [ 0.79688716 -0.18484132 -0.3701471  -1.2102813 ]

--- ATTENTION MAP ---
Sentence 1:
  "i" looks at: i (0.29), love (0.33), pizza (0.38)
  "love" looks at: i (0.25), love (0.43), pizza (0.32)
  "pizza" looks at: i (0.24), love (0.39), pizza (0.36)
Sentence 2:
  "i" looks at: i (0.36), hate (0.32), broccoli (0.32)
  "hate" looks at: i (0.32), hate (0.31), broccoli (0.37)
  "broccoli" looks at: i (0.38), hate (0.35), broccoli (0.27)

--- FINAL OUTPUT ---
Sentence 1:
  i          -> [ 0.00859527 -0.4691878  -0.10628858 -0.708125  ]
  love       -> [ 0.02552723 -0.4547485  -0.09126757 -0.6452503 ]
  pizza

tensor([[[ 0.0086, -0.4692, -0.1063, -0.7081],
         [ 0.0255, -0.4547, -0.0913, -0.6453],
         [ 0.0176, -0.4606, -0.1001, -0.6596]],

        [[ 0.1605, -0.3840,  0.0678, -1.0530],
         [ 0.1808, -0.3686,  0.0786, -1.0545],
         [ 0.1398, -0.3999,  0.0583, -1.0473]]], grad_fn=<ViewBackward0>)

# Multi-head attention

Suppose you’re given:
- A batch of sequences of token embeddings
- The number of attention heads (num_heads)
- The dimension of embeddings (embed_dim), which must be divisible by num_heads
- Input tensor dimensions: (batch_size, sequence_length, embed_dim)

In [None]:
import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        """
        Initialize the multi-head self-attention layer.

        Args:
            embed_dim: Dimension of input embeddings.
            num_heads: Number of attention heads.
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Validate embedding dimension divisibility
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads.")

        # TODO: Define linear layers for queries, keys, values, and output projection

        # Linear projections for queries, keys, and values
        self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.key_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # Final linear layer to combine heads
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Perform multi-head self-attention.

        Args:
            x: Input tensor of shape (batch_size, seq_length, embed_dim).
            mask: Optional tensor to mask specific tokens.

        Returns:
            Tensor of shape (batch_size, seq_length, embed_dim), attention output.
        """
        # TODO: Implement multi-head self-attention logic here
        batch_size, seq_length, embed_dim = x.shape

        # Project inputs to queries, keys, and values, and split embeddings into multiple heads
        queries = self.query_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        keys = self.key_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        values = self.value_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)

        # Transpose for multi-head attention computation (batch_size, num_heads, seq_length, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention scores per head
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply mask if provided
        if mask is not None:
            expanded_mask = mask.unsqueeze(1).unsqueeze(2)  # shape: (batch_size, 1, 1, seq_length)
            attention_scores = attention_scores.masked_fill(expanded_mask == 0, float('-inf'))

        # Normalize scores into attention probabilities
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Compute weighted sum of values
        attention_output = torch.matmul(attention_weights, values)

        # Concatenate heads' outputs back to original embedding dimension
        attention_output = attention_output.transpose(1, 2).contiguous()
        concatenated_output = attention_output.reshape(batch_size, seq_length, embed_dim)

        # Final linear layer to refine concatenated outputs
        output = self.fc_out(concatenated_output)

        return output


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Validate embedding dimension divisibility
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads.")

        # Linear projections for queries, keys, and values
        self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.key_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        # Final linear layer to combine heads
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_length, embed_dim = x.shape

        # Project inputs to queries, keys, and values, and split embeddings into multiple heads
        queries = self.query_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        keys = self.key_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        values = self.value_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim)

        # Transpose for multi-head attention computation (batch_size, num_heads, seq_length, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention scores per head
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply mask if provided
        if mask is not None:
            expanded_mask = mask.unsqueeze(1).unsqueeze(2)  # shape: (batch_size, 1, 1, seq_length)
            attention_scores = attention_scores.masked_fill(expanded_mask == 0, float('-inf'))

        # Normalize scores into attention probabilities
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Compute weighted sum of values
        attention_output = torch.matmul(attention_weights, values)

        # Concatenate heads' outputs back to original embedding dimension
        attention_output = attention_output.transpose(1, 2).contiguous()
        concatenated_output = attention_output.reshape(batch_size, seq_length, embed_dim)

        # Final linear layer to refine concatenated outputs
        output = self.fc_out(concatenated_output)

        return output

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==== Your Multi-Head Self-Attention ====
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads.")

        self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.key_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_length, embed_dim = x.shape

        Q = self.query_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value_proj(x).reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            # mask: (batch, seq) -> (batch, 1, 1, seq) to broadcast over heads and query positions
            expanded_mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(expanded_mask == 0, float('-inf'))

        weights = F.softmax(scores, dim=-1)   # (batch, heads, seq, seq)
        out = torch.matmul(weights, V)        # (batch, heads, seq, head_dim)

        out = out.transpose(1, 2).contiguous().reshape(batch_size, seq_length, embed_dim)
        out = self.fc_out(out)
        return out, weights  # return weights so we can pretty-print



In [15]:

# ==== A tiny driver that prints per-head attention nicely ====
def run_demo_pretty_print_mha():
    # Tiny vocabulary
    vocab = {"i":0, "love":1, "pizza":2, "you":3, "hate":4, "broccoli":5}
    sentences = [
        ["i", "love", "pizza"],
        ["i", "hate", "broccoli"]
    ]
    token_ids = torch.tensor([[vocab[w] for w in s] for s in sentences])  # (batch=2, seq=3)
    mask = torch.ones_like(token_ids, dtype=torch.int)  # keep it simple: no padding masked out

    # Small, readable dims
    embed_dim = 8
    num_heads = 2  # head_dim = 4
    torch.manual_seed(0)

    # Embedding + attention
    embedding = nn.Embedding(len(vocab), embed_dim)
    x = embedding(token_ids)  # (batch, seq, embed_dim)

    print("\n=== INPUT WORDS & EMBEDDINGS ===")
    for b, sent in enumerate(sentences):
        print(f"Sentence {b+1}:")
        for i, w in enumerate(sent):
            print(f"  {w:10s} -> {x[b, i].detach().numpy()}")

    mha = MultiHeadSelfAttention(embed_dim=embed_dim, num_heads=num_heads)
    out, attn_weights = mha(x, mask=mask)  # out: (batch, seq, embed_dim); attn_weights: (batch, heads, seq, seq)

    # Pretty print attention per head
    print("\n=== PER-HEAD ATTENTION MAPS (probabilities) ===")
    B, H, S, _ = attn_weights.shape
    for b in range(B):
        print(f"\nSentence {b+1}: {' '.join(sentences[b])}")
        for h in range(H):
            print(f"  Head {h}:")
            for i in range(S):
                dist = attn_weights[b, h, i].detach().numpy()
                looked_at = [f"{sentences[b][j]} ({dist[j]:.2f})" for j in range(S)]
                print(f'    "{sentences[b][i]}" looks at: ' + ", ".join(looked_at))

    # Final outputs
    print("\n=== FINAL OUTPUT VECTORS (after concat + fc_out) ===")
    for b, sent in enumerate(sentences):
        print(f"Sentence {b+1}:")
        for i, w in enumerate(sent):
            print(f"  {w:10s} -> {out[b, i].detach().numpy()}")

    # Optional: top-1 attention target per token per head
    print("\n=== TOP-1 TARGET PER TOKEN PER HEAD ===")
    top_idx = attn_weights.argmax(dim=-1)  # (batch, heads, seq)
    for b in range(B):
        print(f"Sentence {b+1}:")
        for h in range(H):
            picks = [sentences[b][top_idx[b, h, i].item()] for i in range(S)]
            srcs  = sentences[b]
            pairs = [f'{srcs[i]}→{picks[i]}' for i in range(S)]
            print(f"  Head {h}: " + ", ".join(pairs))


# Run it
if __name__ == "__main__":
    run_demo_pretty_print_mha()



=== INPUT WORDS & EMBEDDINGS ===
Sentence 1:
  i          -> [-1.1258398  -1.1523602  -0.25057858 -0.43387884  0.84871036  0.6920092
 -0.31601277 -2.1152196 ]
  love       -> [ 0.32227492 -1.2633348   0.3499832   0.3081339   0.11984151  1.2376579
  1.1167772  -0.24727765]
  pizza      -> [-1.3526537  -1.6959313   0.5666505   0.7935084   0.59883946 -1.5550951
 -0.3413603   1.8530061 ]
Sentence 2:
  i          -> [-1.1258398  -1.1523602  -0.25057858 -0.43387884  0.84871036  0.6920092
 -0.31601277 -2.1152196 ]
  hate       -> [-0.6135831   0.03159274 -0.49267703  0.24841475  0.43969586  0.11241119
  0.64079237  0.44115627]
  broccoli   -> [-0.10230965  0.792444   -0.28966758  0.05250749  0.5228604   2.3022053
 -1.4688939  -1.5866888 ]

=== PER-HEAD ATTENTION MAPS (probabilities) ===

Sentence 1: i love pizza
  Head 0:
    "i" looks at: i (0.33), love (0.31), pizza (0.36)
    "love" looks at: i (0.33), love (0.30), pizza (0.37)
    "pizza" looks at: i (0.36), love (0.30), pizza (0.34)
  H