# cross-attention 
just to explain how query, key and value interact in cross-attention (implemented using numpy)

In [None]:
import numpy as np

def cross_attention(query, key, value):
    """
    query: shape (num_queries, d_k)
    key:   shape (num_keys, d_k)
    value: shape (num_keys, d_v)
    
    returns: shape (num_queries, d_v)
    """
    d_k = query.shape[-1]

    # Step 1: Compute raw attention scores (dot product between queries and keys)
    scores = np.dot(query, key.T) / np.sqrt(d_k)  # shape: (num_queries, num_keys)

    # Step 2: Apply softmax to get attention weights
    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))  # stability
    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)  # shape: (num_queries, num_keys)

    # Step 3: Compute weighted sum of value vectors
    output = np.dot(attention_weights, value)  # shape: (num_queries, d_v)

    return output
    
# Toy example with 3 queries and 4 key-value pairs
np.random.seed(42)
Q = np.random.rand(3, 8)   # 3 query vectors (decoder tokens)
K = np.random.rand(4, 8)   # 4 key vectors (encoder tokens)
V = np.random.rand(4, 16)  # 4 value vectors with d_v = 16

output = cross_attention(Q, K, V)
print("Output shape:", output.shape)
print("Output:", output)

Output shape: (3, 16)
Output: [[0.47424909 0.44825225 0.35960412 0.65028308 0.52966772 0.2177133
  0.52158821 0.25817116 0.69138563 0.62032681 0.31718724 0.46422458
  0.34938322 0.4858956  0.69959381 0.49094542]
 [0.48639217 0.42460981 0.32800314 0.63567388 0.52019432 0.23346849
  0.54607276 0.27313906 0.66431665 0.61254396 0.30268994 0.48476288
  0.33229508 0.51576049 0.69228928 0.47398688]
 [0.47873282 0.43109256 0.3423324  0.64023396 0.51838087 0.22245792
  0.53300005 0.2641912  0.6771243  0.61896404 0.31349208 0.48526171
  0.34509942 0.50194426 0.70135753 0.47967725]]


In [8]:
import numpy as np

def cross_attention(query, key, value):
    """
    query: shape (num_queries, d_k)
    key:   shape (num_keys, d_k)
    value: shape (num_keys, d_v)
    
    returns: shape (num_queries, d_v)
    """
    d_k = query.shape[-1]

    # Step 1: Compute raw attention scores (dot product between queries and keys)
    scores = np.dot(query, key.T) / np.sqrt(d_k)  # shape: (num_queries, num_keys)

    # Step 2: Apply softmax to get attention weights
    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))  # stability
    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)  # shape: (num_queries, num_keys)

    # Step 3: Compute weighted sum of value vectors
    output = np.dot(attention_weights, value)  # shape: (num_queries, d_v)

    return output, attention_weights, scores

# ----------------------------
# Word-labeled, interpretable setup
# ----------------------------
doc_tokens = ["the", "cat", "sat", "on", "the", "mat"]      # encoder tokens
sum_tokens = ["cat", "on", "mat"]                           # decoder tokens (summary so far)

# Tiny 4-D “semantic” basis so we get intuitive alignments:
#   cat -> [1,0,0,0], sat -> [0,1,0,0], on -> [0,0,1,0], mat -> [0,0,0,1]
#   the is a weak, uniform vector so it attracts less attention
B = {
    "cat": np.array([1.,0.,0.,0.]),
    "sat": np.array([0.,1.,0.,0.]),
    "on" : np.array([0.,0.,1.,0.]),
    "mat": np.array([0.,0.,0.,1.]),
    "the": np.array([0.2,0.2,0.2,0.2]),
}

# Keys & Values from the encoder (values can have different dim; here 6 for variety)
d_k, d_v = 4, 6
K = np.stack([B[w] for w in doc_tokens], axis=0)            # (num_keys=6, d_k=4)

rng = np.random.default_rng(0)
V = rng.normal(size=(len(doc_tokens), d_v)).astype(np.float32)  # random values carry content

# Queries from the decoder (one per summary token)
Q = np.stack([B[w] for w in sum_tokens], axis=0)            # (num_queries=3, d_k=4)

# Run cross-attention
context, attn, scores = cross_attention(Q, K, V)

