# Multi-modal Integration for Disease Prediction

This notebook demonstrates how to integrate tabular and image data for unified disease prediction. It covers feature fusion, unified model building, and evaluation of the multi-modal system.

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import os

# --- 1. Load Pre-trained Models ---
# Assuming you have already trained individual models using the respective notebooks

def load_pretrained_models():
    models = {}
    preprocessors = {}
    
    print("Loading pre-trained models...")
    
    # Load tabular models (example for diabetes)
    try:
        models['diabetes'] = joblib.load('diabetes_logistic_regression_model.pkl')
        preprocessors['diabetes'] = joblib.load('diabetes_preprocessor.pkl')
        print("Diabetes model loaded.")
    except FileNotFoundError:
        print("Diabetes model not found. Please train it first using the diabetes_prediction.ipynb notebook.")
    
    # Load image models (example for pneumonia)
    try:
        models['pneumonia'] = load_model('pneumonia_detection_model.h5')
        print("Pneumonia image model loaded.")
    except Exception as e:
        print(f"Pneumonia image model not found: {e}. Please train it first using the chest_xray_pneumonia_detection.ipynb notebook.")
    
    return models, preprocessors

models, preprocessors = load_pretrained_models()

# --- 2. Simulate Multi-modal Data ---
# In a real scenario, you would load your actual datasets
# For demonstration, we'll create synthetic data that represents features extracted from pre-trained models

def simulate_multimodal_data(num_samples=1000):
    print("Simulating multi-modal data...")
    
    # Simulate tabular features (as if extracted from a tabular model)
    tabular_features = np.random.rand(num_samples, 64)  # 64 features from tabular model
    
    # Simulate image features (as if extracted from an image model)
    image_features = np.random.rand(num_samples, 128)   # 128 features from image model
    
    # Simulate unified target (combining information from both modalities)
    # In reality, this would be your ground truth labels
    unified_target = np.random.randint(0, 2, num_samples)
    
    return tabular_features, image_features, unified_target

tabular_features, image_features, unified_target = simulate_multimodal_data()

print(f"Tabular features shape: {tabular_features.shape}")
print(f"Image features shape: {image_features.shape}")
print(f"Unified target shape: {unified_target.shape}")

# --- 3. Split Data for Training ---
X_tab_train, X_tab_test, X_img_train, X_img_test, y_train, y_test = train_test_split(
    tabular_features, image_features, unified_target, 
    test_size=0.2, random_state=42, stratify=unified_target
)

print("Data split for multi-modal training:")
print(f"Training tabular features: {X_tab_train.shape}")
print(f"Training image features: {X_img_train.shape}")
print(f"Training targets: {y_train.shape}")

# --- 4. Build Unified Multi-modal Model ---
def build_unified_model(tabular_features_dim, image_features_dim, num_classes=2):
    """
    Builds a unified prediction model that takes fused features as input.
    """
    # Input layers for each modality's features
    input_tabular = Input(shape=(tabular_features_dim,), name='tabular_features_input')
    input_image = Input(shape=(image_features_dim,), name='image_features_input')
    
    # Optional: Add modality-specific processing layers
    tabular_processed = Dense(32, activation='relu')(input_tabular)
    tabular_processed = Dropout(0.3)(tabular_processed)
    
    image_processed = Dense(64, activation='relu')(input_image)
    image_processed = Dropout(0.3)(image_processed)
    
    # Concatenate features (Feature Fusion)
    merged_features = Concatenate()([tabular_processed, image_processed])
    
    # Add dense layers for the unified model
    x = Dense(128, activation='relu')(merged_features)
    x = Dropout(0.5)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.5)(x)
    
    # Output layer
    if num_classes > 2:
        output = Dense(num_classes, activation='softmax')(x)
        loss = 'categorical_crossentropy'
    else:
        output = Dense(1, activation='sigmoid')(x)
        loss = 'binary_crossentropy'
    
    unified_model = Model(inputs=[input_tabular, input_image], outputs=output)
    unified_model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
    
    return unified_model

