# Multi-Modal Models Tutorial
## HSE 2025 - Introduction to Multi-Modal Learning

This notebook introduces the concepts and practical implementation of multi-modal machine learning using the MMM framework.

### Learning Objectives
- Understand multi-modal learning principles
- Learn to work with different data modalities
- Implement and train multi-modal classifiers
- Explore different fusion strategies
- Evaluate multi-modal models

## 1. Setup and Imports

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torchvision transformers matplotlib seaborn

import sys
import os

# Add src to path for imports
sys.path.append('../src')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

# MMM imports
from models.multimodal_classifier import MultiModalClassifier
from data.multimodal_dataset import MultiModalDataset

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ Setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 2. Understanding Multi-Modal Data

Multi-modal learning involves processing and learning from multiple types of data simultaneously:

- **Text**: Natural language (reviews, descriptions, captions)
- **Images**: Visual content (photos, diagrams, charts)
- **Audio**: Sound data (speech, music, environmental sounds)
- **Video**: Temporal visual sequences
- **Tabular**: Structured numerical data

The key challenge is learning meaningful representations that capture relationships **across** modalities.

In [None]:
# Let's create synthetic multi-modal data to understand the concepts

def create_synthetic_dataset(n_samples=1000, n_classes=5):
    """
    Create synthetic multi-modal dataset for demonstration.
    
    In real scenarios, you would:
    - Extract text features using BERT/RoBERTa
    - Extract image features using ResNet/ViT
    - Extract audio features using spectrograms/MFCC
    """
    
    # Text features (simulating BERT embeddings)
    text_features = torch.randn(n_samples, 768)
    
    # Image features (simulating ResNet features)
    image_features = torch.randn(n_samples, 2048)
    
    # Audio features (simulating mel-spectrogram features)
    audio_features = torch.randn(n_samples, 512)
    
    # Create correlated labels (some modalities are more informative for certain classes)
    labels = torch.randint(0, n_classes, (n_samples,))
    
    # Add some correlation between modalities and labels
    for i in range(n_classes):
        mask = labels == i
        # Make class i more distinctive in text features
        text_features[mask] += torch.randn(mask.sum(), 768) * 0.5 + i
        # Make class i more distinctive in image features  
        image_features[mask] += torch.randn(mask.sum(), 2048) * 0.3 + i * 0.5
        # Make class i more distinctive in audio features
        audio_features[mask] += torch.randn(mask.sum(), 512) * 0.4 + i * 0.3
    
    return {
        'text': text_features,
        'image': image_features, 
        'audio': audio_features,
        'labels': labels
    }

# Create datasets
train_data = create_synthetic_dataset(n_samples=800, n_classes=5)
val_data = create_synthetic_dataset(n_samples=200, n_classes=5)

print("Dataset created:")
print(f"Training samples: {len(train_data['labels'])}")
print(f"Validation samples: {len(val_data['labels'])}")
print(f"Number of classes: 5")
print(f"Text feature dim: {train_data['text'].shape[1]}")
print(f"Image feature dim: {train_data['image'].shape[1]}")
print(f"Audio feature dim: {train_data['audio'].shape[1]}")

## 3. Multi-Modal Fusion Strategies

There are several ways to combine information from different modalities:

In [None]:
# Let's explore different fusion strategies

fusion_methods = ['concat', 'attention', 'sum']
models = {}

for fusion_method in fusion_methods:
    model = MultiModalClassifier(
        num_classes=5,
        text_dim=768,
        image_dim=2048, 
        audio_dim=512,
        hidden_dim=256,
        fusion_method=fusion_method,
        dropout=0.1
    )
    models[fusion_method] = model
    
    # Count parameters
    param_count = model.get_parameter_count()
    print(f"{fusion_method.capitalize()} Fusion:")
    print(f"  Parameters: {param_count['total']:,}")
    print()

### 3.1 Concatenation Fusion
Simply concatenates all modality representations and passes through a classifier.

In [None]:
# Test concatenation fusion
concat_model = models['concat']
concat_model.eval()

# Sample batch
batch_size = 32
sample_inputs = {
    'text': train_data['text'][:batch_size],
    'image': train_data['image'][:batch_size],
    'audio': train_data['audio'][:batch_size]
}

with torch.no_grad():
    outputs = concat_model(sample_inputs)
    
print("Concatenation Fusion Results:")
print(f"Input shapes:")
for modality, data in sample_inputs.items():
    print(f"  {modality}: {data.shape}")
    
print(f"\nOutput shapes:")
print(f"  Fused features: {outputs['fused_features'].shape}")
print(f"  Logits: {outputs['logits'].shape}")
print(f"  Predictions: {outputs['predictions'].shape}")

