# Advanced Multi-Modal AI Research Notebook

This notebook demonstrates cutting-edge research and experimentation with multi-modal AI models, including:
- Custom model architectures
- Advanced training techniques
- Multi-modal fusion strategies
- Performance optimization
- Research experiments

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2

from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    CLIPProcessor, CLIPModel,
    BlipProcessor, BlipForConditionalGeneration
)

import wandb
from sklearn.metrics import accuracy_score, classification_report
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Advanced Multi-Modal Architecture Design

Let's design a state-of-the-art multi-modal architecture that combines vision and language understanding with advanced fusion techniques.

In [None]:
class CrossModalAttention(nn.Module):
    """Cross-modal attention mechanism for vision-language fusion"""
    
    def __init__(self, vision_dim, text_dim, hidden_dim, num_heads=8):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Vision projections
        self.vision_query = nn.Linear(vision_dim, hidden_dim)
        self.vision_key = nn.Linear(vision_dim, hidden_dim)
        self.vision_value = nn.Linear(vision_dim, hidden_dim)
        
        # Text projections
        self.text_query = nn.Linear(text_dim, hidden_dim)
        self.text_key = nn.Linear(text_dim, hidden_dim)
        self.text_value = nn.Linear(text_dim, hidden_dim)
        
        # Output projections
        self.vision_out = nn.Linear(hidden_dim, vision_dim)
        self.text_out = nn.Linear(hidden_dim, text_dim)
        
        self.dropout = nn.Dropout(0.1)
        self.layer_norm_v = nn.LayerNorm(vision_dim)
        self.layer_norm_t = nn.LayerNorm(text_dim)
    
    def forward(self, vision_features, text_features):
        batch_size = vision_features.size(0)
        
        # Vision to Text attention
        v_q = self.vision_query(vision_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        t_k = self.text_key(text_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        t_v = self.text_value(text_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        v2t_scores = torch.matmul(v_q, t_k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        v2t_attn = torch.softmax(v2t_scores, dim=-1)
        v2t_out = torch.matmul(v2t_attn, t_v)
        
        # Text to Vision attention
        t_q = self.text_query(text_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v_k = self.vision_key(vision_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v_v = self.vision_value(vision_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        t2v_scores = torch.matmul(t_q, v_k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        t2v_attn = torch.softmax(t2v_scores, dim=-1)
        t2v_out = torch.matmul(t2v_attn, v_v)
        
        # Reshape and project
        v2t_out = v2t_out.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        t2v_out = t2v_out.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        
        # Apply output projections and residual connections
        enhanced_vision = self.layer_norm_v(vision_features + self.dropout(self.vision_out(t2v_out.mean(1))))
        enhanced_text = self.layer_norm_t(text_features + self.dropout(self.text_out(v2t_out.mean(1))))
        
        return enhanced_vision, enhanced_text

class AdvancedMultiModalModel(nn.Module):
    """Advanced multi-modal model with cross-attention and hierarchical fusion"""
    
    def __init__(self, vision_model_name, text_model_name, num_classes, fusion_layers=3):
        super().__init__()
        
        # Load pre-trained encoders
        self.vision_encoder = AutoModel.from_pretrained(vision_model_name)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        
        # Get dimensions
        self.vision_dim = self.vision_encoder.config.hidden_size
        self.text_dim = self.text_encoder.config.hidden_size
        
        # Cross-modal attention layers
        self.cross_attention_layers = nn.ModuleList([
            CrossModalAttention(self.vision_dim, self.text_dim, 512)
            for _ in range(fusion_layers)
        ])
        
        # Hierarchical fusion
        self.fusion_dim = 256
        self.vision_proj = nn.Linear(self.vision_dim, self.fusion_dim)
        self.text_proj = nn.Linear(self.text_dim, self.fusion_dim)
        
        # Multi-scale fusion
        self.early_fusion = nn.Linear(self.fusion_dim * 2, self.fusion_dim)
        self.mid_fusion = nn.Linear(self.fusion_dim * 2, self.fusion_dim)
        self.late_fusion = nn.Linear(self.fusion_dim * 2, self.fusion_dim)
        
        # Classification head with attention pooling
        self.attention_pool = nn.MultiheadAttention(self.fusion_dim, num_heads=8, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.fusion_dim, self.fusion_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.fusion_dim // 2, num_classes)
        )
        
        # Auxiliary losses for better training
        self.vision_aux_classifier = nn.Linear(self.vision_dim, num_classes)
        self.text_aux_classifier = nn.Linear(self.text_dim, num_classes)
    
    def forward(self, pixel_values, input_ids, attention_mask, return_aux=False):
        # Encode modalities
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        vision_features = vision_outputs.pooler_output
        text_features = text_outputs.pooler_output
        
        # Store original features for auxiliary losses
        orig_vision_features = vision_features.clone()
        orig_text_features = text_features.clone()
        
        # Apply cross-modal attention layers
        for cross_attn in self.cross_attention_layers:
            vision_features, text_features = cross_attn(vision_features, text_features)
        
        # Project to fusion dimension
        vision_proj = self.vision_proj(vision_features)
        text_proj = self.text_proj(text_features)
        
        # Multi-scale fusion
        early_fused = torch.tanh(self.early_fusion(torch.cat([vision_proj, text_proj], dim=-1)))
        mid_fused = torch.tanh(self.mid_fusion(torch.cat([vision_proj * text_proj, vision_proj + text_proj], dim=-1)))
        late_fused = torch.tanh(self.late_fusion(torch.cat([early_fused, mid_fused], dim=-1)))
        
        # Attention pooling
        fusion_stack = torch.stack([early_fused, mid_fused, late_fused], dim=1)
        attended_features, _ = self.attention_pool(fusion_stack, fusion_stack, fusion_stack)
        final_features = attended_features.mean(dim=1)
        
        # Main classification
        logits = self.classifier(final_features)
        
        if return_aux:
            # Auxiliary classifications for regularization
            vision_aux_logits = self.vision_aux_classifier(orig_vision_features)
            text_aux_logits = self.text_aux_classifier(orig_text_features)
            return logits, vision_aux_logits, text_aux_logits
        
        return logits

# Initialize the model
model = AdvancedMultiModalModel(
    vision_model_name="google/vit-base-patch16-224",
    text_model_name="bert-base-uncased",
    num_classes=10,
    fusion_layers=3
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 2. Advanced Training Techniques

Implement cutting-edge training techniques including:
- Curriculum learning
- Contrastive learning
- Knowledge distillation
- Advanced optimization

In [None]:
class ContrastiveLoss(nn.Module):
    """Contrastive loss for multi-modal representation learning"""
    
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cosine_sim = nn.CosineSimilarity(dim=-1)
    
    def forward(self, vision_features, text_features):
        batch_size = vision_features.size(0)
        
        # Normalize features
        vision_features = nn.functional.normalize(vision_features, dim=-1)
        text_features = nn.functional.normalize(text_features, dim=-1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(vision_features, text_features.T) / self.temperature
        
        # Create labels (diagonal should be positive pairs)
        labels = torch.arange(batch_size).to(vision_features.device)
        
        # Compute contrastive loss
        loss_v2t = nn.functional.cross_entropy(similarity_matrix, labels)
        loss_t2v = nn.functional.cross_entropy(similarity_matrix.T, labels)
        
        return (loss_v2t + loss_t2v) / 2

class CurriculumLearningScheduler:
    """Curriculum learning scheduler for progressive difficulty"""
    
    def __init__(self, total_epochs, difficulty_levels=5):
        self.total_epochs = total_epochs
        self.difficulty_levels = difficulty_levels
        self.current_level = 1
    
    def get_difficulty_level(self, epoch):
        """Get current difficulty level based on epoch"""
        progress = epoch / self.total_epochs
        level = min(int(progress * self.difficulty_levels) + 1, self.difficulty_levels)
        return level
    
    def should_include_sample(self, sample_difficulty, current_level):
        """Determine if sample should be included based on difficulty"""
        return sample_difficulty <= current_level

class AdvancedTrainer:
    """Advanced trainer with multiple training techniques"""
    
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Optimizers
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay'],
            betas=(0.9, 0.999)
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=config['warmup_epochs'],
            T_mult=2,
            eta_min=config['learning_rate'] * 0.01
        )
        
        # Loss functions
        self.classification_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.contrastive_loss = ContrastiveLoss(temperature=0.07)
        
        # Curriculum learning
        self.curriculum_scheduler = CurriculumLearningScheduler(config['epochs'])
        
        # Mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if config['mixed_precision'] else None
        
        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []
    
    def train_epoch(self, epoch):
        """Train for one epoch with advanced techniques"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        # Get current curriculum difficulty level
        difficulty_level = self.curriculum_scheduler.get_difficulty_level(epoch)
        
        for batch_idx, batch in enumerate(self.train_loader):
            # Curriculum learning sample filtering
            if hasattr(batch, 'difficulty'):
                valid_samples = [
                    i for i, diff in enumerate(batch['difficulty'])
                    if self.curriculum_scheduler.should_include_sample(diff, difficulty_level)
                ]
                if not valid_samples:
                    continue
                
                # Filter batch based on curriculum
                batch = {k: v[valid_samples] if torch.is_tensor(v) else v for k, v in batch.items()}
            
            # Move to device
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            self.optimizer.zero_grad()
            
            if self.scaler:
                with torch.cuda.amp.autocast():
                    # Forward pass with auxiliary outputs
                    logits, vision_aux, text_aux = self.model(
                        pixel_values, input_ids, attention_mask, return_aux=True
                    )
                    
                    # Multi-task loss
                    main_loss = self.classification_loss(logits, labels)
                    aux_loss_v = self.classification_loss(vision_aux, labels)
                    aux_loss_t = self.classification_loss(text_aux, labels)
                    
                    # Contrastive loss for representation learning
                    vision_features = self.model.vision_encoder(pixel_values=pixel_values).pooler_output
                    text_features = self.model.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
                    contrastive_loss = self.contrastive_loss(vision_features, text_features)
                    
                    # Combined loss
                    total_batch_loss = (
                        main_loss + 
                        0.3 * aux_loss_v + 
                        0.3 * aux_loss_t + 
                        0.2 * contrastive_loss
                    )
                
                self.scaler.scale(total_batch_loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard training without mixed precision
                logits, vision_aux, text_aux = self.model(
                    pixel_values, input_ids, attention_mask, return_aux=True
                )
                
                main_loss = self.classification_loss(logits, labels)
                aux_loss_v = self.classification_loss(vision_aux, labels)
                aux_loss_t = self.classification_loss(text_aux, labels)
                
                vision_features = self.model.vision_encoder(pixel_values=pixel_values).pooler_output
                text_features = self.model.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
                contrastive_loss = self.contrastive_loss(vision_features, text_features)
                
                total_batch_loss = (
                    main_loss + 
                    0.3 * aux_loss_v + 
                    0.3 * aux_loss_t + 
                    0.2 * contrastive_loss
                )
                
                total_batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
            
            self.scheduler.step()
            
            total_loss += total_batch_loss.item()
            num_batches += 1
            
            # Log progress
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {total_batch_loss.item():.4f}, '
                      f'Difficulty Level: {difficulty_level}, LR: {self.scheduler.get_last_lr()[0]:.6f}')
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate(self):
        """Validate the model"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in self.val_loader:
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                logits = self.model(pixel_values, input_ids, attention_mask)
                loss = self.classification_loss(logits, labels)
                
                total_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        avg_loss = total_loss / len(self.val_loader)
        accuracy = 100 * correct / total
        
        self.val_losses.append(avg_loss)
        self.val_accuracies.append(accuracy)
        
        return avg_loss, accuracy
    
    def train(self):
        """Full training loop"""
        best_accuracy = 0
        patience = 0
        max_patience = self.config.get('patience', 10)
        
        for epoch in range(self.config['epochs']):
            print(f"\nEpoch {epoch + 1}/{self.config['epochs']}")
            print("-" * 50)
            
            # Training
            train_loss = self.train_epoch(epoch)
            
            # Validation
            val_loss, val_accuracy = self.validate()
            
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}")
            print(f"Val Accuracy: {val_accuracy:.2f}%")
            
            # Early stopping
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                patience = 0
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_accuracy': best_accuracy,
                }, 'best_model.pt')
            else:
                patience += 1
                if patience >= max_patience:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
        
        return best_accuracy

# Training configuration
training_config = {
    'epochs': 50,
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'warmup_epochs': 5,
    'mixed_precision': True,
    'patience': 10
}

print("Advanced training configuration loaded!")

## 3. Multi-Modal Data Analysis and Visualization

Analyze multi-modal data patterns and create advanced visualizations.

In [None]:
class MultiModalAnalyzer:
    """Advanced analyzer for multi-modal data patterns"""
    
    def __init__(self, model):
        self.model = model
        self.model.eval()
    
    def extract_features(self, dataloader, max_samples=1000):
        """Extract features from multi-modal data"""
        vision_features = []
        text_features = []
        labels = []
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                if i * dataloader.batch_size >= max_samples:
                    break
                
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Extract features from encoders
                vision_output = self.model.vision_encoder(pixel_values=pixel_values)
                text_output = self.model.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
                
                vision_features.append(vision_output.pooler_output.cpu().numpy())
                text_features.append(text_output.pooler_output.cpu().numpy())
                labels.append(batch['labels'].cpu().numpy())
        
        vision_features = np.vstack(vision_features)
        text_features = np.vstack(text_features)
        labels = np.hstack(labels)
        
        return vision_features, text_features, labels
    
    def analyze_feature_correlation(self, vision_features, text_features):
        """Analyze correlation between vision and text features"""
        # Compute correlation matrix
        combined_features = np.hstack([vision_features, text_features])
        correlation_matrix = np.corrcoef(combined_features.T)
        
        # Visualize correlation
        plt.figure(figsize=(12, 10))
        sns.heatmap(correlation_matrix[:50, :50], cmap='coolwarm', center=0, 
                   square=True, linewidths=0.5)
        plt.title('Feature Correlation Matrix (First 50 dimensions)')
        plt.tight_layout()
        plt.show()
        
        return correlation_matrix
    
    def visualize_feature_space(self, vision_features, text_features, labels, method='tsne'):
        """Visualize feature space using dimensionality reduction"""
        # Combine features
        combined_features = np.hstack([vision_features, text_features])
        
        # Apply dimensionality reduction
        if method == 'tsne':
            reducer = TSNE(n_components=2, random_state=42, perplexity=30)
        else:
            from sklearn.decomposition import PCA
            reducer = PCA(n_components=2, random_state=42)
        
        reduced_features = reducer.fit_transform(combined_features)
        
        # Create visualization
        plt.figure(figsize=(15, 5))
        
        # Combined features
        plt.subplot(1, 3, 1)
        scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], 
                            c=labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter)
        plt.title(f'Combined Features ({method.upper()})')
        plt.xlabel('Component 1')
        plt.ylabel('Component 2')
        
        # Vision features only
        vision_reduced = reducer.fit_transform(vision_features)
        plt.subplot(1, 3, 2)
        scatter = plt.scatter(vision_reduced[:, 0], vision_reduced[:, 1], 
                            c=labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter)
        plt.title(f'Vision Features ({method.upper()})')
        plt.xlabel('Component 1')
        plt.ylabel('Component 2')
        
        # Text features only
        text_reduced = reducer.fit_transform(text_features)
        plt.subplot(1, 3, 3)
        scatter = plt.scatter(text_reduced[:, 0], text_reduced[:, 1], 
                            c=labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter)
        plt.title(f'Text Features ({method.upper()})')
        plt.xlabel('Component 1')
        plt.ylabel('Component 2')
        
        plt.tight_layout()
        plt.show()
        
        return reduced_features
    
    def analyze_modality_importance(self, dataloader, num_samples=100):
        """Analyze the importance of each modality for predictions"""
        vision_only_correct = 0
        text_only_correct = 0
        combined_correct = 0
        total = 0
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                if i >= num_samples // dataloader.batch_size:
                    break
                
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Combined prediction
                combined_logits = self.model(pixel_values, input_ids, attention_mask)
                combined_pred = torch.argmax(combined_logits, dim=1)
                
                # Vision-only prediction (zero out text)
                zero_input_ids = torch.zeros_like(input_ids)
                zero_attention_mask = torch.zeros_like(attention_mask)
                vision_logits = self.model(pixel_values, zero_input_ids, zero_attention_mask)
                vision_pred = torch.argmax(vision_logits, dim=1)
                
                # Text-only prediction (zero out vision)
                zero_pixel_values = torch.zeros_like(pixel_values)
                text_logits = self.model(zero_pixel_values, input_ids, attention_mask)
                text_pred = torch.argmax(text_logits, dim=1)
                
                # Count correct predictions
                vision_only_correct += (vision_pred == labels).sum().item()
                text_only_correct += (text_pred == labels).sum().item()
                combined_correct += (combined_pred == labels).sum().item()
                total += labels.size(0)
        
        # Calculate accuracies
        vision_accuracy = vision_only_correct / total
        text_accuracy = text_only_correct / total
        combined_accuracy = combined_correct / total
        
        # Visualize results
        modalities = ['Vision Only', 'Text Only', 'Combined']
        accuracies = [vision_accuracy, text_accuracy, combined_accuracy]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(modalities, accuracies, color=['skyblue', 'lightcoral', 'lightgreen'])
        plt.ylabel('Accuracy')
        plt.title('Modality Importance Analysis')
        plt.ylim(0, 1)
        
        # Add value labels on bars
        for bar, acc in zip(bars, accuracies):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                    f'{acc:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return {
            'vision_accuracy': vision_accuracy,
            'text_accuracy': text_accuracy,
            'combined_accuracy': combined_accuracy,
            'synergy_effect': combined_accuracy - max(vision_accuracy, text_accuracy)
        }
    
    def cluster_analysis(self, vision_features, text_features, labels, n_clusters=5):
        """Perform clustering analysis on multi-modal features"""
        # Combine features
        combined_features = np.hstack([vision_features, text_features])
        
        # Perform clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(combined_features)
        
        # Visualize clusters
        tsne = TSNE(n_components=2, random_state=42)
        reduced_features = tsne.fit_transform(combined_features)
        
        plt.figure(figsize=(15, 5))
        
        # True labels
        plt.subplot(1, 3, 1)
        scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], 
                            c=labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter)
        plt.title('True Labels')
        plt.xlabel('t-SNE 1')
        plt.ylabel('t-SNE 2')
        
        # Cluster labels
        plt.subplot(1, 3, 2)
        scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], 
                            c=cluster_labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter)
        plt.title('K-Means Clusters')
        plt.xlabel('t-SNE 1')
        plt.ylabel('t-SNE 2')
        
        # Cluster centers
        plt.subplot(1, 3, 3)
        scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], 
                            c=cluster_labels, cmap='tab10', alpha=0.7)
        
        # Transform cluster centers to t-SNE space (approximate)
        centers_2d = tsne.fit_transform(kmeans.cluster_centers_)
        plt.scatter(centers_2d[:, 0], centers_2d[:, 1], 
                   c='red', marker='x', s=200, linewidths=3)
        plt.colorbar(scatter)
        plt.title('Clusters with Centers')
        plt.xlabel('t-SNE 1')
        plt.ylabel('t-SNE 2')
        
        plt.tight_layout()
        plt.show()
        
        # Analyze cluster purity
        from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
        
        ari_score = adjusted_rand_score(labels, cluster_labels)
        nmi_score = normalized_mutual_info_score(labels, cluster_labels)
        
        print(f"Clustering Analysis Results:")
        print(f"Adjusted Rand Index: {ari_score:.3f}")
        print(f"Normalized Mutual Information: {nmi_score:.3f}")
        
        return {
            'cluster_labels': cluster_labels,
            'cluster_centers': kmeans.cluster_centers_,
            'ari_score': ari_score,
            'nmi_score': nmi_score
        }

