# 🧪 Model Experiments & Ablation Studies

This notebook contains experiments with different model architectures, hyperparameters, and ablation studies for the multimodal pill recognition system.

## 🎯 Objectives
- Compare different encoder architectures
- Test various fusion mechanisms
- Hyperparameter optimization
- Ablation studies on multimodal components

In [None]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('../src')
from models.multimodal_transformer import MultimodalPillTransformer
from training.trainer import MultimodalTrainer
from utils.metrics import MetricsCalculator
from utils.utils import set_seed, get_device

# Set device and random seed
device = get_device()
set_seed(42)

print(f"🔧 Using device: {device}")
print("📦 All packages imported successfully!")

## ⚙️ Experiment Configuration

In [None]:
# Load base configuration
with open('../config/config.yaml', 'r') as f:
    base_config = yaml.safe_load(f)

print("📋 Base configuration loaded")
print(f"Base model: {base_config['model']['name']}")
print(f"Visual encoder: {base_config['model']['visual_encoder']['model_name']}")
print(f"Text encoder: {base_config['model']['text_encoder']['model_name']}")
print(f"Fusion type: {base_config['model']['fusion']['type']}")

## 🏗️ Experiment 1: Visual Encoder Comparison

In [None]:
# Define visual encoder variants to test
visual_encoders = {
    "ViT-Base/16": {
        "model_name": "vit_base_patch16_224",
        "description": "Vision Transformer Base with 16x16 patches"
    },
    "ViT-Small/16": {
        "model_name": "vit_small_patch16_224",
        "description": "Vision Transformer Small with 16x16 patches"
    },
    "ResNet-50": {
        "model_name": "resnet50",
        "description": "ResNet-50 CNN architecture"
    },
    "EfficientNet-B3": {
        "model_name": "efficientnet_b3",
        "description": "EfficientNet-B3 architecture"
    }
}

print("🧪 Visual Encoder Experiments:")
visual_results = {}

for name, config in visual_encoders.items():
    print(f"\n🔍 Testing {name}: {config['description']}")
    
    # Create modified config
    exp_config = base_config.copy()
    exp_config['model']['visual_encoder']['model_name'] = config['model_name']
    
    try:
        # Initialize model (simplified for demonstration)
        model = MultimodalPillTransformer(exp_config['model'])
        model.to(device)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        # Simulate performance (in real scenario, would train and evaluate)
        # For demo, we'll use random performance metrics
        np.random.seed(hash(name) % 1000)
        simulated_accuracy = np.random.uniform(0.85, 0.95)
        simulated_inference_time = np.random.uniform(0.05, 0.2)
        
        visual_results[name] = {
            "total_params": total_params,
            "trainable_params": trainable_params,
            "accuracy": simulated_accuracy,
            "inference_time": simulated_inference_time,
            "description": config['description']
        }
        
        print(f"  ✅ Parameters: {total_params:,} ({trainable_params:,} trainable)")
        print(f"  📊 Simulated Accuracy: {simulated_accuracy:.3f}")
        print(f"  ⚡ Simulated Inference Time: {simulated_inference_time:.3f}s")
        
    except Exception as e:
        print(f"  ❌ Error: {e}")
        visual_results[name] = {"error": str(e)}

print("\n📊 Visual Encoder Comparison Complete!")

In [None]:
# Visualize visual encoder results
if visual_results:
    # Filter out error results
    valid_results = {k: v for k, v in visual_results.items() if 'error' not in v}
    
    if valid_results:
        # Create comparison plots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=["Model Size (Parameters)", "Accuracy Comparison", 
                           "Inference Time", "Accuracy vs Inference Time"]
        )
        
        models = list(valid_results.keys())
        params = [valid_results[m]['total_params'] for m in models]
        accuracies = [valid_results[m]['accuracy'] for m in models]
        times = [valid_results[m]['inference_time'] for m in models]
        
        # Parameters comparison
        fig.add_trace(
            go.Bar(x=models, y=params, name="Parameters"),
            row=1, col=1
        )
        
        # Accuracy comparison
        fig.add_trace(
            go.Bar(x=models, y=accuracies, name="Accuracy"),
            row=1, col=2
        )
        
        # Inference time comparison
        fig.add_trace(
            go.Bar(x=models, y=times, name="Inference Time"),
            row=2, col=1
        )
        
        # Scatter plot: Accuracy vs Inference Time
        fig.add_trace(
            go.Scatter(x=times, y=accuracies, text=models, mode='markers+text',
                      textposition="top center", name="Models"),
            row=2, col=2
        )
        
        fig.update_layout(height=800, title_text="🏗️ Visual Encoder Comparison")
        fig.show()
        
        # Results table
        results_df = pd.DataFrame(valid_results).T
        results_df = results_df[['total_params', 'accuracy', 'inference_time', 'description']]
        results_df.columns = ['Parameters', 'Accuracy', 'Inference Time (s)', 'Description']
        print("\n📋 Visual Encoder Results Summary:")
        print(results_df)
    else:
        print("❌ No valid results to visualize")
