# Ensemble Methods for Pneumonia Detection

This notebook demonstrates ensemble methods that combine multiple models to improve pneumonia detection accuracy. The approach integrates:

1. **Deep Learning Models**: Xception and Xception-LSTM architectures
2. **Traditional Machine Learning**: SVM with statistical feature extraction
3. **Ensemble Fusion**: Weighted averaging of model predictions

The ensemble approach leverages the strengths of different model types to achieve better generalization and robustness in medical image classification.

## 1. Setup and Configuration

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import joblib
import timm
import pandas as pd
import glob
import random
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from pathlib import Path

# Configuration
IMAGE_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

## 2. Model Architectures

### 2.1 Xception-LSTM Model

In [None]:
class XceptionLSTM(nn.Module):
    """
    Xception-LSTM hybrid model for spatial-temporal feature learning.
    Combines Xception CNN feature extraction with LSTM sequence modeling.
    """
    def __init__(self, freeze_layers=100):
        super(XceptionLSTM, self).__init__()
        self.xception = timm.create_model("xception", pretrained=True, features_only=True)
        
        # Freeze early layers for transfer learning
        for i, (name, param) in enumerate(self.xception.named_parameters()):
            if i < freeze_layers:
                param.requires_grad = False
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.reshape = nn.Flatten(2)
        self.transpose = lambda x: x.permute(0, 2, 1)
        
        # LSTM for sequence modeling of spatial features
        self.lstm = nn.LSTM(input_size=2048, hidden_size=256, batch_first=True)
        
        # Classification head
        self.fc1 = nn.Linear(256, 64)
        self.dropout = nn.Dropout(0.46)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # Extract features using Xception backbone
        x = self.xception(x)[-1]  # Use final feature layer
        x = self.pool(x)
        x = self.reshape(x)
        x = self.transpose(x)
        
        # Process through LSTM
        x, _ = self.lstm(x)
        x = x[:, -1, :]  # Use final LSTM output
        
        # Classification
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

### 2.2 Xception Fine-tuned Model

In [None]:
class XceptionFineTune(nn.Module):
    """
    Fine-tuned Xception model for pneumonia classification.
    Uses global average pooling and custom classification head.
    """
    def __init__(self):
        super(XceptionFineTune, self).__init__()
        self.xception = timm.create_model('xception', pretrained=True)
        
        # Remove default classification layers
        self.xception.global_pool = nn.Identity()
        self.xception.fc = nn.Identity()
        
        # Custom pooling and classification
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = self.xception(x)
        x = self.pool(x).view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 3. Model Loading Functions

In [None]:
def load_xception_lstm(weights_path="../models/xception_lstm_weights.pth"):
    """Load the XceptionLSTM model with pre-trained weights"""
    model = XceptionLSTM()
    if os.path.exists(weights_path):
        model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
        print(f"Loaded XceptionLSTM weights from {weights_path}")
    else:
        print(f"Warning: XceptionLSTM weights not found at {weights_path}")
        print("Using randomly initialized weights")
    
    model.to(DEVICE)
    model.eval()
    return model

def load_xception(weights_path="../models/xception_weights.pth"):
    """Load the Xception model with pre-trained weights"""
    model = XceptionFineTune()
    if os.path.exists(weights_path):
        model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
        print(f"Loaded Xception weights from {weights_path}")
    else:
        print(f"Warning: Xception weights not found at {weights_path}")
        print("Using randomly initialized weights")
    
    model.to(DEVICE)
    model.eval()
    return model

def load_svm(model_path="../models/svm_model.pkl", 
             scaler_path="../models/feature_scaler.pkl",
             features_file="../data/test_features.csv"):
    """
    Load the SVM model, feature scaler, and pre-extracted features
    
    Args:
        model_path: Path to the SVM model file
        scaler_path: Path to the feature scaler file
        features_file: Path to the CSV file containing pre-extracted features
    """
    if not all(os.path.exists(p) for p in [model_path, scaler_path]):
        print("Warning: SVM model or scaler not found")
        return None, None, {}
    
    # Load model and scaler
    svm_model = joblib.load(model_path)
    scaler = joblib.load(scaler_path)
    print(f"Loaded SVM model from {model_path}")
    print(f"Loaded feature scaler from {scaler_path}")
    
    # Load pre-extracted features if available
    features_dict = {}
    if os.path.exists(features_file):
        print(f"Loading features from {features_file}...")
        test_df = pd.read_csv(features_file, float_precision='high', low_memory=False)
        
        # Create feature dictionary
        feature_columns = [col for col in test_df.columns if col not in ['image_id', 'label']]
        
        for _, row in test_df.iterrows():
            image_id = os.path.splitext(os.path.basename(str(row['image_id'])))[0].lower()
            features = row[feature_columns].values.astype(np.float32)
            features = scaler.transform([features])[0]  # Scale the features
            features_dict[image_id] = features
        
        print(f"Loaded features for {len(features_dict)} images")
    else:
        print(f"Warning: Features file not found at {features_file}")
    
    return svm_model, scaler, features_dict