print("Multi-modal analyzer ready for advanced data analysis!")

## 4. Model Interpretability and Explainability

Implement advanced techniques for understanding model decisions.

In [None]:
class MultiModalExplainer:
    """Advanced explainability for multi-modal models"""
    
    def __init__(self, model):
        self.model = model
        self.model.eval()
    
    def generate_attention_maps(self, pixel_values, input_ids, attention_mask):
        """Generate attention maps for vision and text modalities"""
        # Enable gradient computation
        pixel_values.requires_grad_()
        input_ids.requires_grad_()
        
        # Forward pass
        logits = self.model(pixel_values, input_ids, attention_mask)
        
        # Get prediction
        pred_class = torch.argmax(logits, dim=1)
        
        # Compute gradients
        logits[0, pred_class].backward(retain_graph=True)
        
        # Vision attention (gradient-based)
        vision_gradients = pixel_values.grad.abs().mean(dim=1)  # Average over channels
        
        # Text attention (gradient-based)
        text_gradients = input_ids.grad.abs() if input_ids.grad is not None else torch.zeros_like(input_ids)
        
        return vision_gradients, text_gradients
    
    def integrated_gradients(self, pixel_values, input_ids, attention_mask, steps=50):
        """Compute integrated gradients for better attribution"""
        # Baseline (zeros)
        baseline_pixels = torch.zeros_like(pixel_values)
        baseline_ids = torch.zeros_like(input_ids)
        
        # Generate path
        alphas = torch.linspace(0, 1, steps).to(pixel_values.device)
        
        vision_gradients = torch.zeros_like(pixel_values)
        text_gradients = torch.zeros_like(input_ids, dtype=torch.float)
        
        for alpha in alphas:
            # Interpolated inputs
            interp_pixels = baseline_pixels + alpha * (pixel_values - baseline_pixels)
            interp_ids = baseline_ids + alpha * (input_ids - baseline_ids)
            
            interp_pixels.requires_grad_()
            interp_ids.requires_grad_()
            
            # Forward pass
            logits = self.model(interp_pixels, interp_ids.long(), attention_mask)
            pred_class = torch.argmax(logits, dim=1)
            
            # Backward pass
            logits[0, pred_class].backward(retain_graph=True)
            
            # Accumulate gradients
            if interp_pixels.grad is not None:
                vision_gradients += interp_pixels.grad
            if interp_ids.grad is not None:
                text_gradients += interp_ids.grad.float()
            
            # Clear gradients
            self.model.zero_grad()
        
        # Average gradients and multiply by input difference
        vision_gradients = vision_gradients / steps * (pixel_values - baseline_pixels)
        text_gradients = text_gradients / steps * (input_ids.float() - baseline_ids.float())
        
        return vision_gradients, text_gradients
    
    def visualize_attention(self, image, text, tokenizer, method='integrated_gradients'):
        """Visualize attention for a single example"""
        # Preprocess inputs
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        pixel_values = transform(image).unsqueeze(0).to(device)
        
        # Tokenize text
        encoding = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        # Generate attributions
        if method == 'integrated_gradients':
            vision_attr, text_attr = self.integrated_gradients(pixel_values, input_ids, attention_mask)
        else:
            vision_attr, text_attr = self.generate_attention_maps(pixel_values, input_ids, attention_mask)
        
        # Visualize results
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Original image
        axes[0, 0].imshow(image)
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        # Vision attention heatmap
        vision_heatmap = vision_attr[0].cpu().detach().numpy().mean(axis=0)
        im = axes[0, 1].imshow(vision_heatmap, cmap='hot', interpolation='bilinear')
        axes[0, 1].set_title('Vision Attention Heatmap')
        axes[0, 1].axis('off')
        plt.colorbar(im, ax=axes[0, 1])
        
        # Text attention
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        text_scores = text_attr[0].cpu().detach().numpy()
        
        # Filter out special tokens and padding
        valid_tokens = [(token, score) for token, score in zip(tokens, text_scores) 
                       if token not in ['[CLS]', '[SEP]', '[PAD]']]
        
        if valid_tokens:
            tokens_filtered, scores_filtered = zip(*valid_tokens)
            
            # Normalize scores for visualization
            scores_normalized = np.array(scores_filtered)
            scores_normalized = (scores_normalized - scores_normalized.min()) / (scores_normalized.max() - scores_normalized.min() + 1e-8)
            
            # Create text attention visualization
            axes[1, 0].barh(range(len(tokens_filtered)), scores_normalized)
            axes[1, 0].set_yticks(range(len(tokens_filtered)))
            axes[1, 0].set_yticklabels(tokens_filtered, fontsize=8)
            axes[1, 0].set_xlabel('Attention Score')
            axes[1, 0].set_title('Text Token Attention')
        
        # Combined visualization (overlay)
        axes[1, 1].imshow(image, alpha=0.7)
        axes[1, 1].imshow(vision_heatmap, cmap='hot', alpha=0.3, interpolation='bilinear')
        axes[1, 1].set_title('Image + Attention Overlay')
        axes[1, 1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return vision_attr, text_attr
    
    def analyze_cross_modal_interactions(self, dataloader, num_samples=50):
        """Analyze how vision and text modalities interact"""
        interaction_scores = []
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                if i >= num_samples // dataloader.batch_size:
                    break
                
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Get features from each modality
                vision_features = self.model.vision_encoder(pixel_values=pixel_values).pooler_output
                text_features = self.model.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
                
                # Compute interaction score (cosine similarity)
                interaction = torch.cosine_similarity(vision_features, text_features, dim=1)
                interaction_scores.extend(interaction.cpu().numpy())
        
        # Visualize interaction distribution
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.hist(interaction_scores, bins=30, alpha=0.7, edgecolor='black')
        plt.xlabel('Cross-Modal Similarity')
        plt.ylabel('Frequency')
        plt.title('Distribution of Vision-Text Interactions')
        plt.axvline(np.mean(interaction_scores), color='red', linestyle='--', 
                   label=f'Mean: {np.mean(interaction_scores):.3f}')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.boxplot(interaction_scores)
        plt.ylabel('Cross-Modal Similarity')
        plt.title('Interaction Score Distribution')
        
        plt.tight_layout()
        plt.show()
        
        return {
            'mean_interaction': np.mean(interaction_scores),
            'std_interaction': np.std(interaction_scores),
            'min_interaction': np.min(interaction_scores),
            'max_interaction': np.max(interaction_scores)
        }
    
    def feature_importance_analysis(self, dataloader, num_samples=100):
        """Analyze feature importance across modalities"""
        vision_importance = []
        text_importance = []
        
        for i, batch in enumerate(dataloader):
            if i >= num_samples // dataloader.batch_size:
                break
            
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Compute feature importance using permutation
            with torch.no_grad():
                # Original prediction
                original_logits = self.model(pixel_values, input_ids, attention_mask)
                original_confidence = torch.softmax(original_logits, dim=1).max(dim=1)[0]
                
                # Vision importance (permute vision features)
                shuffled_pixels = pixel_values[torch.randperm(pixel_values.size(0))]
                vision_logits = self.model(shuffled_pixels, input_ids, attention_mask)
                vision_confidence = torch.softmax(vision_logits, dim=1).max(dim=1)[0]
                vision_drop = (original_confidence - vision_confidence).mean().item()
                
                # Text importance (permute text features)
                shuffled_ids = input_ids[torch.randperm(input_ids.size(0))]
                shuffled_mask = attention_mask[torch.randperm(attention_mask.size(0))]
                text_logits = self.model(pixel_values, shuffled_ids, shuffled_mask)
                text_confidence = torch.softmax(text_logits, dim=1).max(dim=1)[0]
                text_drop = (original_confidence - text_confidence).mean().item()
                
                vision_importance.append(vision_drop)
                text_importance.append(text_drop)
        
        # Visualize importance
        modalities = ['Vision', 'Text']
        importance_scores = [np.mean(vision_importance), np.mean(text_importance)]
        importance_std = [np.std(vision_importance), np.std(text_importance)]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(modalities, importance_scores, yerr=importance_std, 
                      capsize=5, color=['skyblue', 'lightcoral'])
        plt.ylabel('Importance Score (Confidence Drop)')
        plt.title('Feature Importance Analysis')
        
        # Add value labels
        for bar, score in zip(bars, importance_scores):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                    f'{score:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return {
            'vision_importance': np.mean(vision_importance),
            'text_importance': np.mean(text_importance),
            'vision_std': np.std(vision_importance),
            'text_std': np.std(text_importance)
        }