# ----------------------------
# Pretty-print the attention map
# ----------------------------
def format_matrix(attn, row_labels, col_labels, precision=3):
    # build a simple text table
    col_head = ["(query→) \\ (doc↓)"] + col_labels
    rows = []
    for i, rlab in enumerate(row_labels):
        row = [rlab] + [f"{attn[i,j]:.{precision}f}" for j in range(attn.shape[1])]
        rows.append(row)
    widths = [max(len(col_head[c]), max(len(r[c]) for r in rows)) for c in range(len(col_head))]
    def fmt_line(cells):
        return " | ".join(cell.ljust(w) for cell, w in zip(cells, widths))
    lines = [fmt_line(col_head),
             "-+-".join("-"*w for w in widths)]
    lines += [fmt_line(r) for r in rows]
    return "\n".join(lines)

print("Document tokens:", doc_tokens)
print("Summary tokens :", sum_tokens, "\n")

print("Attention weights (rows = summary queries, cols = document tokens):")
print(format_matrix(attn, sum_tokens, doc_tokens))

# Show the top attended doc token per summary token
top_idx = attn.argmax(axis=1)
for i, q in enumerate(sum_tokens):
    print(f"\nTop for '{q}': doc token '{doc_tokens[top_idx[i]]}' "
          f"(weight={attn[i, top_idx[i]]:.3f})")

# Context vectors that the decoder will use to predict the next summary tokens
print("\nContext vectors shape:", context.shape)
print(context)


Document tokens: ['the', 'cat', 'sat', 'on', 'the', 'mat']
Summary tokens : ['cat', 'on', 'mat'] 

Attention weights (rows = summary queries, cols = document tokens):
(query→) \ (doc↓) | the   | cat   | sat   | on    | the   | mat  
------------------+-------+-------+-------+-------+-------+------
cat               | 0.161 | 0.240 | 0.146 | 0.146 | 0.161 | 0.146
on                | 0.161 | 0.146 | 0.146 | 0.240 | 0.161 | 0.146
mat               | 0.161 | 0.146 | 0.146 | 0.146 | 0.161 | 0.240

Top for 'cat': doc token 'cat' (weight=0.240)

Top for 'on': doc token 'on' (weight=0.240)

Top for 'mat': doc token 'mat' (weight=0.240)

Context vectors shape: (3, 6)
[[ 0.05312044  0.31110953 -0.40936367 -0.26447013 -0.45491226  0.16061891]
 [-0.0312787   0.3201354  -0.35496194 -0.01554987 -0.45887702  0.18995572]
 [-0.16569856  0.20175229 -0.3578646  -0.09363573 -0.37566159  0.19032104]]


In [10]:
import numpy as np

np.random.seed(7)

# ----------------------------
# Toy vocab & document
# ----------------------------
Vocab = ["<BOS>", "<EOS>", "the", "cat", "sat", "on", "mat"]
tok2id = {w:i for i,w in enumerate(Vocab)}
id2tok = {i:w for w,i in tok2id.items()}
doc_tokens = ["the", "cat", "sat", "on", "the", "mat"]  # source document

# ----------------------------
# Dimensions
# ----------------------------
d_model = 16
d_k = 16
d_v = 16

# ----------------------------
# Embeddings and projections
# ----------------------------
E_src = np.random.randn(len(Vocab), d_model) / np.sqrt(d_model)  # encoder token embeddings
E_dec = np.random.randn(len(Vocab), d_model) / np.sqrt(d_model)  # decoder token embeddings
W_k = np.random.randn(d_model, d_k) / np.sqrt(d_model)
W_v = np.random.randn(d_model, d_v) / np.sqrt(d_model)
W_q = np.random.randn(d_model, d_k) / np.sqrt(d_model)
W_out = np.random.randn(d_v, len(Vocab)) / np.sqrt(d_v)
b_out = np.zeros(len(Vocab))

def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)

def cross_attention(query, key, value):
    # query: (d_k,), key: (Tsrc,d_k), value: (Tsrc,d_v)
    scores = (key @ query) / np.sqrt(d_k)     # (Tsrc,)
    attn = softmax(scores[None, :], axis=-1)[0]
    ctx = attn @ value                        # (d_v,)
    return ctx, attn, scores

# ----------------------------
# Encode the document (build K,V)
# ----------------------------
src_ids = [tok2id[w] for w in doc_tokens]
src_emb = E_src[src_ids]            # (Tsrc, d_model)
K = src_emb @ W_k                   # (Tsrc, d_k)
V = src_emb @ W_v                   # (Tsrc, d_v)

