# Module 09: Multi-Modal Architectures

Multi-modal learning combines information from different data modalities (vision, text, audio).

## Learning Objectives
- Understand modality-specific encoders
- Learn fusion strategies (early, late, hybrid)
- Implement cross-modal attention
- Connect multi-modal fusion with Watts-Strogatz topology

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

print("[OK] Libraries loaded")

## 1. Multi-Modal Data

Different modalities capture different aspects of the world.

In [None]:
# Simulate multi-modal data
batch_size = 8

# Visual: e.g., image features (could be from CNN)
visual = torch.randn(batch_size, 3, 32, 32)  # 3-channel 32x32 images

# Text: e.g., word embeddings (sequence of tokens)
text = torch.randn(batch_size, 20, 128)  # 20 tokens, 128-dim embeddings

# Audio: e.g., mel spectrogram features
audio = torch.randn(batch_size, 1, 64, 100)  # 64 mel bands, 100 time frames

print("Multi-Modal Input Shapes:")
print(f"  Visual: {visual.shape} (batch, channels, height, width)")
print(f"  Text: {text.shape} (batch, sequence, embedding)")
print(f"  Audio: {audio.shape} (batch, channels, freq, time)")

## 2. Modality-Specific Encoders

In [None]:
class VisualEncoder(nn.Module):
    """Encode visual features using CNN."""
    def __init__(self, output_dim=256):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, output_dim)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.flatten(1)
        return self.fc(x)


class TextEncoder(nn.Module):
    """Encode text features using Transformer-style attention."""
    def __init__(self, input_dim=128, output_dim=256):
        super().__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads=4, batch_first=True)
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        # Self-attention over sequence
        attn_out, _ = self.attention(x, x, x)
        # Pool over sequence
        pooled = attn_out.mean(dim=1)
        return self.fc(pooled)


class AudioEncoder(nn.Module):
    """Encode audio features using CNN."""
    def __init__(self, output_dim=256):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, output_dim)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.flatten(1)
        return self.fc(x)


# Test encoders
visual_enc = VisualEncoder(256)
text_enc = TextEncoder(128, 256)
audio_enc = AudioEncoder(256)

v_feat = visual_enc(visual)
t_feat = text_enc(text)
a_feat = audio_enc(audio)

print("Encoded Feature Shapes:")
print(f"  Visual: {v_feat.shape}")
print(f"  Text: {t_feat.shape}")
print(f"  Audio: {a_feat.shape}")
print("\n[OK] All modalities encoded to same dimension (256)")

## 3. Fusion Strategies