## 4. Image Preprocessing

In [None]:
def preprocess_image(image_path):
    """Preprocess image for CNN models"""
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0).to(DEVICE)

## 5. Ensemble Prediction System

In [None]:
def predict_ensemble(image_path, models=None, fallback_to_cnn=True):
    """
    Make prediction using ensemble of models
    
    Args:
        image_path: Path to the image file
        models: Dictionary containing pre-loaded models (optional)
        fallback_to_cnn: If True, use only CNNs if SVM features are missing
    """
    # Load models if not provided
    if models is None:
        models = {
            'xception_lstm': load_xception_lstm(),
            'xception': load_xception(),
            'svm_model': None,
            'features_dict': {}
        }
        models['svm_model'], _, models['features_dict'] = load_svm()
    
    # Preprocess image for CNN models
    image_tensor = preprocess_image(image_path)
    
    # Get image ID from path
    image_id = os.path.splitext(os.path.basename(image_path))[0].lower()
    
    # Get CNN predictions
    model_predictions = {}
    with torch.no_grad():
        if models['xception_lstm'] is not None:
            xc_lstm_pred = torch.sigmoid(models['xception_lstm'](image_tensor)).item()
            model_predictions['xception_lstm'] = xc_lstm_pred
        
        if models['xception'] is not None:
            xc_pred = torch.sigmoid(models['xception'](image_tensor)).item()
            model_predictions['xception'] = xc_pred
    
    # Get SVM prediction if available
    svm_pred = None
    if (models['svm_model'] is not None and 
        image_id in models['features_dict']):
        features = models['features_dict'][image_id].reshape(1, -1)
        svm_pred = models['svm_model'].predict_proba(features)[0][1]  # Probability for pneumonia
        model_predictions['svm'] = svm_pred
    elif not fallback_to_cnn:
        raise ValueError(f"Pre-extracted features not found for image {image_id}")
    
    # Ensemble prediction with weighted averaging
    valid_predictions = [p for p in model_predictions.values() if p is not None]
    
    if len(valid_predictions) == 0:
        raise ValueError("No valid predictions from any model")
    
    # Equal weighting for available models
    avg_pred = np.mean(valid_predictions)
    final_pred = 1 if avg_pred > 0.5 else 0
    confidence = avg_pred if final_pred == 1 else (1 - avg_pred)
    
    return {
        'final_prediction': 'PNEUMONIA' if final_pred == 1 else 'NORMAL',
        'confidence': confidence,
        'prob_pneumonia': avg_pred,
        'model_predictions': model_predictions,
        'num_models': len(valid_predictions)
    }

## 6. Evaluation Framework

