In [1]:
# cross attention demo - used for encoder-decoder models (i.e., T5: Text-to-Text Transfer Transformer)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dummy encoder outputs
encoder_outputs = torch.tensor([[1.0, 0.0, 0.5, 0.5],
                                [0.5, 1.0, 0.0, 0.5],
                                [0.0, 0.5, 1.0, 0.5]])

In [3]:
# Decoder hidden state at current time step (query)
decoder_hidden = torch.tensor([[0.5, 0.5, 0.0, 1.0]])

In [4]:
decoder_hidden.shape

torch.Size([1, 4])

In [5]:
# --- Step 1: Create Q, K, V ---
W_q = nn.Linear(4, 4, bias=False)
W_k = nn.Linear(4, 4, bias=False)
W_v = nn.Linear(4, 4, bias=False)

Q = W_q(decoder_hidden)
K = W_k(encoder_outputs)
V = W_v(encoder_outputs)

In [6]:
# --- Step 2: Attention scores ---
attn_scores = torch.matmul(Q, K.T)         # (1, 3)

In [7]:
# --- Step 3: Softmax over encoder tokens ---
attn_weights = F.softmax(attn_scores, dim=-1)  # (1, 3)

In [8]:
# --- Step 4: Compute attention context vector ---
context = torch.matmul(attn_weights, V)    # (1, 4)

In [9]:
# --- Step 5: Linear projection of context vector ---
output_proj = nn.Linear(4, 4, bias=False)  # d_model → d_model
attn_output = output_proj(context)         # (1, 4)

In [10]:
# --- Step 6: Residual connection + LayerNorm ---
residual = decoder_hidden
layer_norm = nn.LayerNorm(4)
output_embedding = layer_norm(attn_output + residual)  # (1, 4)

In [11]:
# --- Final Output ---
print("Attention Scores:\n", attn_scores)
print("Attention Weights:\n", attn_weights)
print("Context Vector:\n", context)
print("Final Output Embedding (for next layer):\n", output_embedding)

Attention Scores:
 tensor([[0.1654, 0.1047, 0.0918]], grad_fn=<MmBackward0>)
Attention Weights:
 tensor([[0.3484, 0.3279, 0.3237]], grad_fn=<SoftmaxBackward0>)
Context Vector:
 tensor([[ 0.7012, -0.0544,  0.0585,  0.2132]], grad_fn=<MmBackward0>)
Final Output Embedding (for next layer):
 tensor([[-0.0396, -0.3584, -1.1777,  1.5756]],
       grad_fn=<NativeLayerNormBackward0>)