print("Multi-modal explainer ready for interpretability analysis!")

## 5. Performance Optimization and Benchmarking

Advanced optimization techniques and comprehensive benchmarking.

In [None]:
import time
import psutil
import GPUtil
from contextlib import contextmanager

class PerformanceOptimizer:
    """Advanced performance optimization for multi-modal models"""
    
    def __init__(self, model):
        self.model = model
        self.original_model = None
    
    def apply_torch_compile(self):
        """Apply PyTorch 2.0 compilation for speed optimization"""
        if hasattr(torch, 'compile'):
            print("Applying torch.compile optimization...")
            self.original_model = self.model
            self.model = torch.compile(self.model, mode='max-autotune')
            print("✅ Torch compilation applied")
        else:
            print("⚠️ torch.compile not available in this PyTorch version")
    
    def apply_quantization(self, quantization_type='dynamic'):
        """Apply model quantization for memory and speed optimization"""
        print(f"Applying {quantization_type} quantization...")
        
        if quantization_type == 'dynamic':
            self.model = torch.quantization.quantize_dynamic(
                self.model, {torch.nn.Linear}, dtype=torch.qint8
            )
        elif quantization_type == 'static':
            # Static quantization requires calibration data
            self.model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
            torch.quantization.prepare(self.model, inplace=True)
            # Calibration would happen here with representative data
            torch.quantization.convert(self.model, inplace=True)
        
        print(f"✅ {quantization_type.capitalize()} quantization applied")
    
    def apply_pruning(self, sparsity=0.3):
        """Apply structured pruning to reduce model size"""
        print(f"Applying pruning with {sparsity*100}% sparsity...")
        
        import torch.nn.utils.prune as prune
        
        # Apply pruning to linear layers
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=sparsity)
                prune.remove(module, 'weight')
        
        print(f"✅ Pruning applied with {sparsity*100}% sparsity")
    
    def optimize_memory(self):
        """Apply memory optimization techniques"""
        print("Applying memory optimizations...")
        
        # Enable gradient checkpointing
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()
        
        # Enable attention slicing for vision models
        if hasattr(self.model, 'enable_attention_slicing'):
            self.model.enable_attention_slicing()
        
        # Enable memory efficient attention
        if hasattr(self.model, 'enable_xformers_memory_efficient_attention'):
            try:
                self.model.enable_xformers_memory_efficient_attention()
                print("✅ XFormers memory efficient attention enabled")
            except:
                print("⚠️ XFormers not available")
        
        print("✅ Memory optimizations applied")

