# Malaria Geographic Origin Model Comparison

This notebook compares the performance of different models for classifying malaria parasites by their geographic origin:
1. **Multinomial Naive Bayes**: A classical machine learning approach
2. **Standard CNN**: A basic convolutional neural network
3. **Advanced CNN with Strand Symmetry**: Enhanced CNN with reverse complement equivariance
4. **Advanced CNN without Strand Symmetry**: Enhanced CNN without reverse complement equivariance

A key focus of this analysis is understanding the impact of reverse complement equivariance (strand symmetry) on model performance.

In [None]:
# Import necessary libraries
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix

# Add project root to path
sys.path.append('..')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('viridis')

# Create directories for outputs
os.makedirs('../reports/figures', exist_ok=True)

# Import Project Modules

In [None]:
# Import project modules
from src.models.cnn_standard import DNACNN as StandardCNN  
from src.models.cnn_advanced import DNACNN as AdvancedCNN
from src.models.naive_bayes import MultinomialNaiveBayes
from src.data.genomic_sequences import GenomicSequenceDataset
from src.evaluation.model_evaluator import evaluate_model_detailed
from src.evaluation.model_comparison import (
    compare_models, 
    calculate_roc_curves, 
    analyze_complexity_tradeoff,
    analyze_per_class_performance,
    identify_challenging_classes,
    identify_best_model_per_class,
    calculate_model_size
)
from src.evaluation.compare_strand_symmetry import (
    create_strand_symmetric_model,
    evaluate_strand_symmetry_effect,
    plot_strand_symmetry_comparison,
    run_strand_symmetry_analysis
)
from src.visualization.performance_visualizer import (
    plot_confusion_matrix, 
    plot_roc_curves, 
    plot_metrics_comparison, 
    plot_complexity_tradeoff,
    plot_per_class_performance,
    plot_training_history,
    plot_model_architecture_comparison,
    plot_challenging_classes
)

# Load Test Data

In [None]:
# Load test data
test_dataset = GenomicSequenceDataset(
    split_dir="../data/split",
    split_type="test",
    window_size=1000,
    stride=500,
    cache_size=128
)

test_loader = test_dataset.get_dataloader(batch_size=32, shuffle=False, num_workers=4)

# Define class names
class_names = test_dataset.encoder.classes_
print(f"Number of classes: {len(class_names)}")
print(f"Class names: {class_names}")

# Load Trained Models

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load CNN models
try:
    standard_cnn = torch.load('../models/cnn_standard_best.pt', map_location=device)
    print("Loaded standard CNN model")
except FileNotFoundError:
    print("Standard CNN model not found. Please train the model first.")
    standard_cnn = None

# Load Advanced CNN with strand symmetry
try:
    advanced_cnn_symmetric = torch.load('../models/cnn_advanced_symmetric.pt', map_location=device)
    print("Loaded advanced CNN model with strand symmetry")
except FileNotFoundError:
    print("Advanced CNN model with strand symmetry not found. Please train the model first.")
    advanced_cnn_symmetric = None

# Load Advanced CNN without strand symmetry
try:
    advanced_cnn_standard = torch.load('../models/cnn_advanced_standard.pt', map_location=device)
    print("Loaded advanced CNN model without strand symmetry")
except FileNotFoundError:
    print("Advanced CNN model without strand symmetry not found. Please train the model first.")
    advanced_cnn_standard = None
    
# Load Naive Bayes model
try:
    naive_bayes = MultinomialNaiveBayes.load('../models/naive_bayes_model.pkl')
    print("Loaded Naive Bayes model")
except FileNotFoundError:
    print("Naive Bayes model not found. Please train the model first.")
    naive_bayes = None

# Evaluate CNN Models

In [None]:
# Define loss function for deep models
criterion = torch.nn.CrossEntropyLoss()

# List to store all model results
all_model_results = []

# Evaluate Standard CNN
if standard_cnn is not None:
    print("Evaluating Standard CNN...")
    standard_cnn_results = evaluate_model_detailed(
        standard_cnn, test_loader, criterion, device, model_name='CNN (Standard)'
    )
    all_model_results.append(standard_cnn_results)
    print(f"Standard CNN Accuracy: {standard_cnn_results['accuracy']:.4f}")

