# Module 10: Capstone - Build the Segmented WS Multi-Modal Architecture

Congratulations on reaching the capstone! In this notebook, you'll build the complete architecture that combines everything you've learned:

- Modality-specific encoders (visual, text, audio)
- Watts-Strogatz inter-module connector
- Learnable beta parameter
- Multi-modal fusion
- Dynamic sparse training

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from src.models import (
    SegmentedWSArchitecture,
    MultiModalWSNetwork,
)
from src.data import SyntheticMultiModal, create_dataloaders
from src.training import Trainer, TrainingConfig, SparseTrainer
from src.visualization import (
    plot_loss_curves,
    create_training_dashboard,
    plot_inter_module_connectivity,
)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 10.1 The Architecture Overview

```
+-------------------------------------------------------------------+
|  +----------+   +----------+   +----------+                       |
|  |  Visual  |   |   Text   |   |  Audio   |   <- Modality         |
|  |  Module  |   |  Module  |   |  Module  |      Encoders         |
|  +----+-----+   +----+-----+   +----+-----+                       |
|       |              |              |                              |
|       +-------+------+------+-------+                              |
|               |                                                    |
|        +------v------+                                             |
|        |   WS Inter- |  <- Learnable Small-World                  |
|        |   Module    |     Connector with Dynamic                  |
|        |   Connector |     Rewiring (beta learnable)               |
|        +------+------+                                             |
|               |                                                    |
|        +------v------+                                             |
|        |    Fusion   |  <- Cross-Modal Integration                 |
|        |    Module   |                                             |
|        +-------------+                                             |
+-------------------------------------------------------------------+
```

## 10.2 Create the Dataset

We'll use synthetic multi-modal data for demonstration:

In [None]:
# Create synthetic multi-modal dataset
train_dataset = SyntheticMultiModal(
    n_samples=5000,
    visual_dim=(1, 28, 28),
    text_seq_len=32,
    vocab_size=1000,
    audio_dim=(128, 64),
    n_classes=10,
    correlation=0.7,  # How correlated modalities are with labels
)

val_dataset = SyntheticMultiModal(
    n_samples=1000,
    visual_dim=(1, 28, 28),
    text_seq_len=32,
    vocab_size=1000,
    audio_dim=(128, 64),
    n_classes=10,
    correlation=0.7,
    seed=123,  # Different seed for validation
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Check a sample
sample = train_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Visual shape: {sample['visual'].shape}")
print(f"Text shape: {sample['text'].shape}")
print(f"Audio shape: {sample['audio'].shape}")
print(f"Label: {sample['label']}")

## 10.3 Build the Architecture

Now let's create our Segmented Watts-Strogatz Multi-Modal Architecture:

In [None]:
# Create the architecture
model = SegmentedWSArchitecture(
    visual_config={'input_shape': (1, 28, 28), 'hidden_dims': [256, 128]},
    text_config={'vocab_size': 1000, 'embed_dim': 64, 'hidden_dim': 128},
    audio_config={'input_dim': 128, 'hidden_dims': [256, 128]},
    segment_dim=64,         # Dimension of each module's output
    n_ws_layers=2,          # WS-connected processing layers
    ws_k=2,                 # Initial neighbors in WS topology
    initial_beta=0.3,       # Starting rewiring probability
    use_moe=False,          # Use Mixture of Experts (try True!)
    sparse_layers=True,     # Use sparse connectivity
    layer_density=0.3,      # 30% density (70% sparse)
    output_dim=10,          # 10 classes
    dropout=0.1,
)

model = model.to(device)

# Print architecture stats
stats = model.get_architecture_stats()
print("Architecture Statistics:")
print(f"  Total parameters: {stats['total_params']:,}")
print(f"  Trainable parameters: {stats['trainable_params']:,}")
print(f"  WS layers: {stats['n_ws_layers']}")
print(f"  Beta values: {stats['betas']}")

## 10.4 Test Forward Pass

In [None]:
# Test forward pass
batch = train_dataset[0]

# Add batch dimension and move to device
visual = batch['visual'].unsqueeze(0).to(device)
text = batch['text'].unsqueeze(0).to(device)
audio = batch['audio'].unsqueeze(0).to(device)

# Forward pass
with torch.no_grad():
    output = model(visual=visual, text=text, audio=audio)

print(f"Input shapes:")
print(f"  Visual: {visual.shape}")
print(f"  Text: {text.shape}")
print(f"  Audio: {audio.shape}")
print(f"\nOutput shape: {output.shape}")
print(f"Predictions: {torch.softmax(output, dim=-1).squeeze()}")

## 10.5 Training the Architecture

In [None]:
from torch.utils.data import DataLoader

# Create data loaders
def collate_fn(batch):
    """Custom collate function for multi-modal data."""
    return {
        'visual': torch.stack([b['visual'] for b in batch]),
        'text': torch.stack([b['text'] for b in batch]),
        'audio': torch.stack([b['audio'] for b in batch]),
        'label': torch.stack([b['label'] for b in batch]),
    }

train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn
)