class PerformanceBenchmark:
    """Comprehensive performance benchmarking suite"""
    
    def __init__(self):
        self.results = {}
    
    @contextmanager
    def measure_time(self, operation_name):
        """Context manager for measuring execution time"""
        start_time = time.time()
        start_memory = self.get_memory_usage()
        
        yield
        
        end_time = time.time()
        end_memory = self.get_memory_usage()
        
        self.results[operation_name] = {
            'execution_time': end_time - start_time,
            'memory_before': start_memory,
            'memory_after': end_memory,
            'memory_delta': end_memory['gpu_used'] - start_memory['gpu_used'] if torch.cuda.is_available() else 0
        }
    
    def get_memory_usage(self):
        """Get current memory usage"""
        memory_info = {
            'cpu_percent': psutil.virtual_memory().percent,
            'cpu_used_gb': psutil.virtual_memory().used / (1024**3)
        }
        
        if torch.cuda.is_available():
            gpu = GPUtil.getGPUs()[0] if GPUtil.getGPUs() else None
            if gpu:
                memory_info.update({
                    'gpu_used': gpu.memoryUsed,
                    'gpu_total': gpu.memoryTotal,
                    'gpu_percent': (gpu.memoryUsed / gpu.memoryTotal) * 100
                })
            else:
                memory_info.update({
                    'gpu_used': torch.cuda.memory_allocated() / (1024**3),
                    'gpu_total': torch.cuda.get_device_properties(0).total_memory / (1024**3),
                    'gpu_percent': (torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory) * 100
                })
        
        return memory_info
    
    def benchmark_inference(self, model, dataloader, num_batches=10):
        """Benchmark inference performance"""
        model.eval()
        inference_times = []
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                if i >= num_batches:
                    break
                
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                with self.measure_time(f'inference_batch_{i}'):
                    logits = model(pixel_values, input_ids, attention_mask)
                
                inference_times.append(self.results[f'inference_batch_{i}']['execution_time'])
        
        return {
            'mean_inference_time': np.mean(inference_times),
            'std_inference_time': np.std(inference_times),
            'min_inference_time': np.min(inference_times),
            'max_inference_time': np.max(inference_times),
            'throughput_samples_per_second': dataloader.batch_size / np.mean(inference_times)
        }
    
    def benchmark_training_step(self, model, dataloader, optimizer, criterion, num_steps=5):
        """Benchmark training step performance"""
        model.train()
        training_times = []
        
        for i, batch in enumerate(dataloader):
            if i >= num_steps:
                break
            
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            with self.measure_time(f'training_step_{i}'):
                optimizer.zero_grad()
                logits = model(pixel_values, input_ids, attention_mask)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
            
            training_times.append(self.results[f'training_step_{i}']['execution_time'])
        
        return {
            'mean_training_time': np.mean(training_times),
            'std_training_time': np.std(training_times),
            'min_training_time': np.min(training_times),
            'max_training_time': np.max(training_times)
        }
    
    def compare_optimizations(self, original_model, optimized_models, dataloader):
        """Compare performance of different optimizations"""
        results = {}
        
        # Benchmark original model
        print("Benchmarking original model...")
        results['original'] = self.benchmark_inference(original_model, dataloader)
        
        # Benchmark optimized models
        for name, model in optimized_models.items():
            print(f"Benchmarking {name} model...")
            results[name] = self.benchmark_inference(model, dataloader)
        
        # Visualize results
        self.visualize_benchmark_results(results)
        
        return results
    
    def visualize_benchmark_results(self, results):
        """Visualize benchmark results"""
        models = list(results.keys())
        inference_times = [results[model]['mean_inference_time'] for model in models]
        throughputs = [results[model]['throughput_samples_per_second'] for model in models]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Inference time comparison
        bars1 = ax1.bar(models, inference_times, color=['skyblue', 'lightcoral', 'lightgreen', 'orange'][:len(models)])
        ax1.set_ylabel('Inference Time (seconds)')
        ax1.set_title('Inference Time Comparison')
        ax1.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, time_val in zip(bars1, inference_times):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, 
                    f'{time_val:.3f}s', ha='center', va='bottom')
        
        # Throughput comparison
        bars2 = ax2.bar(models, throughputs, color=['skyblue', 'lightcoral', 'lightgreen', 'orange'][:len(models)])
        ax2.set_ylabel('Throughput (samples/second)')
        ax2.set_title('Throughput Comparison')
        ax2.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, throughput in zip(bars2, throughputs):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                    f'{throughput:.1f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        # Print speedup information
        if 'original' in results:
            original_time = results['original']['mean_inference_time']
            print("\nSpeedup Analysis:")
            print("-" * 40)
            for model in models:
                if model != 'original':
                    speedup = original_time / results[model]['mean_inference_time']
                    print(f"{model}: {speedup:.2f}x speedup")
    
    def profile_model_components(self, model, sample_input):
        """Profile individual model components"""
        pixel_values, input_ids, attention_mask = sample_input
        
        # Profile vision encoder
        with self.measure_time('vision_encoder'):
            vision_output = model.vision_encoder(pixel_values=pixel_values)
        
        # Profile text encoder
        with self.measure_time('text_encoder'):
            text_output = model.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Profile fusion layers
        vision_features = vision_output.pooler_output
        text_features = text_output.pooler_output
        
        with self.measure_time('cross_attention'):
            for cross_attn in model.cross_attention_layers:
                vision_features, text_features = cross_attn(vision_features, text_features)
        
        # Profile classifier
        vision_proj = model.vision_proj(vision_features)
        text_proj = model.text_proj(text_features)
        
        with self.measure_time('classifier'):
            early_fused = torch.tanh(model.early_fusion(torch.cat([vision_proj, text_proj], dim=-1)))
            mid_fused = torch.tanh(model.mid_fusion(torch.cat([vision_proj * text_proj, vision_proj + text_proj], dim=-1)))
            late_fused = torch.tanh(model.late_fusion(torch.cat([early_fused, mid_fused], dim=-1)))
            
            fusion_stack = torch.stack([early_fused, mid_fused, late_fused], dim=1)
            attended_features, _ = model.attention_pool(fusion_stack, fusion_stack, fusion_stack)
            final_features = attended_features.mean(dim=1)
            logits = model.classifier(final_features)
        
        # Visualize component timing
        components = ['vision_encoder', 'text_encoder', 'cross_attention', 'classifier']
        times = [self.results[comp]['execution_time'] for comp in components]
        
        plt.figure(figsize=(10, 6))
        bars = plt.bar(components, times, color=['skyblue', 'lightcoral', 'lightgreen', 'orange'])
        plt.ylabel('Execution Time (seconds)')
        plt.title('Model Component Profiling')
        plt.xticks(rotation=45)
        
        # Add value labels and percentages
        total_time = sum(times)
        for bar, time_val in zip(bars, times):
            percentage = (time_val / total_time) * 100
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, 
                    f'{time_val:.3f}s\n({percentage:.1f}%)', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        return {comp: self.results[comp] for comp in components}

print("Performance optimization and benchmarking tools ready!")

## 6. Research Experiments and Future Directions

Cutting-edge research experiments and exploration of future directions in multi-modal AI.

In [None]:
# Summary and Next Steps
print("🎉 Advanced Multi-Modal AI Research Notebook Complete!")
print("\n📊 What we've covered:")
print("✅ Advanced multi-modal architecture with cross-attention")
print("✅ Cutting-edge training techniques (contrastive learning, curriculum learning)")
print("✅ Comprehensive data analysis and visualization")
print("✅ Model interpretability and explainability")
print("✅ Performance optimization and benchmarking")

print("\n🔬 Research Directions:")
print("• Multi-modal foundation models")
print("• Few-shot and zero-shot learning")
print("• Efficient attention mechanisms")
print("• Cross-modal knowledge distillation")
print("• Federated multi-modal learning")

print("\n🚀 Next Steps:")
print("1. Implement custom datasets for your specific use case")
print("2. Experiment with different fusion strategies")
print("3. Apply advanced optimization techniques")
print("4. Conduct thorough evaluation and analysis")
print("5. Deploy optimized models to production")

print("\n💡 Happy researching with Multi-Modal AI!")