# Evaluate Advanced CNN with strand symmetry
if advanced_cnn_symmetric is not None:
    print("\nEvaluating Advanced CNN with Strand Symmetry...")
    advanced_symmetric_results = evaluate_model_detailed(
        advanced_cnn_symmetric, test_loader, criterion, device, model_name='CNN (Advanced, Symmetric)'
    )
    all_model_results.append(advanced_symmetric_results)
    print(f"Advanced CNN (Symmetric) Accuracy: {advanced_symmetric_results['accuracy']:.4f}")

# Evaluate Advanced CNN without strand symmetry
if advanced_cnn_standard is not None:
    print("\nEvaluating Advanced CNN without Strand Symmetry...")
    advanced_standard_results = evaluate_model_detailed(
        advanced_cnn_standard, test_loader, criterion, device, model_name='CNN (Advanced, Standard)'
    )
    all_model_results.append(advanced_standard_results)
    print(f"Advanced CNN (Standard) Accuracy: {advanced_standard_results['accuracy']:.4f}")

# Evaluate Naive Bayes Model

In [None]:
# For Naive Bayes, we need a different approach as it doesn't use PyTorch
if naive_bayes is not None:
    print("\nEvaluating Naive Bayes...")
    
    # Extract test data in a format suitable for Naive Bayes
    X_test = []
    y_test = []
    
    for batch in test_loader:
        sequences = batch['sequence'].cpu().numpy()
        labels = batch['label'].cpu().numpy()
        
        # Flatten and convert each sequence for Naive Bayes
        for i in range(len(sequences)):
            X_test.append(sequences[i].flatten())
            y_test.append(labels[i])

    X_test = np.array(X_test)
    y_test = np.array(y_test)
    
    # Time the Naive Bayes inference
    start_time = time.time()
    nb_predictions = naive_bayes.predict(X_test)
    nb_inference_time = (time.time() - start_time) / len(X_test)
    
    # Calculate metrics for Naive Bayes
    nb_accuracy = accuracy_score(y_test, nb_predictions)
    nb_precision = precision_score(y_test, nb_predictions, average='weighted', zero_division=0)
    nb_recall = recall_score(y_test, nb_predictions, average='weighted', zero_division=0)
    nb_f1 = f1_score(y_test, nb_predictions, average='weighted', zero_division=0)
    nb_class_report = classification_report(y_test, nb_predictions, output_dict=True, zero_division=0)
    nb_conf_matrix = confusion_matrix(y_test, nb_predictions)
    
    # Try to get probabilities if the model supports it
    try:
        nb_probabilities = naive_bayes.predict_proba(X_test)
    except:
        nb_probabilities = None
    
    # Store Naive Bayes results
    naive_bayes_results = {
        'model_name': 'Naive Bayes',
        'accuracy': nb_accuracy,
        'precision': nb_precision,
        'recall': nb_recall,
        'f1_score': nb_f1,
        'test_loss': 0.0,  # Naive Bayes doesn't have a loss
        'class_report': nb_class_report,
        'confusion_matrix': nb_conf_matrix,
        'predictions': nb_predictions,
        'true_labels': y_test,
        'probabilities': nb_probabilities,
        'avg_inference_time': nb_inference_time,
        'inference_times': [nb_inference_time] * len(y_test)  # Approximation
    }
    
    all_model_results.append(naive_bayes_results)
    print(f"Naive Bayes Accuracy: {nb_accuracy:.4f}")

# Compare Model Metrics

In [None]:
# Compare models
if len(all_model_results) > 0:
    comparison_results = compare_models(all_model_results)
    
    # Display metrics table
    metrics_df = comparison_results['metrics_table']
    metrics_df

# Analyze Strand Symmetry Effect

This section specifically focuses on comparing the impact of reverse complement equivariance (strand symmetry) on model performance.

In [None]:
# Run strand symmetry analysis using dedicated function
symmetry_results = run_strand_symmetry_analysis(test_loader, device)