# ----------------------------
# Decode autoregressively with EOS masked early
# ----------------------------
max_len = 6
generated = []
attn_history = []

prev_id = tok2id["<BOS>"]
dec_state = E_dec[prev_id]

content_ids = np.array([tok2id[w] for w in Vocab if w not in ("<BOS>", "<EOS>")])

for t in range(max_len):
    # 1) Make a query from the current decoder state
    q = dec_state @ W_q             # (d_k,)

    # 2) Cross-attend to the source
    ctx, attn, scores = cross_attention(q, K, V)
    attn_history.append(attn)

    # 3) Predict next token from context
    logits = ctx @ W_out + b_out

    # --- Decoding constraints to avoid early stop / degeneracy ---
    # disallow <BOS> always
    logits[tok2id["<BOS>"]] = -1e9
    # disallow <EOS> for the first 3 steps
    if t < 3:
        logits[tok2id["<EOS>"]] = -1e9
    # (optional) discourage repeating the exact same token
    if generated:
        logits[generated[-1]] -= 2.0

    probs = softmax(logits[None, :], axis=-1)[0]
    next_id = int(np.argmax(probs))

    generated.append(next_id)
    if next_id == tok2id["<EOS>"]:
        break

    # 4) Update decoder state (toy): combine next token embedding + context
    dec_state = 0.5 * E_dec[next_id] + 0.5 * ctx

# ----------------------------
# Print results
# ----------------------------
gen_tokens = [id2tok[i] for i in generated]
print("Document: ", " ".join(doc_tokens))
print("Summary (generated): ", " ".join(gen_tokens))

print("\nCross-attention weights per step (rows = steps, cols = doc tokens):")
header = "step  | " + " | ".join([f"{w:>6}" for w in doc_tokens])
print(header)
print("-" * len(header))
for t, a in enumerate(attn_history, start=1):
    row = f"{t:>4}  | " + " | ".join([f"{w:>6.3f}" for w in a])
    print(row)


Document:  the cat sat on the mat
Summary (generated):  cat on cat <EOS>

Cross-attention weights per step (rows = steps, cols = doc tokens):
step  |    the |    cat |    sat |     on |    the |    mat
-----------------------------------------------------------
   1  |  0.170 |  0.167 |  0.156 |  0.174 |  0.170 |  0.162
   2  |  0.170 |  0.168 |  0.163 |  0.165 |  0.170 |  0.164
   3  |  0.164 |  0.167 |  0.166 |  0.170 |  0.164 |  0.169
   4  |  0.170 |  0.168 |  0.163 |  0.165 |  0.170 |  0.164


In [11]:
import numpy as np

# ----------------------------
# Vocab & document
# ----------------------------
Vocab = ["<BOS>", "<EOS>", "the", "cat", "sat", "on", "mat"]
tok2id = {w:i for i,w in enumerate(Vocab)}
id2tok = {i:w for w,i in tok2id.items()}

doc_tokens = ["the", "cat", "sat", "on", "the", "mat"]  # source document

# ----------------------------
# Dimensions
# ----------------------------
d_model = 5   # decoder "state" dim (one state per step)
d_k = 4       # key/query dim
d_v = 4       # value dim (same basis as content words)

# ----------------------------
# Semantic basis (orthogonal) for content words
# ----------------------------
e_cat = np.array([1.,0.,0.,0.])
e_sat = np.array([0.,1.,0.,0.])
e_on  = np.array([0.,0.,1.,0.])
e_mat = np.array([0.,0.,0.,1.])
e_the = np.array([0.2,0.2,0.2,0.2])  # weak, uniform

# ----------------------------
# Encoder: keys/values from doc tokens
# ----------------------------
def enc_vec(w):
    return {"cat": e_cat, "sat": e_sat, "on": e_on, "mat": e_mat}.get(w, e_the)

K = np.stack([enc_vec(w) for w in doc_tokens], axis=0)  # (Tsrc, d_k)
V = K.copy()                                            # (Tsrc, d_v)

# ----------------------------
# Decoder states & query projection
# We design states so W_q maps:
#   s0(<BOS>)->cat, s1(cat)->sat, s2(sat)->on, s3(on)->mat, s4(mat)->mat (to then produce EOS)
# ----------------------------
s0 = np.array([1,0,0,0,0], dtype=float)  # <BOS>
s1 = np.array([0,1,0,0,0], dtype=float)  # after generating "cat"
s2 = np.array([0,0,1,0,0], dtype=float)  # after "sat"
s3 = np.array([0,0,0,1,0], dtype=float)  # after "on"
s4 = np.array([0,0,0,0,1], dtype=float)  # after "mat"

