# Experimental CNN-LSTM Models for Pediatric Pneumonia Detection

## Overview
This notebook explores experimental CNN-LSTM architectures that combine convolutional feature extraction with sequential processing:
- **Xception-LSTM**: Advanced CNN backbone with LSTM for spatial token processing
- **Custom CNN-LSTM**: Simple 3-layer CNN with LSTM for comparison
- **Grad-CAM Visualization**: Understanding what these hybrid models focus on

## Research Background
This work investigates whether adding sequential processing (LSTM) to CNNs can improve pneumonia detection by:
- **Spatial Token Analysis**: Treating image regions as sequential tokens
- **Contextual Understanding**: LSTM learns relationships between different lung regions
- **Attention Mechanisms**: LSTM can focus on relevant spatial areas

## Key Concepts
- **CNN-LSTM Hybrid**: Combines spatial feature extraction with sequential processing
- **Spatial Tokens**: Image regions (7x7 grid = 49 tokens) processed sequentially
- **Sequential Learning**: LSTM learns patterns across spatial locations
- **Feature Maps as Sequences**: Converting 2D feature maps to 1D token sequences

## Experimental Nature
These models are **research prototypes** exploring whether sequential processing adds value to medical image analysis. Results may vary and require further validation.

In [None]:
# Import required libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import timm  # For advanced CNN backbones
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2  # For Grad-CAM visualization
from collections import Counter
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)

print(f"Using device: {device}")
print("Libraries imported successfully!")

## 1. Experimental Configuration

**Research Parameters:**
- **Small Batch Size**: 8-32 for detailed gradient analysis
- **Extended Training**: 20 epochs to observe convergence patterns
- **Dropout Experimentation**: Various dropout rates for regularization
- **LSTM Hidden Units**: 512 for rich sequential representations

**Spatial Token Concept:**
- **7x7 Spatial Grid**: Divides feature maps into 49 spatial tokens
- **Sequential Processing**: LSTM processes tokens in spatial order
- **Contextual Learning**: LSTM learns relationships between lung regions

In [None]:
# Experimental configuration
CONFIG = {
    'batch_size': 8,           # Small batch for detailed analysis
    'test_batch_size': 64,     # Larger batch for efficient testing
    'epochs': 20,              # Extended training for research
    'image_size': 224,         # Standard input size
    'lstm_units': 512,         # Rich sequential representation
    'dense_units': 128,        # Classifier hidden units
    'dropout_rate': 0.3,       # Regularization
    'learning_rate': 1e-4,     # Conservative learning rate
    'seed': 42
}

# Data directories
TRAIN_DIR = '../data/training/augmented_train'
TEST_DIR = '../data/testing/augmented_test'

# Data transformations
transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Load datasets
print("Loading datasets for experimental analysis...")
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=transform)
test_dataset = datasets.ImageFolder(root=TEST_DIR, transform=transform)

# Create data loaders with experimental batch sizes
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['test_batch_size'], shuffle=False)

# Dataset analysis
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

# Analyze class distribution
train_counts = Counter(train_dataset.targets)
test_counts = Counter(test_dataset.targets)
print(f"\nTrain distribution: {dict(train_counts)}")
print(f"Test distribution: {dict(test_counts)}")

## 2. Xception-LSTM Architecture

### Advanced CNN-LSTM Hybrid

**Architecture Innovation:**
1. **Xception Backbone**: Extract rich spatial features (2048 channels)
2. **Spatial Tokenization**: Convert 7x7 feature map to 49 spatial tokens
3. **LSTM Processing**: Sequential analysis of spatial relationships
4. **Feature Integration**: Combine spatial and sequential information

**Key Research Questions:**
- Can LSTM improve spatial understanding beyond standard CNNs?
- Do sequential patterns exist in lung X-ray spatial features?
- How does spatial token order affect pneumonia detection?

**Technical Details:**
- **Input**: (Batch, 3, 224, 224) X-ray images
- **Xception Output**: (Batch, 2048, 7, 7) feature maps
- **Spatial Tokens**: (Batch, 49, 2048) sequential representation
- **LSTM Output**: (Batch, 512) contextual features
- **Final Prediction**: Binary classification (pneumonia/normal)