# Display results
if len(symmetry_results) >= 2:
    sym_model = list(symmetry_results.keys())[0]  # First key (symmetric model)
    std_model = list(symmetry_results.keys())[1]  # Second key (standard model)
    
    # Create comparison table
    sym_comparison = pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
        sym_model: [
            symmetry_results[sym_model]['accuracy'],
            symmetry_results[sym_model]['precision'],
            symmetry_results[sym_model]['recall'],
            symmetry_results[sym_model]['f1_score']
        ],
        std_model: [
            symmetry_results[std_model]['accuracy'],
            symmetry_results[std_model]['precision'],
            symmetry_results[std_model]['recall'],
            symmetry_results[std_model]['f1_score']
        ]
    })
    
    # Calculate improvement percentage
    sym_comparison['Improvement (%)'] = ((sym_comparison[sym_model] - sym_comparison[std_model]) / sym_comparison[std_model] * 100).round(2)
    
    # Display table
    sym_comparison
else:
    print("Not enough models to compare strand symmetry effect")

# Analyze Class-Specific Strand Symmetry Impact

Let's examine whether strand symmetry has a larger impact on certain geographic regions.

In [None]:
# Compare per-class performance between symmetry and non-symmetry models
if len(symmetry_results) >= 2:
    # Get per-class F1 scores from both models
    per_class_sym = []
    
    for model_name, result in symmetry_results.items():
        for class_idx, metrics in result['class_report'].items():
            if class_idx in ['accuracy', 'macro avg', 'weighted avg']:
                continue
                
            if class_names is not None and int(class_idx) < len(class_names):
                class_label = class_names[int(class_idx)]
            else:
                class_label = f"Class {class_idx}"
                
            per_class_sym.append({
                'Model': model_name,
                'Class': class_label,
                'F1 Score': metrics['f1-score']
            })
    
    # Convert to DataFrame
    per_class_df = pd.DataFrame(per_class_sym)
    
    # Pivot to wide format for comparison
    pivot_df = per_class_df.pivot(index='Class', columns='Model', values='F1 Score')
    
    # Calculate improvement percentage
    sym_model = list(symmetry_results.keys())[0]  # Symmetric model
    std_model = list(symmetry_results.keys())[1]  # Standard model
    pivot_df['Improvement (%)'] = ((pivot_df[sym_model] - pivot_df[std_model]) / pivot_df[std_model] * 100).round(2)
    
    # Sort by improvement percentage
    pivot_df = pivot_df.sort_values('Improvement (%)', ascending=False)
    
    # Display top 10 most improved classes
    print("Classes with largest improvement from strand symmetry:")
    pivot_df.head(10)
else:
    print("Not enough models to compare per-class strand symmetry effect")

# Visualize Strand Symmetry Effect

In [None]:
# Plot strand symmetry comparison
if len(symmetry_results) >= 2:
    fig = plot_strand_symmetry_comparison(symmetry_results)
    plt.show()
    
    # Save figure
    fig.savefig('../reports/figures/strand_symmetry_effect.png', dpi=300)

# Analyze Improvement Percentages

In [None]:
# Display performance improvement
if len(all_model_results) > 1:
    improvement_metrics = metrics_df[['Model', 'Accuracy Improvement (%)', 'F1 Score Improvement (%)']]
    improvement_metrics = improvement_metrics.dropna().round(2)
    improvement_metrics

# Visualize Confusion Matrices

In [None]:
# Plot confusion matrices
for result in all_model_results:
    fig = plot_confusion_matrix(
        result['true_labels'], 
        result['predictions'],
        class_names=class_names,
        title=f'{result["model_name"]} Confusion Matrix'
    )
    plt.show()
    
    # Save figure
    model_name = result['model_name'].replace(' ', '_').lower().replace('(', '').replace(')', '').replace(',', '')
    fig.savefig(f'../reports/figures/{model_name}_confusion_matrix.png', dpi=300)

# Analyze Complexity vs. Performance Tradeoff

In [None]:
# Analyze complexity tradeoff
if len(all_model_results) > 1:
    tradeoff_results = analyze_complexity_tradeoff(all_model_results)
    
    # Plot complexity tradeoff
    fig = plot_complexity_tradeoff(tradeoff_results)
    plt.show()
    
    # Save figure
    fig.savefig('../reports/figures/complexity_tradeoff.png', dpi=300)

# Visualize Performance Metrics Comparison

In [None]:
# Plot metrics comparison
if len(all_model_results) > 1:
    fig = plot_metrics_comparison(metrics_df)
    plt.show()
    
    # Save figure
    fig.savefig('../reports/figures/metrics_comparison.png', dpi=300)