In [None]:
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
n_epochs = 10  # Use more epochs for better results

for epoch in range(n_epochs):
    # Training
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    for batch in train_loader:
        visual = batch['visual'].to(device)
        text = batch['text'].to(device)
        audio = batch['audio'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(visual=visual, text=text, audio=audio)
        loss = criterion(outputs, labels)
        
        # Add auxiliary loss if using MoE
        if hasattr(model, 'aux_loss'):
            loss = loss + model.aux_loss
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
    
    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            visual = batch['visual'].to(device)
            text = batch['text'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(visual=visual, text=text, audio=audio)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    scheduler.step()
    
    # Record history
    history['train_loss'].append(train_loss / len(train_loader))
    history['val_loss'].append(val_loss / len(val_loader))
    history['train_acc'].append(100 * train_correct / train_total)
    history['val_acc'].append(100 * val_correct / val_total)
    
    print(f"Epoch {epoch+1}/{n_epochs}: "
          f"Train Loss={history['train_loss'][-1]:.4f}, "
          f"Train Acc={history['train_acc'][-1]:.1f}%, "
          f"Val Acc={history['val_acc'][-1]:.1f}%")

In [None]:
# Plot training progress
fig = plot_loss_curves(history, title='Segmented WS Architecture Training')
plt.show()

## 10.6 Analyze the Learned Architecture

Let's examine what the model learned:

In [None]:
# Check learned beta values
print("Learned WS parameters:")
print(f"  Beta values per layer: {model.betas}")

# Visualize inter-module connectivity
adj = model.segment_adj.cpu().numpy()
fig = plot_inter_module_connectivity(
    adj, 
    module_names=['Visual', 'Text', 'Audio']
)
plt.show()

## 10.7 Exercise: Experiment with the Architecture

Try modifying the architecture:

1. Enable Mixture of Experts (`use_moe=True`)
2. Change the sparsity level (`layer_density`)
3. Adjust the initial beta value
4. Try different segment dimensions

What gives the best accuracy vs parameter efficiency tradeoff?

In [None]:
# YOUR EXPERIMENTS HERE
# Example: Try with MoE enabled
# model_moe = SegmentedWSArchitecture(
#     ...
#     use_moe=True,
#     n_experts=4,
#     ...
# )

## Congratulations!

You've built a complete Segmented Watts-Strogatz Multi-Modal Architecture!

**Key innovations in this architecture:**

1. **Heterogeneous modules** - Different encoder architectures for each modality
2. **WS topology** - Small-world connectivity for efficient inter-module communication
3. **Learnable beta** - The network can optimize its own topology
4. **Sparse connectivity** - Fewer parameters with maintained performance
5. **Optional MoE** - Expert routing for modality-specific processing

**What you've learned in this curriculum:**

- Neural network fundamentals
- Graph theory and network topology
- Sparse neural networks
- Dynamic sparse training (SET, DEEP R)
- Multi-modal learning
- How to combine all these into a novel architecture!