In [None]:
class EarlyFusion(nn.Module):
    """Concatenate features early and process together."""
    def __init__(self, modality_dims, hidden_dim, output_dim):
        super().__init__()
        total_dim = sum(modality_dims)
        self.fusion = nn.Sequential(
            nn.Linear(total_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, features):
        # features: list of tensors [visual, text, audio]
        concat = torch.cat(features, dim=-1)
        return self.fusion(concat)


class LateFusion(nn.Module):
    """Process each modality separately, combine predictions."""
    def __init__(self, modality_dim, num_modalities, output_dim):
        super().__init__()
        self.heads = nn.ModuleList([
            nn.Linear(modality_dim, output_dim)
            for _ in range(num_modalities)
        ])
        self.weights = nn.Parameter(torch.ones(num_modalities) / num_modalities)
    
    def forward(self, features):
        # Each modality makes its own prediction
        predictions = [head(feat) for head, feat in zip(self.heads, features)]
        # Weighted average
        weights = F.softmax(self.weights, dim=0)
        combined = sum(w * p for w, p in zip(weights, predictions))
        return combined


class HybridFusion(nn.Module):
    """Combine early fusion with modality-specific processing."""
    def __init__(self, modality_dim, hidden_dim, output_dim):
        super().__init__()
        # Per-modality refinement
        self.refine = nn.ModuleList([
            nn.Linear(modality_dim, modality_dim) for _ in range(3)
        ])
        # Joint processing
        self.joint = nn.Sequential(
            nn.Linear(modality_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, features):
        refined = [F.relu(r(f)) for r, f in zip(self.refine, features)]
        concat = torch.cat(refined, dim=-1)
        return self.joint(concat)


# Test fusion strategies
features = [v_feat, t_feat, a_feat]

early = EarlyFusion([256, 256, 256], 256, 10)
late = LateFusion(256, 3, 10)
hybrid = HybridFusion(256, 256, 10)

print("Fusion Output Shapes:")
print(f"  Early Fusion: {early(features).shape}")
print(f"  Late Fusion: {late(features).shape}")
print(f"  Hybrid Fusion: {hybrid(features).shape}")

## 4. Cross-Modal Attention

In [None]:
class CrossModalAttention(nn.Module):
    """
    Attention from one modality to another.
    Query from modality A, Key/Value from modality B.
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
    
    def forward(self, query_modality, key_value_modality):
        # Add sequence dimension if needed
        if query_modality.dim() == 2:
            query_modality = query_modality.unsqueeze(1)
        if key_value_modality.dim() == 2:
            key_value_modality = key_value_modality.unsqueeze(1)
        
        attn_out, attn_weights = self.attention(
            query_modality, 
            key_value_modality, 
            key_value_modality
        )
        return attn_out.squeeze(1), attn_weights


class MultiModalTransformer(nn.Module):
    """Full multi-modal fusion with cross-attention."""
    def __init__(self, dim=256, num_heads=4, output_dim=10):
        super().__init__()
        # Cross-modal attention: each modality attends to others
        self.v_to_t = CrossModalAttention(dim, num_heads)
        self.v_to_a = CrossModalAttention(dim, num_heads)
        self.t_to_v = CrossModalAttention(dim, num_heads)
        self.t_to_a = CrossModalAttention(dim, num_heads)
        self.a_to_v = CrossModalAttention(dim, num_heads)
        self.a_to_t = CrossModalAttention(dim, num_heads)
        
        # Fusion
        self.fusion = nn.Linear(dim * 3, output_dim)
    
    def forward(self, visual, text, audio):
        # Cross-modal attention
        v_from_t, _ = self.v_to_t(visual, text)
        v_from_a, _ = self.v_to_a(visual, audio)
        v_enhanced = visual + v_from_t + v_from_a
        
        t_from_v, _ = self.t_to_v(text, visual)
        t_from_a, _ = self.t_to_a(text, audio)
        t_enhanced = text + t_from_v + t_from_a
        
        a_from_v, _ = self.a_to_v(audio, visual)
        a_from_t, _ = self.a_to_t(audio, text)
        a_enhanced = audio + a_from_v + a_from_t
        
        # Combine
        combined = torch.cat([v_enhanced, t_enhanced, a_enhanced], dim=-1)
        return self.fusion(combined)


# Test
mm_transformer = MultiModalTransformer(256, 4, 10)
output = mm_transformer(v_feat, t_feat, a_feat)
print(f"Multi-Modal Transformer output: {output.shape}")

## 5. Complete Multi-Modal Network

In [None]:
class MultiModalNetwork(nn.Module):
    """Complete multi-modal classification network."""
    def __init__(self, hidden_dim=256, num_classes=10):
        super().__init__()
        # Encoders
        self.visual_encoder = VisualEncoder(hidden_dim)
        self.text_encoder = TextEncoder(128, hidden_dim)
        self.audio_encoder = AudioEncoder(hidden_dim)
        
        # Cross-modal fusion
        self.cross_attention = MultiModalTransformer(hidden_dim, 4, num_classes)
    
    def forward(self, visual, text, audio):
        # Encode each modality
        v_feat = self.visual_encoder(visual)
        t_feat = self.text_encoder(text)
        a_feat = self.audio_encoder(audio)
        
        # Fuse and classify
        return self.cross_attention(v_feat, t_feat, a_feat)


# Create and test
model = MultiModalNetwork(hidden_dim=256, num_classes=10)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Test forward pass
output = model(visual, text, audio)
print(f"Output shape: {output.shape}")

# Simulate training
labels = torch.randint(0, 10, (batch_size,))
loss = F.cross_entropy(output, labels)
print(f"Loss: {loss.item():.4f}")

## 6. Watts-Strogatz Connectivity for Multi-Modal Fusion

We can use small-world topology to connect modality-specific modules.

In [None]:
import networkx as nx

def create_ws_cross_modal_mask(n_modalities, features_per_modality, k=4, beta=0.3):
    """
    Create a Watts-Strogatz connectivity pattern between modality features.
    """
    total_features = n_modalities * features_per_modality
    
    # Create WS graph
    G = nx.watts_strogatz_graph(total_features, k, beta, seed=42)
    adj = nx.to_numpy_array(G)
    
    return torch.FloatTensor(adj)


class WSMultiModalFusion(nn.Module):
    """Multi-modal fusion using Watts-Strogatz connectivity."""
    def __init__(self, n_modalities=3, features_per_modality=64, k=4, beta=0.3):
        super().__init__()
        total = n_modalities * features_per_modality
        
        # Learnable weights
        self.weight = nn.Parameter(torch.randn(total, total) * 0.01)
        
        # WS connectivity mask
        mask = create_ws_cross_modal_mask(n_modalities, features_per_modality, k, beta)
        self.register_buffer('mask', mask)
        
        self.n_modalities = n_modalities
        self.features_per_modality = features_per_modality
    
    def forward(self, features):
        # features: list of [batch, features_per_modality] tensors
        concat = torch.cat(features, dim=-1)  # [batch, total]
        
        # Apply sparse WS connectivity
        masked_weight = self.weight * self.mask
        output = F.linear(concat, masked_weight)
        
        return output
    
    def get_sparsity(self):
        return 1 - (self.mask.sum() / self.mask.numel()).item()


# Test
ws_fusion = WSMultiModalFusion(n_modalities=3, features_per_modality=64, k=6, beta=0.3)

# Create features
features = [torch.randn(8, 64) for _ in range(3)]
output = ws_fusion(features)

print(f"WS Fusion output shape: {output.shape}")
print(f"Connectivity sparsity: {ws_fusion.get_sparsity():.1%}")

# Visualize mask
plt.figure(figsize=(8, 6))
plt.imshow(ws_fusion.mask.numpy(), cmap='Blues')
plt.title('Watts-Strogatz Cross-Modal Connectivity')
plt.xlabel('Output Features')
plt.ylabel('Input Features')

# Add modality boundaries
for i in range(1, 3):
    pos = i * 64
    plt.axhline(pos, color='red', linestyle='--', alpha=0.5)
    plt.axvline(pos, color='red', linestyle='--', alpha=0.5)

plt.colorbar(label='Connection')
plt.show()

## Summary

Key concepts covered:

1. **Modality-Specific Encoders**: CNN for vision/audio, Transformer for text
2. **Fusion Strategies**: Early (concatenate), Late (combine predictions), Hybrid
3. **Cross-Modal Attention**: Let each modality attend to others
4. **Complete Pipeline**: Encoders + Fusion + Classifier
5. **WS Topology**: Use small-world connectivity for efficient cross-modal fusion

## Why Multi-Modal + Sparse WS?

- Multi-modal captures complementary information
- WS topology provides efficient information routing
- Sparse connections reduce parameters while maintaining performance
- Learnable beta can adapt connectivity during training

## Next Steps

- [->] Module 10: Capstone - Build Your Own Segmented WS Multi-Modal Architecture