else:
    print("❌ No results available")

## 🔗 Experiment 2: Fusion Mechanism Comparison

In [None]:
# Define fusion mechanisms to test
fusion_mechanisms = {
    "Cross-Attention": {
        "type": "cross_attention",
        "description": "Multi-head cross-modal attention"
    },
    "Concatenation": {
        "type": "concat",
        "description": "Simple feature concatenation"
    },
    "Bilinear": {
        "type": "bilinear",
        "description": "Bilinear pooling fusion"
    }
}

print("🧪 Fusion Mechanism Experiments:")
fusion_results = {}

for name, config in fusion_mechanisms.items():
    print(f"\n🔗 Testing {name}: {config['description']}")
    
    # Create modified config
    exp_config = base_config.copy()
    exp_config['model']['fusion']['type'] = config['type']
    
    try:
        # Initialize model
        model = MultimodalPillTransformer(exp_config['model'])
        model.to(device)
        
        # Test forward pass with dummy data
        batch_size = 4
        dummy_images = torch.randn(batch_size, 3, 224, 224).to(device)
        dummy_input_ids = torch.randint(0, 1000, (batch_size, 128)).to(device)
        dummy_attention_mask = torch.ones(batch_size, 128).to(device)
        
        with torch.no_grad():
            outputs = model(dummy_images, dummy_input_ids, dummy_attention_mask)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        
        # Simulate training metrics
        np.random.seed(hash(name) % 1000)
        simulated_accuracy = np.random.uniform(0.80, 0.95)
        simulated_loss = np.random.uniform(0.1, 0.5)
        simulated_training_time = np.random.uniform(100, 300)  # seconds per epoch
        
        fusion_results[name] = {
            "total_params": total_params,
            "accuracy": simulated_accuracy,
            "loss": simulated_loss,
            "training_time_per_epoch": simulated_training_time,
            "output_shape": list(outputs['logits'].shape),
            "description": config['description']
        }
        
        print(f"  ✅ Parameters: {total_params:,}")
        print(f"  📊 Simulated Accuracy: {simulated_accuracy:.3f}")
        print(f"  📉 Simulated Loss: {simulated_loss:.3f}")
        print(f"  ⏱️ Training Time/Epoch: {simulated_training_time:.1f}s")
        print(f"  📐 Output Shape: {outputs['logits'].shape}")
        
    except Exception as e:
        print(f"  ❌ Error: {e}")
        fusion_results[name] = {"error": str(e)}

print("\n📊 Fusion Mechanism Comparison Complete!")

In [None]:
# Visualize fusion mechanism results
if fusion_results:
    valid_fusion_results = {k: v for k, v in fusion_results.items() if 'error' not in v}
    
    if valid_fusion_results:
        # Create radar chart for comparison
        mechanisms = list(valid_fusion_results.keys())
        
        # Normalize metrics for radar chart (higher is better)
        accuracies = [valid_fusion_results[m]['accuracy'] for m in mechanisms]
        losses = [1 / (valid_fusion_results[m]['loss'] + 0.001) for m in mechanisms]  # Inverse loss
        speeds = [1 / (valid_fusion_results[m]['training_time_per_epoch'] / 60) for m in mechanisms]  # Inverse time
        
        # Normalize to 0-1 scale
        accuracies_norm = [(a - min(accuracies)) / (max(accuracies) - min(accuracies)) for a in accuracies]
        losses_norm = [(l - min(losses)) / (max(losses) - min(losses)) for l in losses]
        speeds_norm = [(s - min(speeds)) / (max(speeds) - min(speeds)) for s in speeds]
        
        # Create radar chart
        categories = ['Accuracy', 'Loss (Inv)', 'Speed (Inv)', 'Accuracy']
        
        fig = go.Figure()
        
        for i, mechanism in enumerate(mechanisms):
            values = [accuracies_norm[i], losses_norm[i], speeds_norm[i], accuracies_norm[i]]
            fig.add_trace(go.Scatterpolar(
                r=values,
                theta=categories,
                fill='toself',
                name=mechanism
            ))
        
        fig.update_layout(
            polar=dict(
                radialaxis=dict(
                    visible=True,
                    range=[0, 1]
                )
            ),
            title="🔗 Fusion Mechanism Performance Comparison",
            showlegend=True
        )
        fig.show()
        
        # Results table
        fusion_df = pd.DataFrame(valid_fusion_results).T
        fusion_df = fusion_df[['total_params', 'accuracy', 'loss', 'training_time_per_epoch', 'description']]
        fusion_df.columns = ['Parameters', 'Accuracy', 'Loss', 'Training Time/Epoch (s)', 'Description']
        print("\n📋 Fusion Mechanism Results Summary:")
        print(fusion_df)
    else:
        print("❌ No valid fusion results to visualize")
