# Cross-Modal Attention Mechanism

In this notebook, we implement a Cross-Modal Attention mechanism where text features attend to image features, inspired by architectures like Flamingo.

This type of attention allows the model to dynamically route relevant visual information based on textual context — similar to how humans align specific words in a question to parts of an image.


## Imports

We import PyTorch libraries for building the model and numpy for numerical operations.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


## CrossModalAttention Class

This class implements the cross-modal attention mechanism.

- It uses multi-head attention to let text features (queries) attend to image features (keys and values).
- The linear layers project inputs into query, key, and value spaces.
- Scaled dot-product attention computes similarity scores.
- The output is a fused representation combining text and image information.


In [None]:
class CrossModalAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear layers to project text features to queries,
        # and image features to keys and values
        self.text_query = nn.Linear(embed_dim, embed_dim)
        self.image_key = nn.Linear(embed_dim, embed_dim)
        self.image_value = nn.Linear(embed_dim, embed_dim)

        self.output_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.1)


## Forward Pass

The `forward` method computes the attention output:

- Input shapes:
  - `text_features`: (batch_size, text_len, embed_dim)
  - `image_features`: (batch_size, image_len, embed_dim)

- Steps:
  1. Project inputs into query (Q), key (K), and value (V) vectors.
  2. Reshape and transpose for multi-head attention.
  3. Calculate scaled dot-product attention scores.
  4. Apply an optional attention mask.
  5. Use softmax to get attention weights and apply dropout.
  6. Compute weighted sum of the values.
  7. Project back to output dimension.


In [None]:
    def forward(self, text_features, image_features, attention_mask=None):
        batch_size, text_len, _ = text_features.shape
        _, image_len, _ = image_features.shape

        # Linear projections
        Q = self.text_query(text_features)      # Queries from text
        K = self.image_key(image_features)      # Keys from image
        V = self.image_value(image_features)    # Values from image

        # Reshape for multi-head attention
        Q = Q.view(batch_size, text_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, image_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, image_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Weighted sum of values
        context = torch.matmul(attention_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, text_len, self.embed_dim)

        output = self.output_proj(context)
        return output, attention_weights


## Example Usage

We create an instance of the model and pass simulated text and image features.

- Text features: batch of 4 samples, 20 tokens each, embedding size 256.
- Image features: batch of 4 samples, 196 image patches (14x14 grid), embedding size 256.

The output includes attended features and attention weights.


In [None]:
# Create the model
cross_attention = CrossModalAttention(embed_dim=256, num_heads=8)

# Simulated input data
text_features = torch.randn(4, 20, 256)    # 4 samples, 20 text tokens
image_features = torch.randn(4, 196, 256)  # 4 samples, 196 image patches (14x14)

# Forward pass
attended_features, attention_weights = cross_attention(text_features, image_features)

# Inspect outputs
print(f"Cross-modal attention output shape: {attended_features.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