E_dec = {
    "<BOS>": s0,
    "cat":   s1,
    "sat":   s2,
    "on":    s3,
    "mat":   s4,
    "<EOS>": np.zeros_like(s0),
    "the":   np.zeros_like(s0),
}

# W_q rows map state -> desired query
# shape: (d_model, d_k). Row i is the query produced when that state is 1-hot.
W_q = np.stack([
    e_cat,  # from s0
    e_sat,  # from s1
    e_on,   # from s2
    e_mat,  # from s3
    e_mat,  # from s4 (we'll allow EOS at this step)
], axis=0)  # (5,4)

# ----------------------------
# Output head: context -> logits over vocab
# Make the next-token obvious:
#   e_cat -> "cat", e_sat -> "sat", e_on -> "on", e_mat -> "mat"
# Also: when EOS is allowed, make it beat "mat" if context = e_mat
# ----------------------------
W_out = np.zeros((d_v, len(Vocab)))
# columns correspond to tokens; set them to the matching basis
W_out[:, tok2id["cat"]] = e_cat
W_out[:, tok2id["sat"]] = e_sat
W_out[:, tok2id["on"]]  = e_on
W_out[:, tok2id["mat"]] = e_mat
# make EOS extra strong on e_mat so step after "mat" picks EOS
W_out[:, tok2id["<EOS>"]] = 2.0 * e_mat
b_out = np.zeros(len(Vocab))

def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)

def cross_attention(query, key, value):
    # query: (d_k,), key: (Tsrc,d_k), value: (Tsrc,d_v)
    scores = (key @ query) / np.sqrt(d_k)          # (Tsrc,)
    attn = softmax(scores[None, :], axis=-1)[0]    # (Tsrc,)
    ctx = attn @ value                              # (d_v,)
    return ctx, attn, scores

# ----------------------------
# Autoregressive decoding
# ----------------------------
max_len = 6
generated = []
attn_history = []

state = E_dec["<BOS>"]  # start
for t in range(max_len):
    # 1) make query from state
    q = state @ W_q                       # (d_k,)

    # 2) cross-attend to source
    ctx, attn, _ = cross_attention(q, K, V)
    attn_history.append(attn)

    # 3) predict next token
    logits = ctx @ W_out + b_out

    # disallow BOS always
    logits[tok2id["<BOS>"]] = -1e9
    # disallow EOS for the first 4 steps, allow at t >= 4
    if t < 4:
        logits[tok2id["<EOS>"]] = -1e9

    probs = softmax(logits[None, :], axis=-1)[0]
    next_id = int(np.argmax(probs))
    generated.append(next_id)

    if id2tok[next_id] == "<EOS>":
        break

    # 4) update state to the embedding of the token we just emitted
    state = E_dec[id2tok[next_id]]

# ----------------------------
# Display
# ----------------------------
gen_tokens = [id2tok[i] for i in generated]
print("Document: ", " ".join(doc_tokens))
print("Summary  : ", " ".join(gen_tokens))

print("\nCross-attention weights per step (rows=steps, cols=document tokens):")
header = "step  | " + " | ".join([f"{w:>6}" for w in doc_tokens])
print(header)
print("-" * len(header))
for t, a in enumerate(attn_history, start=1):
    row = f"{t:>4}  | " + " | ".join([f"{w:>6.3f}" for w in a])
    print(row)


Document:  the cat sat on the mat
Summary  :  cat sat on mat <EOS>

Cross-attention weights per step (rows=steps, cols=document tokens):
step  |    the |    cat |    sat |     on |    the |    mat
-----------------------------------------------------------
   1  |  0.161 |  0.240 |  0.146 |  0.146 |  0.161 |  0.146
   2  |  0.161 |  0.146 |  0.240 |  0.146 |  0.161 |  0.146
   3  |  0.161 |  0.146 |  0.146 |  0.240 |  0.161 |  0.146
   4  |  0.161 |  0.146 |  0.146 |  0.146 |  0.161 |  0.240
   5  |  0.161 |  0.146 |  0.146 |  0.146 |  0.161 |  0.240