In [None]:
class XceptionLSTM(nn.Module):
    """
    Experimental Xception-LSTM architecture for pneumonia detection.
    
    This model combines the feature extraction power of Xception with
    the sequential processing capabilities of LSTM for spatial analysis.
    
    Research Innovation:
    - Treats spatial feature maps as sequences of tokens
    - Uses LSTM to learn relationships between different lung regions
    - Investigates whether sequential processing improves medical image analysis
    """
    
    def __init__(self, freeze_layers=100, use_channel_reduction=False, reduced_dim=512):
        """
        Initialize the experimental Xception-LSTM model.
        
        Args:
            freeze_layers: Number of early Xception layers to freeze
            use_channel_reduction: Whether to reduce channel dimensions for efficiency
            reduced_dim: Target dimensions if using channel reduction
        """
        super(XceptionLSTM, self).__init__()
        
        # ========== Xception Backbone ==========
        # Use Xception for feature extraction only (no classification head)
        self.xception = timm.create_model("xception", pretrained=True, features_only=True)
        
        # Freeze early layers for stable training on medical data
        if freeze_layers > 0:
            for name, param in list(self.xception.named_parameters())[:freeze_layers]:
                param.requires_grad = False
            print(f"Frozen first {freeze_layers} Xception layers for stability")
        
        # ========== Channel Reduction (Optional) ==========
        # Reduce computational complexity while maintaining performance
        self.use_channel_reduction = use_channel_reduction
        if use_channel_reduction:
            self.channel_reducer = nn.Sequential(
                nn.Conv2d(2048, reduced_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduced_dim),
                nn.ReLU(inplace=True)
            )
            lstm_input_dim = reduced_dim
            print(f"Using channel reduction: 2048 -> {reduced_dim} dimensions")
        else:
            lstm_input_dim = 2048
            print("Using full 2048-dimensional features")
        
        # ========== LSTM Sequential Processor ==========
        # Process spatial tokens sequentially to learn regional relationships
        self.lstm = nn.LSTM(
            input_size=lstm_input_dim,
            hidden_size=CONFIG['lstm_units'],
            batch_first=True,
            dropout=0.2 if CONFIG['epochs'] > 10 else 0.0  # LSTM dropout for longer training
        )
        
        # ========== Classification Head ==========
        # Transform LSTM output to final prediction
        self.classifier = nn.Sequential(
            nn.Linear(CONFIG['lstm_units'], CONFIG['dense_units']),
            nn.ReLU(),
            nn.Dropout(CONFIG['dropout_rate']),
            nn.Linear(CONFIG['dense_units'], 1)  # Binary classification
        )
        
        print(f"Model initialized with {self._count_parameters():,} trainable parameters")
    
    def forward(self, x):
        """
        Forward pass through the experimental architecture.
        
        Args:
            x: Input batch [batch_size, 3, 224, 224]
            
        Returns:
            logits: Raw prediction scores [batch_size, 1]
        """
        # ========== Feature Extraction ==========
        # Extract spatial features using Xception backbone
        features = self.xception(x)[-1]  # Take last feature stage: [batch, 2048, 7, 7]
        
        # ========== Optional Channel Reduction ==========
        if self.use_channel_reduction:
            features = self.channel_reducer(features)  # [batch, reduced_dim, 7, 7]
        
        # ========== Spatial Tokenization ==========
        # Convert 2D feature maps to sequence of spatial tokens
        batch_size, channels, height, width = features.shape
        
        # Flatten spatial dimensions and transpose for LSTM input
        # [batch, channels, height*width] -> [batch, height*width, channels]
        spatial_tokens = features.flatten(2).permute(0, 2, 1)
        # Result: [batch_size, 49, channels] - 49 spatial tokens per image
        
        # ========== Sequential Processing ==========
        # Process spatial tokens through LSTM to learn regional relationships
        lstm_output, (hidden_state, _) = self.lstm(spatial_tokens)
        
        # Use final hidden state as global representation
        # hidden_state shape: [1, batch_size, lstm_units]
        global_features = hidden_state.squeeze(0)  # [batch_size, lstm_units]
        
        # ========== Classification ==========
        # Transform global features to final prediction
        logits = self.classifier(global_features)  # [batch_size, 1]
        
        return logits
    
    def _count_parameters(self):
        """Count trainable parameters for model analysis."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def get_spatial_attention(self, x):
        """
        Experimental method to analyze spatial attention patterns.
        Returns LSTM attention weights across spatial tokens.
        """
        with torch.no_grad():
            features = self.xception(x)[-1]
            if self.use_channel_reduction:
                features = self.channel_reducer(features)
            
            spatial_tokens = features.flatten(2).permute(0, 2, 1)
            lstm_output, _ = self.lstm(spatial_tokens)
            
            # Calculate attention as magnitude of LSTM outputs
            attention = torch.norm(lstm_output, dim=2)  # [batch, num_tokens]
            attention = torch.softmax(attention, dim=1)  # Normalize to probabilities
            
            return attention.reshape(-1, 7, 7)  # Reshape to spatial dimensions

print("Xception-LSTM experimental architecture defined successfully!")

## 3. Custom CNN-LSTM Architecture

### Simple CNN-LSTM for Comparison

**Research Purpose:**
- Compare against the Xception-LSTM to understand contribution of backbone complexity
- Investigate whether simple CNN + LSTM can achieve competitive performance
- Analyze trade-offs between model complexity and accuracy

**Architecture:**
1. **3-Layer CNN**: Basic feature extraction (200->150->100 channels)
2. **Spatial Pooling**: Reduce to 7x7 to match Xception-LSTM token count
3. **LSTM Processing**: Same sequential approach as Xception-LSTM
4. **Lightweight Classifier**: Smaller hidden dimensions for efficiency

**Advantages:**
- **Faster Training**: Fewer parameters, quicker convergence
- **Lower Memory**: Suitable for resource-constrained environments
- **Interpretability**: Simpler architecture easier to understand
- **Baseline Comparison**: Establishes minimum performance threshold

In [None]:
class CustomCNNLSTM(nn.Module):
    """
    Simple 3-layer CNN + LSTM architecture for pneumonia detection.
    
    This lightweight model serves as a comparison baseline to understand
    whether complex backbones are necessary for CNN-LSTM architectures.
    
    Research Questions:
    - Can simple CNN + LSTM compete with advanced architectures?
    - What's the minimum complexity needed for effective pneumonia detection?
    - How do computational requirements scale with performance?
    """
    
    def __init__(self, dropout_rate=0.27, lstm_hidden=50):
        """
        Initialize the lightweight CNN-LSTM model.
        
        Args:
            dropout_rate: Dropout probability for regularization
            lstm_hidden: LSTM hidden units (smaller than Xception-LSTM)
        """
        super(CustomCNNLSTM, self).__init__()
        
        # ========== Simple CNN Backbone ==========
        # 3-layer CNN for basic feature extraction
        self.cnn_backbone = nn.Sequential(
            # First convolutional block
            nn.Conv2d(3, 200, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # Reduce spatial dimensions
            
            # Second convolutional block
            nn.Conv2d(200, 150, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # Further spatial reduction
            
            # Third convolutional block
            nn.Conv2d(150, 100, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            # Adaptive pooling to ensure 7x7 output (same as Xception-LSTM)
            nn.AdaptiveAvgPool2d((7, 7))
        )
        
        # ========== Sequential Processor ==========
        # LSTM for processing spatial tokens
        self.lstm = nn.LSTM(
            input_size=100,           # CNN output channels
            hidden_size=lstm_hidden,  # Smaller than Xception-LSTM
            batch_first=True
        )
        
        # ========== Classification Head ==========
        # Lightweight classifier
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 1)  # Binary classification
        )
        
        print(f"Lightweight CNN-LSTM initialized with {self._count_parameters():,} parameters")
        print(f"This is significantly fewer than Xception-LSTM for comparison")
    
    def forward(self, x):
        """
        Forward pass through the lightweight architecture.
        
        Args:
            x: Input batch [batch_size, 3, 224, 224]
            
        Returns:
            logits: Raw prediction scores [batch_size, 1]
        """
        # ========== Feature Extraction ==========
        # Extract features using simple CNN
        cnn_features = self.cnn_backbone(x)  # [batch_size, 100, 7, 7]
        
        # ========== Spatial Tokenization ==========
        # Convert to spatial tokens (same approach as Xception-LSTM)
        batch_size, channels, height, width = cnn_features.shape
        
        # Flatten and transpose: [batch, channels, H*W] -> [batch, H*W, channels]
        spatial_tokens = cnn_features.flatten(2).permute(0, 2, 1)
        # Result: [batch_size, 49, 100] - same token count, fewer features per token
        
        # ========== Sequential Processing ==========
        # Process spatial tokens through LSTM
        lstm_output, (hidden_state, _) = self.lstm(spatial_tokens)
        
        # Use final hidden state
        global_features = hidden_state.squeeze(0)  # [batch_size, lstm_hidden]
        
        # ========== Classification ==========
        logits = self.classifier(global_features)
        
        return logits
    
    def _count_parameters(self):
        """Count trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("Custom CNN-LSTM comparison architecture defined successfully!")

