# 06 - Cross-Modal Attention Multimodal Detection

## Objective
Use transformer-based cross-attention to learn modality interactions.

## Architecture
- Multi-head cross-attention between modalities
- Self-attention within each modality
- Learnable modality embeddings

## Novel Contribution
Cross-modal attention discovers relationships:
- Audio-visual synchronization
- Temporal consistency across modalities
- Modality-specific artifact detection

In [None]:
class CrossModalAttention(nn.Module):
    def __init__(self, dim=512, num_heads=8):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, query, key, value):
        # Cross attention
        attn_out, _ = self.multihead_attn(query, key, value)
        x = self.norm1(query + attn_out)
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

class CrossModalDetector(nn.Module):
    def __init__(self, feature_dim=512):
        super().__init__()
        # Feature projections
        self.img_proj = nn.Linear(512, feature_dim)
        self.aud_proj = nn.Linear(768, feature_dim)
        self.vid_proj = nn.Linear(512, feature_dim)
        
        # Modality embeddings
        self.modality_embed = nn.Embedding(3, feature_dim)
        
        # Cross-attention layers
        self.img_to_aud = CrossModalAttention(feature_dim)
        self.img_to_vid = CrossModalAttention(feature_dim)
        self.aud_to_vid = CrossModalAttention(feature_dim)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim * 3, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )
    
    def forward(self, img_feat, aud_feat, vid_feat):
        # Project features
        img = self.img_proj(img_feat).unsqueeze(1)  # (B, 1, D)
        aud = self.aud_proj(aud_feat).unsqueeze(1)
        vid = self.vid_proj(vid_feat).unsqueeze(1)
        
        # Add modality embeddings
        img = img + self.modality_embed(torch.tensor([0], device=img.device))
        aud = aud + self.modality_embed(torch.tensor([1], device=aud.device))
        vid = vid + self.modality_embed(torch.tensor([2], device=vid.device))
        
        # Cross-modal attention
        img_enhanced = self.img_to_aud(img, aud, aud) + self.img_to_vid(img, vid, vid)
        aud_enhanced = self.aud_to_vid(aud, vid, vid)
        
        # Aggregate
        combined = torch.cat([img_enhanced.squeeze(1), aud_enhanced.squeeze(1), vid.squeeze(1)], dim=1)
        output = self.classifier(combined)
        return output

print('Cross-Modal Attention Model Defined!')