# Build the unified model
unified_model = build_unified_model(tabular_features.shape[1], image_features.shape[1])
unified_model.summary()

# --- 5. Train the Unified Model ---
print("\nTraining unified multi-modal model...")

history = unified_model.fit(
    [X_tab_train, X_img_train], y_train,
    epochs=20, 
    batch_size=32,
    validation_data=([X_tab_test, X_img_test], y_test),
    verbose=1
)

print("Unified multi-modal model training complete.")

# --- 6. Plot Training History ---
def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot accuracy
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Multi-modal Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)
    
    # Plot loss
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Multi-modal Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

# --- 7. Evaluate the Unified Model ---
print("\nEvaluating unified multi-modal model...")

# Evaluate on test set
loss, accuracy = unified_model.evaluate([X_tab_test, X_img_test], y_test, verbose=0)
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

# Get predictions
y_pred_proba = unified_model.predict([X_tab_test, X_img_test]).flatten()
y_pred = (y_pred_proba > 0.5).astype(int)

# Calculate metrics
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_pred_proba)

print(f"\nDetailed Metrics:")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title("Confusion Matrix - Multi-modal Model")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

# ROC Curve
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Multi-modal Model')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

# --- 8. Compare with Individual Modalities ---
print("\n--- Comparison with Individual Modalities ---")

# Train models on individual modalities for comparison
from sklearn.linear_model import LogisticRegression

# Tabular-only model
tabular_only_model = LogisticRegression(random_state=42)
tabular_only_model.fit(X_tab_train, y_train)
tabular_pred = tabular_only_model.predict(X_tab_test)
tabular_accuracy = accuracy_score(y_test, tabular_pred)
print(f"Tabular-only model accuracy: {tabular_accuracy:.4f}")

# Image-only model
image_only_model = LogisticRegression(random_state=42)
image_only_model.fit(X_img_train, y_train)
image_pred = image_only_model.predict(X_img_test)
image_accuracy = accuracy_score(y_test, image_pred)
print(f"Image-only model accuracy: {image_accuracy:.4f}")

print(f"Multi-modal model accuracy: {accuracy:.4f}")

# Visualization of comparison
comparison_data = {
    'Model': ['Tabular Only', 'Image Only', 'Multi-modal'],
    'Accuracy': [tabular_accuracy, image_accuracy, accuracy]
}
comparison_df = pd.DataFrame(comparison_data)

plt.figure(figsize=(8, 5))
sns.barplot(data=comparison_df, x='Model', y='Accuracy')
plt.title('Model Performance Comparison')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
for i, v in enumerate(comparison_df['Accuracy']):
    plt.text(i, v + 0.01, f'{v:.3f}', ha='center')
plt.show()

# --- 9. Save the Unified Model ---
unified_model.save('unified_multimodal_model.h5')
print("\nUnified multi-modal model saved as 'unified_multimodal_model.h5'")

# --- 10. Prediction Function for New Data ---
def predict_multimodal(tabular_features, image_features, model):
    """
    Makes predictions using the unified multi-modal model.
    
    Args:
        tabular_features: Preprocessed tabular features
        image_features: Extracted image features
        model: Trained unified model
    
    Returns:
        prediction, confidence
    """
    prediction_proba = model.predict([tabular_features, image_features])[0][0]
    prediction = 1 if prediction_proba > 0.5 else 0
    confidence = prediction_proba if prediction == 1 else 1 - prediction_proba
    
    return prediction, confidence

# Example usage
sample_tabular = X_tab_test[:1]  # First test sample
sample_image = X_img_test[:1]    # First test sample
pred, conf = predict_multimodal(sample_tabular, sample_image, unified_model)
print(f"\nSample prediction: {pred} (Confidence: {conf:.4f})")
print(f"Actual label: {y_test[0]}")

print("\nMulti-modal integration complete!")
print("This notebook demonstrates the power of combining multiple data modalities for improved prediction accuracy.")