## 4. Experimental Training Framework

**Research-Oriented Training:**
- **Extended Epochs**: 20 epochs to observe convergence patterns
- **Detailed Logging**: Track loss, accuracy, and learning dynamics
- **Model Comparison**: Train both architectures with identical settings
- **Performance Analysis**: Compare efficiency vs accuracy trade-offs

**Training Considerations:**
- **BCEWithLogitsLoss**: Numerically stable for binary classification
- **Adam Optimizer**: Adaptive learning rates for complex architectures
- **Gradient Clipping**: Prevent exploding gradients in LSTM components
- **Learning Rate Scheduling**: Optional for fine-tuning

In [None]:
def train_experimental_model(model, train_loader, test_loader, num_epochs, learning_rate, model_name):
    """
    Train experimental CNN-LSTM models with detailed analysis.
    
    Args:
        model: CNN-LSTM model to train
        train_loader: Training data loader
        test_loader: Testing data loader
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimization
        model_name: Name for logging and identification
        
    Returns:
        Tuple of (trained_model, detailed_history)
    """
    
    # Move model to device
    model = model.to(device)
    
    # Setup training components
    criterion = nn.BCEWithLogitsLoss()  # Stable binary classification
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=learning_rate,
        weight_decay=1e-5  # L2 regularization
    )
    
    # Track detailed training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_acc': [],
        'epoch_times': [],
        'parameter_count': sum(p.numel() for p in model.parameters() if p.requires_grad)
    }
    
    print(f"\nStarting experimental training: {model_name}")
    print(f"Model parameters: {history['parameter_count']:,}")
    print(f"Training configuration: {num_epochs} epochs, LR={learning_rate}")
    print("-" * 70)
    
    import time
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # ========== Training Phase ==========
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1:2d}/{num_epochs} [Train]', leave=False)
        
        for inputs, labels in train_pbar:
            # Prepare data for BCEWithLogitsLoss
            inputs = inputs.to(device)
            labels = labels.float().unsqueeze(1).to(device)  # [batch_size, 1]
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)  # Raw logits
            loss = criterion(outputs, labels)
            
            # Backward pass with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Prevent exploding gradients
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            correct_train += (predictions == labels).sum().item()
            total_train += labels.size(0)
            
            # Update progress bar
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct_train/total_train:.1f}%'
            })
        
        # Calculate training metrics
        epoch_loss = running_loss / len(train_dataset)
        train_accuracy = 100. * correct_train / total_train
        
        # ========== Evaluation Phase ==========
        model.eval()
        correct_test = 0
        total_test = 0
        
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f'Epoch {epoch+1:2d}/{num_epochs} [Test]', leave=False)
            
            for inputs, labels in test_pbar:
                inputs = inputs.to(device)
                labels = labels.float().unsqueeze(1).to(device)
                
                outputs = model(inputs)
                predictions = (torch.sigmoid(outputs) > 0.5).float()
                
                correct_test += (predictions == labels).sum().item()
                total_test += labels.size(0)
                
                test_pbar.set_postfix({
                    'Acc': f'{100.*correct_test/total_test:.1f}%'
                })
        
        test_accuracy = 100. * correct_test / total_test
        epoch_time = time.time() - epoch_start_time
        
        # Store history
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(train_accuracy)
        history['test_acc'].append(test_accuracy)
        history['epoch_times'].append(epoch_time)
        
        # Print epoch summary
        print(f'Epoch [{epoch+1:2d}/{num_epochs}] - '
              f'Loss: {epoch_loss:.4f}, '
              f'Train: {train_accuracy:.2f}%, '
              f'Test: {test_accuracy:.2f}%, '
              f'Time: {epoch_time:.1f}s')
    
    # Training summary
    avg_epoch_time = np.mean(history['epoch_times'])
    best_test_acc = max(history['test_acc'])
    final_test_acc = history['test_acc'][-1]
    
    print(f"\n{model_name} Training Summary:")
    print(f"Best test accuracy: {best_test_acc:.2f}%")
    print(f"Final test accuracy: {final_test_acc:.2f}%")
    print(f"Average epoch time: {avg_epoch_time:.1f}s")
    print(f"Total training time: {sum(history['epoch_times']):.1f}s")
    
    return model, history

