In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
# --- Cross-Attention Module ---
class CrossAttention(nn.Module):
    def __init__(self, d_model):
        
        super().__init__()
        
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, decoder_hidden, encoder_output, mask=None):
            # decoder_hidden: [num_queries, d_model] ← Q
            # encoder_output: [num_keys, d_model]     ← K, V

            Q = self.W_q(decoder_hidden)              # [num_queries, d_model]
            K = self.W_k(encoder_output)              # [num_keys, d_model]
            V = self.W_v(encoder_output)              # [num_keys, d_model]

            # Compute attention scores
            attention_scores = torch.matmul(Q, K.T)   # [num_queries, num_keys]

            # Scale scores
            d_k = K.size(-1)
            attention_scores /= d_k ** 0.5

            # Optional masking
            if mask is not None:
                attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

            # Softmax over encoder tokens (keys)
            attention_weights = F.softmax(attention_scores, dim=-1)

            # Weighted sum of values
            output = torch.matmul(attention_weights, V)  # [num_queries, d_model]

            return output, attention_weights

    


In [16]:
# --- Simulated Input ---

# Suppose encoder output has 3 tokens, each with 4-dim embeddings
encoder_output = torch.tensor([
    [1.0, 0.5, 0.2, 0.1],
    [0.4, 0.6, 0.9, 0.3],
    [0.3, 0.2, 0.5, 0.7]
])

# Decoder hidden state is 2 tokens, each with 4-dim embeddings
decoder_hidden = torch.tensor([
    [0.9, 0.1, 0.3, 0.5],
    [0.2, 0.8, 0.6, 0.4]
])

# Set random seed for reproducibility
torch.manual_seed(42)

# Create cross-attention module
cross_attention = CrossAttention(d_model=4)

# Run cross-attention
output, attn_weights = cross_attention.forward(decoder_hidden, encoder_output)

print("Cross-Attention Output:\n", output)
print("\nAttention Weights:\n", attn_weights)


Cross-Attention Output:
 tensor([[0.2845, 0.1918, 0.2908, 0.2556],
        [0.2850, 0.1922, 0.2902, 0.2556]], grad_fn=<MmBackward0>)

Attention Weights:
 tensor([[0.3462, 0.3255, 0.3283],
        [0.3487, 0.3244, 0.3269]], grad_fn=<SoftmaxBackward0>)