In [None]:
def evaluate_ensemble(test_dir="../data/test", 
                     features_file="../data/test_features.csv",
                     fallback_to_cnn=True):
    """
    Evaluate the ensemble model on a test set
    
    Args:
        test_dir: Directory containing test images in 'pneumonia' and 'normal' subfolders
        features_file: Path to CSV file containing pre-extracted features
        fallback_to_cnn: If True, use only CNNs for images with missing SVM features
    """
    # Load all models once
    print("Loading models...")
    models = {
        'xception_lstm': load_xception_lstm(),
        'xception': load_xception(),
        'svm_model': None,
        'features_dict': {}
    }
    models['svm_model'], _, models['features_dict'] = load_svm(features_file=features_file)
    
    # Collect image paths and true labels
    image_paths = []
    true_labels = []
    
    # Look for images in subdirectories
    for class_name, label in [('normal', 0), ('pneumonia', 1)]:
        class_dir = os.path.join(test_dir, class_name)
        if os.path.exists(class_dir):
            class_images = glob.glob(os.path.join(class_dir, '*.jpg')) + \
                          glob.glob(os.path.join(class_dir, '*.jpeg')) + \
                          glob.glob(os.path.join(class_dir, '*.png'))
            image_paths.extend(class_images)
            true_labels.extend([label] * len(class_images))
    
    if not image_paths:
        print(f"Error: No images found in {test_dir}")
        return
    
    print(f"\nEvaluating on {len(image_paths)} test images...")
    
    # Get predictions
    prob_pneumonia_list = []
    model_counts = {'xception_lstm': 0, 'xception': 0, 'svm': 0}
    
    for i, img_path in enumerate(image_paths):
        if i % 100 == 0:
            print(f"Processing image {i+1}/{len(image_paths)}...")
        
        try:
            result = predict_ensemble(img_path, models, fallback_to_cnn)
            prob_pneumonia_list.append(result['prob_pneumonia'])
            
            # Count model usage
            for model_name in result['model_predictions']:
                model_counts[model_name] += 1
                
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            prob_pneumonia_list.append(0.5)  # Default prediction
    
    # Convert to numpy arrays
    true_labels = np.array(true_labels)
    prob_pneumonia = np.array(prob_pneumonia_list)
    pred_labels = (prob_pneumonia > 0.5).astype(int)
    
    # Calculate and display results
    print("\n" + "="*60)
    print("Ensemble Model Evaluation Results")
    print("="*60)
    
    print(f"\nModel Usage Statistics:")
    for model_name, count in model_counts.items():
        print(f"- {model_name.replace('_', ' ').title()}: {count}/{len(image_paths)} images")
    
    print("\nClassification Report:")
    print(classification_report(true_labels, pred_labels, 
                                target_names=['NORMAL', 'PNEUMONIA'],
                                digits=4))
    
    # Confusion Matrix
    cm = confusion_matrix(true_labels, pred_labels)
    print("\nConfusion Matrix:")
    print(cm)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, 
                                  display_labels=['NORMAL', 'PNEUMONIA'])
    disp.plot(cmap=plt.cm.Blues, values_format='d')
    plt.title('Ensemble Model - Confusion Matrix')
    plt.show()
    
    # Detailed metrics
    tn, fp, fn, tp = cm.ravel()
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print("\nDetailed Metrics:")
    print(f"{'Accuracy:':<15} {accuracy:.4f}")
    print(f"{'Precision:':<15} {precision:.4f}")
    print(f"{'Recall:':<15} {recall:.4f}")
    print(f"{'F1-Score:':<15} {f1:.4f}")
    print(f"{'True Positives:':<15} {tp}")
    print(f"{'False Positives:':<15} {fp}")
    print(f"{'True Negatives:':<15} {tn}")
    print(f"{'False Negatives:':<15} {fn}")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'model_counts': model_counts
    }

## 7. Single Image Prediction Demo

In [None]:
def demo_single_prediction(image_path):
    """
    Demonstrate ensemble prediction on a single image
    """
    if not os.path.exists(image_path):
        print(f"Error: Image file {image_path} not found")
        return
    
    print(f"Making prediction for: {os.path.basename(image_path)}")
    print("-" * 50)
    
    try:
        result = predict_ensemble(image_path, fallback_to_cnn=True)
        
        print(f"Final Prediction: {result['final_prediction']}")
        print(f"Confidence: {result['confidence']:.4f}")
        print(f"Pneumonia Probability: {result['prob_pneumonia']:.4f}")
        print(f"Models Used: {result['num_models']}")
        
        print("\nIndividual Model Predictions:")
        for model, pred in result['model_predictions'].items():
            model_name = model.replace('_', ' ').title()
            if pred is not None:
                print(f"- {model_name}: {pred:.4f}")
            else:
                print(f"- {model_name}: Not Available")
                
    except Exception as e:
        print(f"Error making prediction: {e}")

## 8. Model Comparison Analysis

