## Cross Attention


 Cross-Attention is a fundamental mechanism in modern AI, especially in Tranformer models. It allows a neural network to dynamically relate and fuse information from different input sequences.

**Core Concepts**
- **Self-Attention** lets a sequence "attend" to itself, calculating how each element(e.g a word) relates to every other element in the same sequence.

- **Cross-Attention** upgrades the **Self-Attention**; It lets one sequence(Query) attend to a different sequence(Key,Value pairs).

- In Cross-Attention the queries come from one source, while the keys and values come from another.

- Cross-Attention is generally used in *Multi-Modal Language Models* where context from 2 different sequences is available.

![cross_attn.png](./imgs/cross_attn.png)

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import math

- First we will implement MHSA and then we will Implement Cross MultiHead Self-Attention to compare the difference.

In [12]:
class MHSA(nn.Module):
    def __init__(self, head_dim = 64, D_model = 512, causal = True):
        super().__init__()
        self.head_dim = head_dim
        self.D_model = D_model
        assert D_model % head_dim == 0, "Error: Head Dimension does not evenly divide residual stream dimension."
        self.num_heads = D_model // head_dim
        self.causal = causal
        
        self.qkv_proj = nn.Linear(D_model, D_model *3)
        self.out_proj = nn.Linear(D_model, D_model)
        
    def forward(self, x):
        B, S, _ = x.shape  #Shape: [batch_size, seq_len, dim_model]
        
        #Get the QKV projection
        qkv = self.qkv_proj(x) #Shape: [batch_size, seq_len, dim_model]------>[batch_size, seq_len, dim_model*3]
        
        #Reshape the QKV proj: [batch_size, seq_len, dim_model]------>[batch_size, seq_len, 3, num_heads, head_dim]
        qkv = qkv.reshape(B, S, 3, self.num_heads, self.head_dim)
        
        #Permute(Re-arrange) the Dimensions of QKV projection: [batch_size, seq_len, 3, num_heads, head_dim]------>[3, batch_size, num_heads, seq_len, head_dim]
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        #Split the QKV projection into 3 matrices that are Q,K,V
        q = qkv[0] #Shape: [batch_size, num_heads, seq_len, head_dim]
        k = qkv[1] #Shape: [batch_size, num_heads, seq_len, head_dim]
        v = qkv[2] #Shape: [batch_size, num_heads, seq_len, head_dim]
        
        attn_weight = torch.matmul(q, k.transpose(2,3))  
        ''' 
            k.transpose: Shape [batch_size, num_heads, seq_len, head_dim]----->[batch_size, num_heads, head_dim, seq_len]
            matmul: Query @ Key.transpose
                    Shape:[batch_size, num_heads, seq_len, head_dim] @ [batch_size, num_heads, head_dim, seq_len]------>[batch_size, num_heads, seq_len, seq_len]
        '''
        
        denominator = math.sqrt(self.head_dim)
        attn_weight = attn_weight / denominator
        
        #Applt Causal Mask
        if self.causal:
            #Create a mask in which a position can't attend to next position
            mask = torch.arange(S)[:,None] >= torch.arange(S)
            
            #Replace all the places of MASK with (-inf) so that when softmax() is applied they conntribute nothing(i.e 0)
            attn_weight = torch.where(mask, attn_weight, float('-inf') * torch.ones_like(attn_weight))
            
        attn_scores = F.softmax(attn_weight, dim = -1)
        
        output = torch.einsum('bnij, bnjd->bnid', attn_scores, v) #Shape[batch,num_heads, seq_q, seq_k] @ [batch, num_head, seq_k]----->[batch, num_head, seq_k, head_dim]
        output = output.transpose(1,2).reshape(B,S,-1) #Shape [batch, seq, d_model]
        
        return self.out_proj(output) #[dim_model, dim_model] @ [b, s, d] along last dim---->[batch, seq, dim]