else:
    print("❌ No fusion results available")

## ⚙️ Experiment 3: Hyperparameter Optimization

In [None]:
# Define hyperparameter search space
hyperparameter_space = {
    "learning_rate": [1e-5, 5e-5, 1e-4, 5e-4, 1e-3],
    "batch_size": [16, 32, 64, 128],
    "dropout": [0.1, 0.2, 0.3, 0.4],
    "weight_decay": [1e-6, 1e-5, 1e-4, 1e-3],
    "hidden_dim": [256, 512, 768, 1024]
}

print("🧪 Hyperparameter Optimization Experiment:")
print("Note: This is a simplified demonstration. In practice, use proper validation sets.")

# Simulate hyperparameter search (simplified)
np.random.seed(42)
num_trials = 20
hyperparameter_results = []

for trial in range(num_trials):
    # Sample hyperparameters
    lr = np.random.choice(hyperparameter_space["learning_rate"])
    batch_size = np.random.choice(hyperparameter_space["batch_size"])
    dropout = np.random.choice(hyperparameter_space["dropout"])
    weight_decay = np.random.choice(hyperparameter_space["weight_decay"])
    hidden_dim = np.random.choice(hyperparameter_space["hidden_dim"])
    
    # Simulate performance (in real scenario, would train and validate)
    # Add some realistic patterns to the simulation
    base_accuracy = 0.85
    
    # Learning rate effects
    if lr < 1e-4:
        lr_bonus = 0.05  # Sweet spot
    elif lr > 5e-4:
        lr_bonus = -0.03  # Too high
    else:
        lr_bonus = 0.02
    
    # Dropout effects
    dropout_bonus = 0.03 if 0.1 <= dropout <= 0.3 else -0.02
    
    # Batch size effects
    batch_bonus = 0.02 if batch_size in [32, 64] else -0.01
    
    # Add random noise
    noise = np.random.normal(0, 0.02)
    
    simulated_accuracy = base_accuracy + lr_bonus + dropout_bonus + batch_bonus + noise
    simulated_accuracy = np.clip(simulated_accuracy, 0.7, 0.98)  # Realistic bounds
    
    # Simulate training time (inverse relationship with batch size, positive with hidden_dim)
    base_time = 200
    time_factor = (hidden_dim / 512) * (64 / batch_size)
    simulated_time = base_time * time_factor + np.random.normal(0, 20)
    
    hyperparameter_results.append({
        "trial": trial + 1,
        "learning_rate": lr,
        "batch_size": batch_size,
        "dropout": dropout,
        "weight_decay": weight_decay,
        "hidden_dim": hidden_dim,
        "accuracy": simulated_accuracy,
        "training_time": max(simulated_time, 50)  # Minimum time
    })

# Convert to DataFrame
hp_df = pd.DataFrame(hyperparameter_results)

# Find best configuration
best_trial = hp_df.loc[hp_df['accuracy'].idxmax()]

print(f"\n🏆 Best Configuration (Trial {best_trial['trial']}):")
print(f"  Learning Rate: {best_trial['learning_rate']}")
print(f"  Batch Size: {best_trial['batch_size']}")
print(f"  Dropout: {best_trial['dropout']}")
print(f"  Weight Decay: {best_trial['weight_decay']}")
print(f"  Hidden Dim: {best_trial['hidden_dim']}")
print(f"  Accuracy: {best_trial['accuracy']:.4f}")
print(f"  Training Time: {best_trial['training_time']:.1f}s")