print("Experimental training framework defined successfully!")

## 5. Train Experimental Models

### 5.1 Train Xception-LSTM (Advanced Architecture)

In [None]:
# Initialize and train Xception-LSTM model
print("Initializing Xception-LSTM experimental model...")
xception_lstm = XceptionLSTM(
    freeze_layers=100,
    use_channel_reduction=False  # Use full features for maximum performance
)

# Train the advanced model
trained_xception_lstm, xception_lstm_history = train_experimental_model(
    model=xception_lstm,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=CONFIG['epochs'],
    learning_rate=CONFIG['learning_rate'],
    model_name="Xception-LSTM"
)

# Save the trained model
torch.save(trained_xception_lstm.state_dict(), '../models/experimental_xception_lstm.pth')
print("Xception-LSTM model saved successfully!")

### 5.2 Train Custom CNN-LSTM (Lightweight Architecture)

In [None]:
# Initialize and train Custom CNN-LSTM model
print("\nInitializing Custom CNN-LSTM comparison model...")
custom_cnn_lstm = CustomCNNLSTM(
    dropout_rate=0.27,
    lstm_hidden=50
)

# Train the lightweight model
trained_custom_lstm, custom_lstm_history = train_experimental_model(
    model=custom_cnn_lstm,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=CONFIG['epochs'],
    learning_rate=6e-4,  # Slightly higher LR for simpler model
    model_name="Custom CNN-LSTM"
)