In [14]:
class CrossMHSA(nn.Module):
    def __init__(self, head_dim = 64, D_model = 512):
        super().__init__()
        self.head_dim = head_dim 
        self.D_model = D_model
        
        assert D_model % head_dim == 0, "Critical Error: Head Dimension does not evenly divide Dimension of Model"
        self.num_heads = D_model // head_dim
        
        self.q_proj = nn.Linear(D_model, D_model)
        self.k_proj = nn.Linear(D_model, D_model)
        self.v_proj = nn.Linear(D_model, D_model)
        self.o_proj = nn.Linear(D_model, D_model)
    
    '''  
        For Cross MultiHead Self-Attention we will be having 2 input sequences.
        Input_Sequence_1 Shape: [batch_dim, seq_len_1, d_model]
        Input_Sequence_2 Shape: [batch_dim, seq_len_2, d_model]
        
    '''
    def forward(self, x1, x2):
        B, S1 , _ = x1.shape
        _, S2, _  = x2.shape
        
        q = self.q_proj(x1) #[batch_size, seq1, d_model]
        k , v= self.k_proj(x2), self.v_proj(x2) #[batch, seq2, d_model]
        
        ''' 
            Reshape the Projection matrices.
            [B, S, D]----->[B, S, NUM_HEADS, HEAD_DIM]
            
            Then take transpose of the matrices
            [B, S, NUM_HEADS, HEAD_DIM]------>[B, NUM_HEADS, S, HEAD_DIM]
        '''
        q = q.reshape(B, S1, self.num_heads, self.head_dim).transpose(1,2) #[batch, num_heads, seq1, head_dim]
        k = k.reshape(B, S2, self.num_heads, self.head_dim).transpose(1,2) #[batch, num_heads, seq2, head_dim]
        v = v.reshape(B, S2, self.num_heads, self.head_dim).transpose(1,2) #[batch, num_heads, seq2, head_dim]
        
        attn_weights = torch.einsum('bnid, bnjd->bnij', q, k) #[batch, num_heads, seq1, seq2]
        denominator = math.sqrt(self.head_dim)
        attn_weights = attn_weights / denominator
        
        attn_scores = F.softmax(attn_weights, dim= - 1)
        
        #Do Scores @ Value then concatenate attn heads together to get a long value vector of dim D
        out = torch.einsum('bnij,bnjd->bnid', attn_scores, v)
        out = out.transpose(1,2).reshape(B,S1,-1)
        
        return self.o_proj(out), attn_scores

In [18]:
batch_size = 2
seq_len = 5
dim = 512
head_dim = 64

# test 1: when sequences are the same
x = torch.randn(batch_size, seq_len, dim)

# initialize both attention modules
# for fair comparison, we need to set causal=False in MHSA
mhsa = MHSA(head_dim=head_dim, D_model=dim, causal=False)
cross_attn = CrossMHSA(head_dim=head_dim, D_model=dim)

# initialize weights to be the same
with torch.no_grad():
    # set qkv projections to be equivalent
    mhsa.qkv_proj.weight.copy_(torch.cat([
        cross_attn.q_proj.weight,
        cross_attn.k_proj.weight,
        cross_attn.v_proj.weight
    ], dim=0))
    mhsa.qkv_proj.bias.copy_(torch.cat([
        cross_attn.q_proj.bias,
        cross_attn.k_proj.bias,
        cross_attn.v_proj.bias
    ], dim=0))
    
    # set output projections to be the same
    cross_attn.o_proj.weight.copy_(mhsa.out_proj.weight)
    cross_attn.o_proj.bias.copy_(mhsa.out_proj.bias)

# forward pass
mhsa_output = mhsa(x)
cross_attn_output, _ = cross_attn(x, x)  # same sequence for both inputs

# check if outputs are the same when inputs are the same
is_close = torch.allclose(mhsa_output, cross_attn_output, rtol=1e-4, atol=1e-4)
print(f"when sequences are identical: outputs match = {is_close}")

# test 2: when sequences are different
seq_len1 = 5
seq_len2 = 7
x1 = torch.randn(batch_size, seq_len1, dim)
x2 = torch.randn(batch_size, seq_len2, dim)

# forward pass with different sequences
cross_attn_diff, _ = cross_attn(x1, x2)

# check output shape
expected_shape = (batch_size, seq_len1, dim)
assert cross_attn_diff.shape == expected_shape, f"expected shape {expected_shape}, got {cross_attn_diff.shape}"

# try to compute regular self-attention with different sequence lengths (this should fail or give different results)
try:
    # this won't work directly with MHSA as it expects a single sequence
    print("cross-attention allows different sequence lengths, while self-attention requires identical sequences")
except Exception as e:
    print(f"as expected, self-attention can't handle different sequence lengths: {e}")

# verify that cross-attention with different sequences gives different results than with identical sequences
x1_copy = x1.clone()
cross_attn_same, _ = cross_attn(x1_copy, x1_copy)
is_different = not torch.allclose(cross_attn_same, cross_attn_diff, rtol=1e-4, atol=1e-4)
print(f"cross-attention gives different results with different sequences: {is_different}")

print("all tests completed!")

when sequences are identical: outputs match = True
cross-attention allows different sequence lengths, while self-attention requires identical sequences
cross-attention gives different results with different sequences: True
all tests completed!


In [20]:
import plotly.graph_objects as go 

B = 1
S1 = 6
S2 = 10
x1 = torch.randn(B, S1, 64)
x2 = torch.randn(B, S2, 64)

model = CrossMHSA(head_dim= 16, D_model=64)
output, scores = model(x1, x2)

head = 0
attn_scores = scores[0, head].detach().cpu().numpy()

fig = go.Figure(data=[go.Surface(z = attn_scores, colorscale='Viridis')])
fig.update_layout(
    title = f'Cross Attention Weights (HEAD{head})',
    scene = dict(
        xaxis_title = 'Key Tokens(x2)',
        yaxis_title = 'Query Tokens (x1)',
        zaxis_title = 'Attention Weight'
    ),
    autosize = True,
    margin = dict(l = 60, r = 60, b = 60, t = 60),
)

fig.show()