# Analyze Per-Class Performance

In [None]:
# Analyze per-class performance
if len(all_model_results) > 0:
    per_class_results = analyze_per_class_performance(all_model_results, class_names)
    
    # Display per-class results
    per_class_results.head(10)

# Plot per-class F1 scores
if len(all_model_results) > 0:
    fig = plot_per_class_performance(per_class_results, metric='F1 Score')
    plt.show()
    
    # Save figure
    fig.savefig('../reports/figures/per_class_f1.png', dpi=300)

# Identify Best Model for Each Class

In [None]:
# Calculate which model performs best on each class
if len(all_model_results) > 1:
    best_model_per_class = identify_best_model_per_class(per_class_results)
    best_model_per_class

# Identify Challenging Classes

In [None]:
# Identify challenging classes
if len(all_model_results) > 0:
    challenging_classes = identify_challenging_classes(per_class_results, threshold=0.7)
    print(f"Challenging classes (F1 < 0.7): {challenging_classes}")
    
    if challenging_classes:
        # Plot performance on challenging classes
        fig = plot_challenging_classes(per_class_results, challenging_classes)
        plt.show()
        
        # Save figure
        fig.savefig('../reports/figures/challenging_classes.png', dpi=300)

# Summary of Findings

In [None]:
# Print summary of findings
if len(all_model_results) > 0:
    print("\n=== MODEL EVALUATION SUMMARY ===")
    for result in all_model_results:
        print(f"\n{result['model_name']}:")
        print(f"  - Accuracy: {result['accuracy']:.4f}")
        print(f"  - F1 Score: {result['f1_score']:.4f}")
        print(f"  - Inference Time: {result['avg_inference_time']*1000:.2f} ms per sample")
    
    if len(all_model_results) > 1:
        # Find best model by accuracy
        best_acc_idx = np.argmax([r['accuracy'] for r in all_model_results])
        best_acc_model = all_model_results[best_acc_idx]['model_name']
        
        # Find fastest model
        fastest_idx = np.argmin([r['avg_inference_time'] for r in all_model_results])
        fastest_model = all_model_results[fastest_idx]['model_name']
        
        print(f"\nBest accuracy: {best_acc_model} ({all_model_results[best_acc_idx]['accuracy']:.4f})")
        print(f"Fastest inference: {fastest_model} ({all_model_results[fastest_idx]['avg_inference_time']*1000:.2f} ms)")
        
        # Print challenging classes
        if challenging_classes:
            print(f"\nChallenging classes: {', '.join(challenging_classes)}")
            
        # Print summary of strand symmetry effect
        if len(symmetry_results) >= 2:
            sym_model = list(symmetry_results.keys())[0]  # Symmetric model
            std_model = list(symmetry_results.keys())[1]  # Standard model
            
            acc_diff = symmetry_results[sym_model]['accuracy'] - symmetry_results[std_model]['accuracy']
            acc_pct = acc_diff / symmetry_results[std_model]['accuracy'] * 100
            
            print(f"\nImpact of Strand Symmetry (Reverse Complement Equivariance):")
            print(f"  - Accuracy improvement: {acc_diff*100:.2f} percentage points ({acc_pct:.2f}%)")
            
            # Show top 3 most improved classes
            if 'Improvement (%)' in pivot_df.columns:
                top_improved = pivot_df.head(3)
                print(f"  - Classes with largest improvement from strand symmetry:")
                for idx, row in top_improved.iterrows():
                    print(f"    * {idx}: {row['Improvement (%)']:.2f}% improvement")

# Export Results for Report

In [None]:
# Export results to CSV files
if len(all_model_results) > 0:
    # Export metrics table
    metrics_df.to_csv('../reports/model_metrics.csv', index=False)
    
    # Export per-class results
    per_class_results.to_csv('../reports/per_class_metrics.csv', index=False)
    
    # Export best model per class
    if len(all_model_results) > 1:
        best_model_per_class.to_csv('../reports/best_model_per_class.csv', index=False)
    
    # Export strand symmetry comparison
    if len(symmetry_results) >= 2:
        sym_comparison.to_csv('../reports/strand_symmetry_comparison.csv', index=False)
        pivot_df.to_csv('../reports/strand_symmetry_per_class.csv')
        
    print("Results exported to CSV files in '../reports/' directory")