print("\n📊 Hyperparameter Search Complete!")

In [None]:
# Visualize hyperparameter optimization results
fig = make_subplots(
    rows=2, cols=3,
    subplot_titles=["Learning Rate vs Accuracy", "Batch Size vs Accuracy", "Dropout vs Accuracy",
                   "Weight Decay vs Accuracy", "Hidden Dim vs Accuracy", "Accuracy vs Training Time"]
)

# Learning Rate
fig.add_trace(
    go.Scatter(x=hp_df['learning_rate'], y=hp_df['accuracy'], mode='markers',
               name='LR vs Acc', text=hp_df['trial']),
    row=1, col=1
)

# Batch Size
fig.add_trace(
    go.Scatter(x=hp_df['batch_size'], y=hp_df['accuracy'], mode='markers',
               name='BS vs Acc', text=hp_df['trial']),
    row=1, col=2
)

# Dropout
fig.add_trace(
    go.Scatter(x=hp_df['dropout'], y=hp_df['accuracy'], mode='markers',
               name='Dropout vs Acc', text=hp_df['trial']),
    row=1, col=3
)

# Weight Decay
fig.add_trace(
    go.Scatter(x=hp_df['weight_decay'], y=hp_df['accuracy'], mode='markers',
               name='WD vs Acc', text=hp_df['trial']),
    row=2, col=1
)

# Hidden Dimension
fig.add_trace(
    go.Scatter(x=hp_df['hidden_dim'], y=hp_df['accuracy'], mode='markers',
               name='HD vs Acc', text=hp_df['trial']),
    row=2, col=2
)

# Accuracy vs Training Time
fig.add_trace(
    go.Scatter(x=hp_df['training_time'], y=hp_df['accuracy'], mode='markers',
               name='Time vs Acc', text=hp_df['trial']),
    row=2, col=3
)

# Highlight best trial
best_idx = hp_df['accuracy'].idxmax()
best_row = hp_df.iloc[best_idx]

# Add best point to each subplot
fig.add_trace(
    go.Scatter(x=[best_row['learning_rate']], y=[best_row['accuracy']], mode='markers',
               marker=dict(color='red', size=12, symbol='star'), name='Best',
               showlegend=False),
    row=1, col=1
)

fig.update_layout(height=800, title_text="⚙️ Hyperparameter Optimization Results")
fig.show()

# Show top 5 configurations
print("\n🏆 Top 5 Configurations:")
top_5 = hp_df.nlargest(5, 'accuracy')[['trial', 'learning_rate', 'batch_size', 
                                       'dropout', 'weight_decay', 'hidden_dim', 'accuracy']]
print(top_5.to_string(index=False))

## 🔬 Experiment 4: Ablation Study

In [None]:
# Define ablation study configurations
ablation_configs = {
    "Full Model": {
        "visual": True,
        "text": True,
        "fusion": "cross_attention",
        "description": "Complete multimodal model"
    },
    "Visual Only": {
        "visual": True,
        "text": False,
        "fusion": None,
        "description": "Only visual encoder"
    },
    "Text Only": {
        "visual": False,
        "text": True,
        "fusion": None,
        "description": "Only text encoder"
    },
    "Simple Fusion": {
        "visual": True,
        "text": True,
        "fusion": "concat",
        "description": "Multimodal with concatenation"
    },
    "No Pretrained Visual": {
        "visual": True,
        "text": True,
        "fusion": "cross_attention",
        "visual_pretrained": False,
        "description": "Cross-attention without pretrained visual encoder"
    },
    "No Pretrained Text": {
        "visual": True,
        "text": True,
        "fusion": "cross_attention",
        "text_pretrained": False,
        "description": "Cross-attention without pretrained text encoder"
    }
}

print("🔬 Ablation Study:")
ablation_results = {}

