In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

print("Ready to explore attention fusion strategies!")


In [None]:
class ShallowAttentionModule(nn.Module):
    """Shallow Attention Module (SAM) for efficient attention computation"""
    
    def __init__(self, d_model, reduction_factor=4):
        super(ShallowAttentionModule, self).__init__()
        self.d_model = d_model
        self.d_reduced = d_model // reduction_factor
        
        # Reduced dimension projections for efficiency
        self.W_q = nn.Linear(d_model, self.d_reduced)
        self.W_k = nn.Linear(d_model, self.d_reduced)
        self.W_v = nn.Linear(d_model, self.d_reduced)
        self.W_o = nn.Linear(self.d_reduced, d_model)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, query, key, value):
        """
        Shallow attention forward pass
        
        Args:
            query: [batch_size, tgt_len, d_model]
            key: [batch_size, src_len, d_model]
            value: [batch_size, src_len, d_model]
        
        Returns:
            output: [batch_size, tgt_len, d_model]
            attention_weights: [batch_size, tgt_len, src_len]
        """
        batch_size, tgt_len, _ = query.size()
        src_len = key.size(1)
        
        # Project to reduced dimensions
        Q = self.W_q(query)  # [batch_size, tgt_len, d_reduced]
        K = self.W_k(key)    # [batch_size, src_len, d_reduced]
        V = self.W_v(value)  # [batch_size, src_len, d_reduced]
        
        # Compute attention scores (simplified)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_reduced)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        attended = torch.matmul(attention_weights, V)
        
        # Project back to original dimension
        output = self.W_o(attended)
        
        # Residual connection and layer norm
        output = self.layer_norm(output + query)
        
        return output, attention_weights

class DeepAttentionModule(nn.Module):
    """Deep Attention Module with multiple attention layers"""
    
    def __init__(self, d_model, num_layers=3, num_heads=8):
        super(DeepAttentionModule, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        
        # Multiple attention layers
        self.attention_layers = nn.ModuleList([
            nn.MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])
        
        # Layer normalizations
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(num_layers)
        ])
        
    def forward(self, query, key, value):
        """Deep attention with multiple layers"""
        
        output = query
        attention_weights_list = []
        
        for i, (attn_layer, layer_norm) in enumerate(zip(self.attention_layers, self.layer_norms)):
            # Apply attention
            attn_output, attn_weights = attn_layer(output, key, value)
            attention_weights_list.append(attn_weights)
            
            # Residual connection and layer norm
            output = layer_norm(attn_output + output)
        
        return output, attention_weights_list

def compare_attention_efficiency():
    """Compare computational efficiency of different attention approaches"""
    
    batch_size = 8
    seq_len = 128
    d_model = 512
    
    # Create test data
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize models
    shallow_attn = ShallowAttentionModule(d_model, reduction_factor=4)
    deep_attn = DeepAttentionModule(d_model, num_layers=3)
    
    print("ATTENTION EFFICIENCY COMPARISON")
    print("=" * 40)
    
    # Test shallow attention
    start_time = time.time()
    for _ in range(10):
        shallow_output, shallow_weights = shallow_attn(query, key, value)
    shallow_time = time.time() - start_time
    
    # Test deep attention
    start_time = time.time()
    for _ in range(10):
        deep_output, deep_weights = deep_attn(query, key, value)
    deep_time = time.time() - start_time
    
    # Calculate parameters
    shallow_params = sum(p.numel() for p in shallow_attn.parameters())
    deep_params = sum(p.numel() for p in deep_attn.parameters())
    
    print(f"Shallow Attention:")
    print(f"  Parameters: {shallow_params:,}")
    print(f"  Time (10 iterations): {shallow_time:.4f}s")
    print(f"  Output shape: {shallow_output.shape}")
    print()
    
    print(f"Deep Attention:")
    print(f"  Parameters: {deep_params:,}")
    print(f"  Time (10 iterations): {deep_time:.4f}s")
    print(f"  Output shape: {deep_output.shape}")
    print(f"  Attention layers: {len(deep_weights)}")
    print()
    
    print(f"Efficiency Comparison:")
    print(f"  Parameter ratio (Deep/Shallow): {deep_params/shallow_params:.1f}x")
    print(f"  Time ratio (Deep/Shallow): {deep_time/shallow_time:.1f}x")
    
    return shallow_attn, deep_attn, shallow_output, deep_output

# Run efficiency comparison
shallow_model, deep_model, shallow_out, deep_out = compare_attention_efficiency()