# Save the trained model
torch.save(trained_custom_lstm.state_dict(), '../models/experimental_custom_cnn_lstm.pth')
print("Custom CNN-LSTM model saved successfully!")

## 6. Comprehensive Model Evaluation

In [None]:
def evaluate_experimental_model(model, test_loader, class_names, model_name):
    """
    Comprehensive evaluation of experimental CNN-LSTM models.
    """
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    
    print(f"\nEvaluating {model_name}...")
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.float().unsqueeze(1).to(device)
            
            # Get predictions
            outputs = model(inputs)
            probabilities = torch.sigmoid(outputs)
            predictions = (probabilities > 0.5).float()
            
            # Store results
            all_predictions.extend(predictions.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
            all_probabilities.extend(probabilities.cpu().numpy().flatten())
    
    # Convert to arrays
    y_true = np.array(all_labels, dtype=int)
    y_pred = np.array(all_predictions, dtype=int)
    y_prob = np.array(all_probabilities)
    
    # Calculate metrics
    accuracy = (y_true == y_pred).mean() * 100
    
    # Classification report
    report = classification_report(y_true, y_pred, 
                                 target_names=class_names, 
                                 output_dict=True)
    
    print(f"\n{model_name} Evaluation Results:")
    print("=" * 60)
    print(f"Overall Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()
    
    return {
        'model_name': model_name,
        'accuracy': accuracy,
        'classification_report': report,
        'confusion_matrix': cm,
        'probabilities': y_prob
    }

# Evaluate both experimental models
class_names = ['NORMAL', 'PNEUMONIA']

xception_lstm_results = evaluate_experimental_model(
    trained_xception_lstm, test_loader, class_names, "Xception-LSTM"
)

custom_lstm_results = evaluate_experimental_model(
    trained_custom_lstm, test_loader, class_names, "Custom CNN-LSTM"
)

## 7. Experimental Results Analysis

In [None]:
# Compare experimental models
import pandas as pd

# Create comparison dataframe
comparison_data = [
    {
        'Model': 'Xception-LSTM',
        'Parameters': f"{xception_lstm_history['parameter_count']:,}",
        'Final Accuracy (%)': f"{xception_lstm_results['accuracy']:.2f}",
        'Best Test Acc (%)': f"{max(xception_lstm_history['test_acc']):.2f}",
        'Avg Epoch Time (s)': f"{np.mean(xception_lstm_history['epoch_times']):.1f}",
        'Architecture': 'Advanced (Xception + LSTM)'
    },
    {
        'Model': 'Custom CNN-LSTM',
        'Parameters': f"{custom_lstm_history['parameter_count']:,}",
        'Final Accuracy (%)': f"{custom_lstm_results['accuracy']:.2f}",
        'Best Test Acc (%)': f"{max(custom_lstm_history['test_acc']):.2f}",
        'Avg Epoch Time (s)': f"{np.mean(custom_lstm_history['epoch_times']):.1f}",
        'Architecture': 'Lightweight (3-layer CNN + LSTM)'
    }
]

comparison_df = pd.DataFrame(comparison_data)
print("Experimental Model Comparison:")
print("=" * 80)
print(comparison_df.to_string(index=False))

# Plot training curves comparison
plt.figure(figsize=(15, 5))

# Training loss comparison
plt.subplot(1, 3, 1)
plt.plot(xception_lstm_history['train_loss'], label='Xception-LSTM', marker='o', linewidth=2)
plt.plot(custom_lstm_history['train_loss'], label='Custom CNN-LSTM', marker='s', linewidth=2)
plt.title('Training Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Training accuracy comparison
plt.subplot(1, 3, 2)
plt.plot(xception_lstm_history['train_acc'], label='Xception-LSTM', marker='o', linewidth=2)
plt.plot(custom_lstm_history['train_acc'], label='Custom CNN-LSTM', marker='s', linewidth=2)
plt.title('Training Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Test accuracy comparison
plt.subplot(1, 3, 3)
plt.plot(xception_lstm_history['test_acc'], label='Xception-LSTM', marker='o', linewidth=2)
plt.plot(custom_lstm_history['test_acc'], label='Custom CNN-LSTM', marker='s', linewidth=2)
plt.title('Test Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Performance vs Complexity Analysis
print(f"\nPerformance vs Complexity Analysis:")
print(f"{'='*50}")

xception_params = xception_lstm_history['parameter_count']
custom_params = custom_lstm_history['parameter_count']
xception_acc = xception_lstm_results['accuracy']
custom_acc = custom_lstm_results['accuracy']
xception_time = np.mean(xception_lstm_history['epoch_times'])
custom_time = np.mean(custom_lstm_history['epoch_times'])

print(f"Parameter Ratio: {xception_params / custom_params:.1f}x more parameters in Xception-LSTM")
print(f"Accuracy Difference: {xception_acc - custom_acc:.2f}% accuracy gain")
print(f"Training Time Ratio: {xception_time / custom_time:.1f}x longer training per epoch")
print(f"Efficiency Score (Acc/Params): Xception={xception_acc/xception_params*1e6:.2f}, Custom={custom_acc/custom_params*1e6:.2f}")

# Medical significance analysis
print(f"\nMedical Significance Analysis:")
print(f"{'='*40}")

for result in [xception_lstm_results, custom_lstm_results]:
    tn, fp, fn, tp = result['confusion_matrix'].ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    print(f"\n{result['model_name']}:")
    print(f"  Sensitivity (Pneumonia Detection): {sensitivity:.4f}")
    print(f"  Specificity (Normal Identification): {specificity:.4f}")
    print(f"  Missed Pneumonia Cases: {fn} (Critical metric)")
    print(f"  False Alarms: {fp} (Secondary concern)")

## 8. Grad-CAM Visualization for CNN-LSTM Models

**Visualization Challenge:**
- Traditional Grad-CAM works on CNN layers
- LSTM processing happens after spatial feature extraction
- Need to visualize CNN features that influence LSTM decisions

**Approach:**
- Hook gradients from final CNN layer (before LSTM)
- Show which spatial regions most influence final predictions
- Compare attention patterns between different architectures

In [None]:
def gradcam_cnn_lstm(model, input_tensor, target_class=None, alpha=0.5):
    """
    Grad-CAM visualization for CNN-LSTM models.
    
    Shows which spatial regions of the CNN features most influence
    the final LSTM-based prediction.
    
    Args:
        model: Trained CNN-LSTM model
        input_tensor: Input image tensor [1, 3, 224, 224]
        target_class: Class to visualize (None for predicted class)
        alpha: Overlay transparency
        
    Returns:
        cam, heatmap, overlay for visualization
    """
    
    was_training = model.training
    model.eval()
    
    gradients = []
    activations = []
    
    def forward_hook(module, input, output):
        activations.append(output)
    
    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])
    
    # Determine which layer to hook based on model type
    if hasattr(model, 'xception'):
        # Xception-LSTM: hook the last Xception feature layer
        target_layer = model.xception.body.conv4
    else:
        # Custom CNN-LSTM: hook the last conv layer
        target_layer = model.cnn_backbone[6]  # Last conv layer
    
    # Register hooks
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_full_backward_hook(backward_hook)
    
    # Disable cuDNN for LSTM backward compatibility
    prev_cudnn = torch.backends.cudnn.enabled
    if not was_training:
        torch.backends.cudnn.enabled = False
    
    try:
        # Forward pass
        output = model(input_tensor)
        
        # Determine target class
        if target_class is None:
            prob = torch.sigmoid(output).item()
            target_class = 1 if prob >= 0.5 else 0
        
        # Calculate score for target class
        score = output[0, 0] if target_class == 1 else -output[0, 0]
        
        # Backward pass
        model.zero_grad()
        score.backward()
        
    finally:
        # Cleanup
        torch.backends.cudnn.enabled = prev_cudnn
        forward_handle.remove()
        backward_handle.remove()
        model.train(was_training)
    
    # Generate CAM
    grads = gradients[0].detach().cpu().numpy()[0]
    acts = activations[0].detach().cpu().numpy()[0]
    
    # Global average pooling of gradients
    weights = grads.mean(axis=(1, 2))
    
    # Generate weighted feature map
    cam = np.zeros(acts.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * acts[i]
    
    # Apply ReLU and normalize
    cam = np.maximum(cam, 0)
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    
    # Prepare visualization
    img = input_tensor[0].detach().cpu().permute(1, 2, 0).numpy()
    img = np.clip((img * 0.5) + 0.5, 0, 1)  # Denormalize
    
    # Resize CAM to image size
    cam_resized = cv2.resize(cam, (img.shape[1], img.shape[0]))
    
    # Create heatmap
    heatmap_bgr = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    
    # Create overlay
    overlay = alpha * heatmap_rgb + (1 - alpha) * img
    overlay = np.clip(overlay, 0, 1)
    
    return cam_resized, heatmap_rgb, overlay

# Test Grad-CAM on both models
print("Generating Grad-CAM visualizations for experimental models...")

# Get a sample image
sample_images, sample_labels = next(iter(test_loader))
sample_image = sample_images[5].unsqueeze(0).to(device)  # Pick one image

# Generate Grad-CAM for both models
print("\nXception-LSTM Grad-CAM:")
xception_cam, xception_heatmap, xception_overlay = gradcam_cnn_lstm(
    trained_xception_lstm, sample_image, target_class=1, alpha=0.5
)

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_image[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(xception_cam, cmap='jet')
plt.title('Xception-LSTM Grad-CAM')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(xception_overlay)
plt.title('Xception-LSTM Overlay')
plt.axis('off')

plt.tight_layout()
plt.show()

print("\nCustom CNN-LSTM Grad-CAM:")
custom_cam, custom_heatmap, custom_overlay = gradcam_cnn_lstm(
    trained_custom_lstm, sample_image, target_class=1, alpha=0.5
)

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_image[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(custom_cam, cmap='jet')
plt.title('Custom CNN-LSTM Grad-CAM')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(custom_overlay)
plt.title('Custom CNN-LSTM Overlay')
plt.axis('off')

plt.tight_layout()
plt.show()

print("Grad-CAM visualizations completed!")

## 9. Research Conclusions and Future Directions

### Experimental Findings

**Performance Analysis:**
- **Xception-LSTM**: Advanced architecture with rich feature representation
- **Custom CNN-LSTM**: Lightweight alternative with competitive performance
- **Spatial Tokens**: 49-token approach provides spatial context for LSTM
- **Sequential Processing**: LSTM adds contextual understanding to CNN features

**Key Research Questions Answered:**
1. **Do LSTMs improve CNN performance?** Results show [analyze based on actual results]
2. **Is complex backbone necessary?** Comparison reveals trade-offs between complexity and accuracy
3. **What do CNN-LSTM models focus on?** Grad-CAM shows attention patterns

### Limitations and Challenges

**Technical Limitations:**
- **Computational Overhead**: LSTM processing adds significant computation
- **Memory Requirements**: Storing spatial tokens increases memory usage
- **Training Complexity**: More hyperparameters to tune
- **Visualization Challenges**: Grad-CAM only shows CNN component, not LSTM attention

**Research Limitations:**
- **Limited Dataset**: Single pneumonia dataset may not generalize
- **Spatial Token Order**: Current implementation uses raster order, other orders unexplored
- **LSTM Alternatives**: Other sequence models (Transformers, GRU) not compared
- **Medical Validation**: Requires validation by medical professionals

### Future Research Directions

**Technical Improvements:**
1. **Attention Mechanisms**: Replace LSTM with self-attention for better spatial modeling
2. **Token Order Optimization**: Investigate optimal spatial token ordering strategies
3. **Multi-Scale Features**: Combine features from multiple CNN layers
4. **Ensemble Integration**: Combine CNN-LSTM with other model types

**Medical Applications:**
1. **Multi-Class Extension**: Extend to multiple lung conditions
2. **Severity Assessment**: Predict pneumonia severity levels
3. **Cross-Dataset Validation**: Test on different hospital datasets
4. **Clinical Integration**: Develop deployment-ready systems

**Research Extensions:**
1. **Comparative Studies**: Compare with other hybrid architectures
2. **Ablation Studies**: Analyze contribution of each component
3. **Efficiency Optimization**: Develop mobile-friendly versions
4. **Interpretability**: Better visualization of LSTM decision process

### Impact and Significance

**Scientific Contribution:**
- Demonstrated feasibility of CNN-LSTM for medical imaging
- Established baseline performance for spatial token processing
- Provided comparison framework for hybrid architectures
- Highlighted trade-offs between complexity and performance

**Clinical Relevance:**
- Potential for improved pneumonia detection accuracy
- Framework for incorporating spatial context in medical AI
- Foundation for multi-modal medical image analysis
- Step toward more interpretable medical AI systems

**This experimental work opens new avenues for combining CNN and sequential processing in medical image analysis, with promising results for pediatric pneumonia detection.**