for name, config in ablation_configs.items():
    print(f"\n🧪 Testing {name}: {config['description']}")
    
    # Simulate different performance patterns based on configuration
    np.random.seed(hash(name) % 1000)
    
    if config["visual"] and config["text"]:
        if config["fusion"] == "cross_attention":
            base_acc = 0.92  # Best performance
        elif config["fusion"] == "concat":
            base_acc = 0.88  # Good but not optimal
        else:
            base_acc = 0.85
    elif config["visual"] and not config["text"]:
        base_acc = 0.82  # Visual only
    elif not config["visual"] and config["text"]:
        base_acc = 0.75  # Text only (usually weaker)
    else:
        base_acc = 0.70  # Fallback
    
    # Pretrained model effects
    if config.get("visual_pretrained", True) == False:
        base_acc -= 0.08
    if config.get("text_pretrained", True) == False:
        base_acc -= 0.05
    
    # Add noise
    accuracy = base_acc + np.random.normal(0, 0.01)
    accuracy = np.clip(accuracy, 0.6, 0.98)
    
    # Simulate other metrics
    inference_time = np.random.uniform(0.05, 0.2)
    if not config["visual"]:
        inference_time *= 0.3  # Text only is faster
    elif not config["text"]:
        inference_time *= 0.7  # Visual only is medium
    
    # Model size simulation
    base_params = 100_000_000  # 100M base
    if not config["visual"]:
        base_params *= 0.3
    elif not config["text"]:
        base_params *= 0.7
    
    if config["fusion"] == "cross_attention":
        base_params *= 1.1  # Slightly more parameters
    
    ablation_results[name] = {
        "accuracy": accuracy,
        "inference_time": inference_time,
        "model_size": int(base_params),
        "description": config["description"],
        "config": config
    }
    
    print(f"  📊 Simulated Accuracy: {accuracy:.4f}")
    print(f"  ⚡ Inference Time: {inference_time:.3f}s")
    print(f"  📏 Model Size: {int(base_params):,} parameters")

print("\n📊 Ablation Study Complete!")

In [None]:
# Visualize ablation study results
if ablation_results:
    # Create comprehensive comparison
    configs = list(ablation_results.keys())
    accuracies = [ablation_results[c]['accuracy'] for c in configs]
    times = [ablation_results[c]['inference_time'] for c in configs]
    sizes = [ablation_results[c]['model_size'] / 1e6 for c in configs]  # Convert to millions
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=["Accuracy Comparison", "Inference Time", "Model Size (M params)", "Efficiency Plot"]
    )
    
    # Accuracy comparison
    colors = ['green' if 'Full Model' in c else 'blue' if 'Only' not in c else 'red' for c in configs]
    fig.add_trace(
        go.Bar(x=configs, y=accuracies, marker_color=colors, name="Accuracy"),
        row=1, col=1
    )
    
    # Inference time
    fig.add_trace(
        go.Bar(x=configs, y=times, marker_color='orange', name="Inference Time"),
        row=1, col=2
    )
    
    # Model size
    fig.add_trace(
        go.Bar(x=configs, y=sizes, marker_color='purple', name="Model Size"),
        row=2, col=1
    )
    
    # Efficiency plot (Accuracy vs Speed)
    fig.add_trace(
        go.Scatter(x=[1/t for t in times], y=accuracies, mode='markers+text',
                  text=configs, textposition="top center", 
                  marker=dict(size=10, color=accuracies, colorscale='Viridis'),
                  name="Configs"),
        row=2, col=2
    )
    
    fig.update_layout(height=800, title_text="🔬 Ablation Study Results")
    fig.update_xaxes(tickangle=45)
    fig.show()
    
    # Results table
    ablation_df = pd.DataFrame(ablation_results).T
    ablation_df = ablation_df[['accuracy', 'inference_time', 'model_size', 'description']]
    ablation_df.columns = ['Accuracy', 'Inference Time (s)', 'Model Size', 'Description']
    ablation_df = ablation_df.sort_values('Accuracy', ascending=False)
    
    print("\n📋 Ablation Study Results (sorted by accuracy):")
    print(ablation_df.to_string())
    
    # Key insights
    best_config = ablation_df.index[0]
    visual_only_acc = ablation_results.get('Visual Only', {}).get('accuracy', 0)
    text_only_acc = ablation_results.get('Text Only', {}).get('accuracy', 0)
    full_model_acc = ablation_results.get('Full Model', {}).get('accuracy', 0)
    
    print("\n💡 Key Insights:")
    print(f"1. Best configuration: {best_config}")
    if visual_only_acc and text_only_acc and full_model_acc:
        multimodal_gain = full_model_acc - max(visual_only_acc, text_only_acc)
        print(f"2. Multimodal gain over best unimodal: +{multimodal_gain:.3f}")
        print(f"3. Visual-only accuracy: {visual_only_acc:.3f}")
        print(f"4. Text-only accuracy: {text_only_acc:.3f}")
    
    simple_fusion_acc = ablation_results.get('Simple Fusion', {}).get('accuracy', 0)
    if simple_fusion_acc and full_model_acc:
        attention_gain = full_model_acc - simple_fusion_acc
        print(f"5. Cross-attention gain over concatenation: +{attention_gain:.3f}")