### 3.2 Attention Fusion
Uses attention mechanism to learn which modalities are most important for each sample.

In [None]:
# Test attention fusion
attention_model = models['attention']
attention_model.eval()

with torch.no_grad():
    outputs = attention_model(sample_inputs)
    
print("Attention Fusion Results:")
print(f"Fused features shape: {outputs['fused_features'].shape}")
print(f"Predictions shape: {outputs['predictions'].shape}")

# The attention model learns to weight different modalities
print("\nAttention mechanism learns optimal modality weights for each sample!")

## 4. Training Multi-Modal Models

Let's train our models and compare their performance.

In [None]:
def train_model(model, train_data, val_data, epochs=20, lr=0.001):
    """
    Simple training loop for demonstration.
    In practice, you'd use more sophisticated training utilities.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    
    train_losses = []
    val_accuracies = []
    
    batch_size = 64
    n_train = len(train_data['labels'])
    n_val = len(val_data['labels'])
    
    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0
        
        for i in range(0, n_train, batch_size):
            end_idx = min(i + batch_size, n_train)
            
            inputs = {
                'text': train_data['text'][i:end_idx],
                'image': train_data['image'][i:end_idx],
                'audio': train_data['audio'][i:end_idx]
            }
            labels = train_data['labels'][i:end_idx]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs['logits'], labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        # Validation
        model.eval()
        correct = 0
        
        with torch.no_grad():
            for i in range(0, n_val, batch_size):
                end_idx = min(i + batch_size, n_val)
                
                inputs = {
                    'text': val_data['text'][i:end_idx],
                    'image': val_data['image'][i:end_idx], 
                    'audio': val_data['audio'][i:end_idx]
                }
                labels = val_data['labels'][i:end_idx]
                
                predictions = model.predict(inputs)
                correct += (predictions == labels).sum().item()
        
        avg_loss = epoch_loss / (n_train // batch_size)
        accuracy = correct / n_val
        
        train_losses.append(avg_loss)
        val_accuracies.append(accuracy)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:2d}: Loss = {avg_loss:.4f}, Val Acc = {accuracy:.4f}")
    
    return train_losses, val_accuracies

print("Training models with different fusion strategies...")
print("This may take a few minutes.\n")

In [None]:
# Train all models
results = {}

for fusion_method, model in models.items():
    print(f"Training {fusion_method} fusion model...")
    train_losses, val_accuracies = train_model(model, train_data, val_data, epochs=15)
    results[fusion_method] = {
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'final_accuracy': val_accuracies[-1]
    }
    print(f"Final validation accuracy: {val_accuracies[-1]:.4f}\n")

print("Training completed!")

## 5. Model Evaluation and Comparison

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Training losses
for fusion_method, result in results.items():
    ax1.plot(result['train_losses'], label=f'{fusion_method.capitalize()} Fusion')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Validation accuracies
for fusion_method, result in results.items():
    ax2.plot(result['val_accuracies'], label=f'{fusion_method.capitalize()} Fusion')
ax2.set_title('Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

# Final accuracy comparison
print("Final Validation Accuracies:")
for fusion_method, result in results.items():
    print(f"{fusion_method.capitalize():>10}: {result['final_accuracy']:.4f}")

## 6. Handling Missing Modalities

One of the key advantages of multi-modal models is their ability to handle missing modalities gracefully.

In [None]:
# Test with different modality combinations
best_model = models['attention']  # Use the attention model
best_model.eval()

test_batch = {
    'text': val_data['text'][:10],
    'image': val_data['image'][:10],
    'audio': val_data['audio'][:10]
}
test_labels = val_data['labels'][:10]

modality_combinations = [
    (['text', 'image', 'audio'], "All Modalities"),
    (['text', 'image'], "Text + Image"),
    (['text', 'audio'], "Text + Audio"),
    (['image', 'audio'], "Image + Audio"),
    (['text'], "Text Only"),
    (['image'], "Image Only"),
    (['audio'], "Audio Only")
]

print("Performance with different modality combinations:")
print("=" * 50)

with torch.no_grad():
    for modalities, description in modality_combinations:
        # Create input with only specified modalities
        partial_input = {mod: test_batch[mod] for mod in modalities}
        
        # Get predictions
        predictions = best_model.predict(partial_input)
        accuracy = (predictions == test_labels).float().mean().item()
        
        print(f"{description:>20}: {accuracy:.3f} accuracy")

print("\nüí° Notice how the model maintains reasonable performance even with missing modalities!")

## 7. Feature Analysis and Visualization

In [None]:
# Analyze the learned representations
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def visualize_features(model, data, labels, title="Feature Visualization"):
    """Visualize learned features using t-SNE."""
    model.eval()
    
    with torch.no_grad():
        inputs = {
            'text': data['text'][:300],  # Use subset for visualization
            'image': data['image'][:300],
            'audio': data['audio'][:300]
        }
        outputs = model(inputs)
        features = outputs['fused_features'].numpy()
        labels_subset = labels[:300].numpy()
    
    # Apply t-SNE for 2D visualization
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    features_2d = tsne.fit_transform(features)
    
    # Plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], 
                         c=labels_subset, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter)
    plt.title(f'{title} - Learned Feature Space (t-SNE)')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.grid(True, alpha=0.3)
    plt.show()

# Visualize features from the best model
visualize_features(best_model, val_data, val_data['labels'], 
                  "Multi-Modal Attention Fusion")

## 8. Real-World Applications

Multi-modal learning has many practical applications:

### 8.1 Sentiment Analysis
- **Text**: Review text
- **Image**: Product images
- **Audio**: Voice tone in video reviews

### 8.2 Medical Diagnosis
- **Text**: Patient symptoms and history
- **Image**: X-rays, MRI scans
- **Audio**: Heart sounds, breathing patterns

### 8.3 Content Moderation
- **Text**: Captions and comments
- **Image**: Visual content
- **Audio**: Speech in videos

### 8.4 Autonomous Driving
- **Image**: Camera feeds
- **Audio**: Engine sounds, sirens
- **Sensor**: LiDAR, GPS data

In [None]:
# Example: Simulate a content moderation scenario
print("üîç Content Moderation Example")
print("=" * 40)

# Simulate different types of content
content_types = {
    0: "Safe Content",
    1: "Questionable Content", 
    2: "Inappropriate Content",
    3: "Spam Content",
    4: "Harmful Content"
}

# Create a specialized model for content moderation
moderation_model = MultiModalClassifier(
    num_classes=5,
    text_dim=768,   # Text from captions/comments
    image_dim=2048, # Visual content features
    audio_dim=512,  # Audio from videos
    fusion_method="attention",
    hidden_dim=512
)

print(f"Content Moderation Model:")
print(f"- Classes: {list(content_types.values())}")
print(f"- Parameters: {moderation_model.get_parameter_count()['total']:,}")
print(f"- Fusion Method: Attention (learns which modality is most indicative)")

# In practice, you would:
# 1. Extract text features using BERT from captions/comments
# 2. Extract image features using Vision Transformers
# 3. Extract audio features from video soundtracks
# 4. Train on labeled content moderation datasets
# 5. Deploy for real-time content screening

print("\n‚ú® This demonstrates the power of multi-modal learning!")
print("   Each modality provides complementary information for better decisions.")

## 9. Key Takeaways and Next Steps

### What We Learned:
1. **Multi-modal Fusion**: Different strategies (concat, attention, sum) for combining modalities
2. **Robustness**: Models can handle missing modalities gracefully
3. **Flexibility**: Framework supports various architectures and use cases
4. **Performance**: Attention fusion often performs best by learning optimal weights

### Best Practices:
- **Data Quality**: Ensure all modalities are properly preprocessed
- **Feature Engineering**: Use domain-appropriate feature extractors
- **Fusion Strategy**: Choose based on your specific task and data
- **Evaluation**: Test with different modality combinations
- **Interpretability**: Use attention weights to understand model decisions

### Next Steps for HSE Students:
1. **Practice**: Try the framework with real datasets
2. **Experiment**: Create custom fusion mechanisms
3. **Apply**: Use for your course projects
4. **Extend**: Add new modalities (video, sensor data)
5. **Research**: Explore latest multi-modal architectures

In [None]:
# Final summary
print("üéì Congratulations! You've completed the Multi-Modal Learning Tutorial")
print("=" * 70)
print("\nüìö What you've accomplished:")
print("  ‚úÖ Understood multi-modal learning concepts")
print("  ‚úÖ Implemented different fusion strategies")
print("  ‚úÖ Trained and evaluated multi-modal models")
print("  ‚úÖ Explored handling missing modalities")
print("  ‚úÖ Visualized learned representations")
print("  ‚úÖ Considered real-world applications")

print("\nüöÄ Ready for advanced multi-modal AI research!")
print("\nüìñ Additional Resources:")
print("  - MMM Documentation: docs/")
print("  - More Examples: examples/")
print("  - Configuration: configs/")
print("  - HSE AI Lab: ai-lab@hse.ru")

print("\nüéØ HSE 2025 - Multi-Modal Models Framework")
print("   Happy Learning! üß†‚ú®")