In [None]:
def compare_individual_vs_ensemble(test_dir="../data/test", sample_size=100):
    """
    Compare performance of individual models vs ensemble
    """
    print("Comparing Individual Models vs Ensemble")
    print("=" * 50)
    
    # Load models
    models = {
        'xception_lstm': load_xception_lstm(),
        'xception': load_xception(),
        'svm_model': None,
        'features_dict': {}
    }
    models['svm_model'], _, models['features_dict'] = load_svm()
    
    # Get test images
    image_paths = []
    true_labels = []
    
    for class_name, label in [('normal', 0), ('pneumonia', 1)]:
        class_dir = os.path.join(test_dir, class_name)
        if os.path.exists(class_dir):
            class_images = glob.glob(os.path.join(class_dir, '*.jpg'))[:sample_size//2]
            image_paths.extend(class_images)
            true_labels.extend([label] * len(class_images))
    
    if not image_paths:
        print("No test images found")
        return
    
    print(f"Testing on {len(image_paths)} images...")
    
    # Collect predictions from each model
    predictions = {
        'xception_lstm': [],
        'xception': [],
        'svm': [],
        'ensemble': []
    }
    
    for img_path in image_paths:
        image_tensor = preprocess_image(img_path)
        image_id = os.path.splitext(os.path.basename(img_path))[0].lower()
        
        with torch.no_grad():
            # Individual model predictions
            if models['xception_lstm'] is not None:
                pred = torch.sigmoid(models['xception_lstm'](image_tensor)).item()
                predictions['xception_lstm'].append(pred)
            
            if models['xception'] is not None:
                pred = torch.sigmoid(models['xception'](image_tensor)).item()
                predictions['xception'].append(pred)
            
            # SVM prediction
            if (models['svm_model'] is not None and 
                image_id in models['features_dict']):
                features = models['features_dict'][image_id].reshape(1, -1)
                pred = models['svm_model'].predict_proba(features)[0][1]
                predictions['svm'].append(pred)
            else:
                predictions['svm'].append(None)
        
        # Ensemble prediction
        try:
            result = predict_ensemble(img_path, models, fallback_to_cnn=True)
            predictions['ensemble'].append(result['prob_pneumonia'])
        except:
            predictions['ensemble'].append(0.5)
    
    # Calculate metrics for each model
    results = {}
    true_labels = np.array(true_labels)
    
    for model_name, preds in predictions.items():
        valid_preds = [p for p in preds if p is not None]
        if len(valid_preds) > 0:
            valid_indices = [i for i, p in enumerate(preds) if p is not None]
            valid_true = true_labels[valid_indices]
            pred_labels = (np.array(valid_preds) > 0.5).astype(int)
            
            accuracy = np.mean(pred_labels == valid_true)
            results[model_name] = {
                'accuracy': accuracy,
                'samples': len(valid_preds)
            }
    
    # Display results
    print("\nModel Performance Comparison:")
    print("-" * 40)
    for model_name, metrics in results.items():
        print(f"{model_name.replace('_', ' ').title()::<15} Accuracy: {metrics['accuracy']:.4f} ({metrics['samples']} samples)")
    
    return results

## 9. Running the Evaluation

Uncomment the cells below to run the ensemble evaluation:

In [None]:
# Example: Evaluate ensemble on test set
# Note: Adjust paths according to your data structure

# results = evaluate_ensemble(
#     test_dir="../data/test",
#     features_file="../data/test_features.csv",
#     fallback_to_cnn=True
# )

print("Ensemble evaluation framework is ready.")
print("Uncomment the above code to run evaluation when models and data are available.")

In [None]:
# Example: Single image prediction
# demo_single_prediction("../data/test/pneumonia/sample_image.jpg")

print("Single image prediction demo is ready.")
print("Uncomment the above code and provide an image path to test.")

In [None]:
# Example: Compare individual models vs ensemble
# comparison_results = compare_individual_vs_ensemble(
#     test_dir="../data/test",
#     sample_size=200
# )

print("Model comparison framework is ready.")
print("Uncomment the above code to compare individual models vs ensemble.")

## 10. Key Insights and Benefits

### Ensemble Advantages:

1. **Complementary Strengths**: 
   - CNNs excel at spatial feature detection
   - LSTM captures sequential patterns in feature maps
   - SVM provides robust statistical classification

2. **Improved Robustness**:
   - Reduces individual model weaknesses
   - Better generalization to unseen data
   - More reliable predictions in clinical settings

3. **Confidence Estimation**:
   - Agreement between models indicates higher confidence
   - Disagreement suggests uncertain cases requiring human review

4. **Flexibility**:
   - Graceful degradation when some models are unavailable
   - Adaptable weighting schemes for different scenarios

### Implementation Notes:

- **Feature Extraction**: SVM requires pre-computed statistical features from ROI analysis
- **Model Loading**: Each component model can be trained independently
- **Inference Speed**: Trade-off between accuracy and computational cost
- **Clinical Integration**: Ensemble provides interpretable confidence scores

This ensemble approach demonstrates how multiple AI techniques can be combined to create a more robust and reliable pneumonia detection system suitable for clinical applications.