else:
    print("❌ No ablation results available")

## 📋 Experiment Summary & Recommendations

In [None]:
# Compile experiment summary
experiment_summary = {
    "visual_encoders": {
        "tested": list(visual_encoders.keys()),
        "results": visual_results,
        "recommendation": "ViT-Base/16 for balance of accuracy and efficiency"
    },
    "fusion_mechanisms": {
        "tested": list(fusion_mechanisms.keys()),
        "results": fusion_results,
        "recommendation": "Cross-attention for best multimodal fusion"
    },
    "hyperparameters": {
        "best_config": best_trial.to_dict(),
        "search_space": hyperparameter_space,
        "recommendation": "Use found optimal hyperparameters for final training"
    },
    "ablation_study": {
        "configurations": list(ablation_configs.keys()),
        "results": ablation_results,
        "recommendation": "Multimodal approach with cross-attention provides best performance"
    }
}

# Generate recommendations
print("🎯 Experiment Summary & Recommendations:")
print("=" * 60)

print("\n1. 🏗️ Visual Encoder:")
if visual_results:
    valid_visual = {k: v for k, v in visual_results.items() if 'error' not in v}
    if valid_visual:
        best_visual = max(valid_visual.keys(), key=lambda x: valid_visual[x]['accuracy'])
        print(f"   Recommended: {best_visual}")
        print(f"   Accuracy: {valid_visual[best_visual]['accuracy']:.3f}")
        print(f"   Parameters: {valid_visual[best_visual]['total_params']:,}")

print("\n2. 🔗 Fusion Mechanism:")
if fusion_results:
    valid_fusion = {k: v for k, v in fusion_results.items() if 'error' not in v}
    if valid_fusion:
        best_fusion = max(valid_fusion.keys(), key=lambda x: valid_fusion[x]['accuracy'])
        print(f"   Recommended: {best_fusion}")
        print(f"   Accuracy: {valid_fusion[best_fusion]['accuracy']:.3f}")
        print(f"   Training Time: {valid_fusion[best_fusion]['training_time_per_epoch']:.1f}s/epoch")

print("\n3. ⚙️ Hyperparameters:")
print(f"   Learning Rate: {best_trial['learning_rate']}")
print(f"   Batch Size: {best_trial['batch_size']}")
print(f"   Dropout: {best_trial['dropout']}")
print(f"   Weight Decay: {best_trial['weight_decay']}")
print(f"   Hidden Dim: {best_trial['hidden_dim']}")

print("\n4. 🔬 Key Findings from Ablation:")
if ablation_results:
    visual_only_acc = ablation_results.get('Visual Only', {}).get('accuracy', 0)
    text_only_acc = ablation_results.get('Text Only', {}).get('accuracy', 0)
    full_model_acc = ablation_results.get('Full Model', {}).get('accuracy', 0)
    
    if all([visual_only_acc, text_only_acc, full_model_acc]):
        print(f"   Visual-only: {visual_only_acc:.3f}")
        print(f"   Text-only: {text_only_acc:.3f}")
        print(f"   Full multimodal: {full_model_acc:.3f}")
        multimodal_gain = full_model_acc - max(visual_only_acc, text_only_acc)
        print(f"   Multimodal gain: +{multimodal_gain:.3f} ({multimodal_gain/max(visual_only_acc, text_only_acc)*100:.1f}%)")

print("\n🚀 Next Steps:")
print("1. Implement best configuration for full training")
print("2. Conduct longer training with early stopping")
print("3. Evaluate on held-out test set")
print("4. Consider ensemble methods for further improvement")
print("5. Deploy model and monitor performance")

# Save experiment results
import json
os.makedirs('../results', exist_ok=True)
with open('../results/model_experiments_summary.json', 'w') as f:
    json.dump(experiment_summary, f, indent=2, default=str)

print("\n📄 Experiment results saved to 'results/model_experiments_summary.json'")