In [4]:
# Cell 1: Import Libraries and Configure Environment
"""
ASV Selection Machine Learning Pipeline
Author: Sarawut Ounjai
Last Updated: 2024
"""

# Standard libraries
import os
import sys
import time
import logging
import traceback
import shutil
import tempfile
import json
import warnings
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any, Union
from logging.handlers import RotatingFileHandler

from joblib import dump, load, parallel_backend

# Data processing and ML libraries
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, classification_report, precision_score,
    recall_score, f1_score, roc_curve, auc, confusion_matrix,
    balanced_accuracy_score
)
from joblib import parallel_backend
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import pdist, squareform
import networkx as nx
from scipy.stats import chi2_contingency

# Visualization libraries
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from sklearn.decomposition import PCA

from IPython.display import display, HTML, Markdown
import plotly.graph_objects as go
from sklearn.model_selection import cross_val_score

from IPython.display import display, Markdown

# Suppress warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
warnings.filterwarnings('ignore', category=UserWarning)

# Cell 2: Configuration and Setup
# File paths
BASE_DIR = Path("/Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis")
CHAPTER_DIR = BASE_DIR / "Chapter2_Data_generation/Barcoding_Machine_Learning"
INPUT_FILE = CHAPTER_DIR / "Barcoding_Machine_Learning_OR_020425.csv"
OUTPUT_DIR = CHAPTER_DIR / "Barcoding_Machine_Learning_OR_200425"
MODEL_SAVE_DIR = CHAPTER_DIR / "Barcoding_Machine_Learning_OR_200425"

# Create output directories
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Configure logging
def setup_logging(output_dir: Path) -> None:
    """Configure logging with both file and console handlers."""
    log_dir = output_dir / 'logs'
    log_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f'asv_selection_{timestamp}.log'
    
    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # File handler
    file_handler = RotatingFileHandler(
        str(log_file),
        maxBytes=10*1024*1024,
        backupCount=5,
        encoding='utf-8'
    )
    file_handler.setFormatter(formatter)
    file_handler.setLevel(logging.DEBUG)
    
    # Console handler (using StreamHandler for Jupyter compatibility)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    console_handler.setLevel(logging.INFO)
    
    # Configure root logger
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.DEBUG)
    
    # Remove existing handlers
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    
    # Add new handlers    
    root_logger.addHandler(file_handler)
    root_logger.addHandler(console_handler)
    
    logging.info(f"Logging initialized. Log file: {log_file}")

# Initialize logging
setup_logging(OUTPUT_DIR)

# Cell 3: Data Loading Functions
def load_data() -> pd.DataFrame:
    """Load and validate input data."""
    try:
        # Copy file to temp location
        temp_dir = tempfile.mkdtemp()
        temp_file = Path(temp_dir) / 'input_data.csv'
        
        logging.info(f"Copying file to temporary location: {temp_file}")
        shutil.copy2(INPUT_FILE, temp_file)
        
        # Read data
        logging.info("Reading data from temporary file")
        df = pd.read_csv(temp_file)
        
        if len(df) == 0:
            raise ValueError("Empty dataframe loaded")
        
        logging.info(f"Successfully loaded {len(df)} records")
        return df
        
    except Exception as e:
        logging.error(f"Error in data loading: {str(e)}")
        raise
        
    finally:
        if temp_dir and Path(temp_dir).exists():
            shutil.rmtree(temp_dir)

# Cell 4: Data Preprocessing Functions
def preprocess_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
    """Preprocess data without bias, focusing on quality feature engineering."""
    try:
        logging.info("Starting data preprocessing...")
        df = df.copy()
        
        # Standard preprocessing for categorical variables
        df['match'] = df['match'].fillna('no_match')
        match_mapping = {
            'match': 'match', 'Match': 'match', 'TRUE': 'match', True: 'match',
            'no_match': 'no_match', 'No_Match': 'no_match', 'nomatch': 'no_match',
            'NoMatch': 'no_match', 'FALSE': 'no_match', False: 'no_match', 'no': 'no_match'
        }
        df['match'] = df['match'].map(match_mapping)
        
        # Basic features without bias
        df['read_proportion'] = df['reads'] / df['total_asv_reads'].replace(0, np.nan)
        df['log_reads'] = np.log1p(df['reads'])
        df['read_density'] = df['reads'] / df['asv_count'].replace(0, 1)
        df['is_single_asv'] = (df['asv_count'] == 1).astype(int)
        df['is_dominant_asv'] = (df['read_proportion'] > 0.9).astype(int)
        
        # Sample-level features
        df['reads_rank'] = df.groupby('project_readfile_id')['reads'].rank(ascending=False)
        df['relative_abundance'] = df.groupby('project_readfile_id')['reads'].transform(lambda x: x / x.sum())
        df['is_match'] = (df['match'] == 'match').astype(int)
        
        # Target variable
        df['target'] = (df['autopropose'] == 'select').astype(int)
        
        # Core feature columns
        feature_columns = [
            'reads', 'total_asv_reads', 'asv_count', 'percentage_reads',
            'read_proportion', 'log_reads', 'read_density', 'is_single_asv',
            'is_dominant_asv', 'reads_rank', 'relative_abundance', 'is_match'
        ]
        
        return df, feature_columns
        
    except Exception as e:
        logging.error(f"Error in preprocessing: {str(e)}")
        raise

# Cell 5: Model Training Functions
def train_model(X: pd.DataFrame, y: pd.Series, feature_columns: List[str]):
    """Train Random Forest model with comprehensive parameter grid."""
    try:
        logging.info("Training Random Forest model...")
        
        # Restore original comprehensive parameter grid
        param_grid = {
            'n_estimators': [100, 200, 300],
            'max_depth': [None],
            'min_samples_split': [2, 5],
            'class_weight': ['balanced'],
            'min_samples_leaf': [1]
        }
        
        # Initialize base model
        rf = RandomForestClassifier(
            random_state=42,
            n_jobs=-1,
            verbose=0
        )
        
        # Set up grid search with cross-validation
        grid_search = GridSearchCV(
            estimator=rf,
            param_grid=param_grid,
            cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
            scoring=['f1', 'precision', 'recall'],
            refit='f1',
            n_jobs=-1,
            verbose=1
        )
        
        # Fit model using parallel processing
        with parallel_backend('threading'):
            grid_search.fit(X, y)
        
        # Get best model
        best_model = grid_search.best_estimator_
        
        # Calculate feature importance
        importance = pd.DataFrame({
            'feature': feature_columns,
            'importance': best_model.feature_importances_
        }).sort_values('importance', ascending=False)
        
        # Calculate model accuracy
        y_pred = best_model.predict(X)
        accuracy = accuracy_score(y, y_pred)
        
        # Log results
        logging.info("\nModel Training Results:")
        logging.info(f"Best parameters: {grid_search.best_params_}")
        logging.info(f"Cross-validation score: {grid_search.best_score_:.4f}")
        logging.info(f"Test accuracy: {accuracy:.4f}")
        
        return best_model, accuracy, importance, X, y
        
    except Exception as e:
        logging.error(f"Error in model training: {str(e)}")
        raise

def evaluate_model_stability(X: pd.DataFrame, y: pd.Series, n_repeats: int = 5):
    """Evaluate model stability across multiple runs."""
    logging.info("\nEvaluating model stability...")
    
    results = []
    for i in range(n_repeats):
        # Train model with different random seeds
        rf = RandomForestClassifier(
            n_estimators=300,
            max_depth=20,
            min_samples_split=5,
            min_samples_leaf=2,
            class_weight='balanced',
            random_state=i
        )
        
        # Use cross-validation
        cv_scores = cross_val_score(rf, X, y, cv=5, scoring='f1')
        results.append({
            'run': i + 1,
            'mean_f1': cv_scores.mean(),
            'std_f1': cv_scores.std()
        })
    
    # Calculate stability metrics
    stability_df = pd.DataFrame(results)
    overall_stability = {
        'mean_f1': stability_df['mean_f1'].mean(),
        'std_f1': stability_df['mean_f1'].std(),
        'cv_stability': stability_df['mean_f1'].std() / stability_df['mean_f1'].mean()
    }
    
    logging.info(f"Model stability metrics: {overall_stability}")
    return overall_stability

def save_model_components(model, scaler: StandardScaler, feature_columns: List[str],
                        optimal_threshold: float, MODEL_SAVE_DIR: Path) -> None:
    """Save trained model, scaler, and configuration for future predictions."""
    try:
        logging.info("\nSaving model components...")
        
        # Ensure directory exists
        MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
        
        # Save the trained model
        model_path = MODEL_SAVE_DIR / 'trained_model.joblib'
        dump(model, model_path)
        logging.info(f"Saved model to: {model_path}")
        
        # Save the scaler
        scaler_path = MODEL_SAVE_DIR / 'scaler.joblib'
        dump(scaler, scaler_path)
        logging.info(f"Saved scaler to: {scaler_path}")
        
        # Save configuration (feature columns and threshold)
        config = {
            'feature_columns': feature_columns,
            'optimal_threshold': float(optimal_threshold),
            'timestamp': datetime.now().isoformat(),
            'model_parameters': model.get_params()
        }
        
        config_path = MODEL_SAVE_DIR / 'model_config.json'
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)
        logging.info(f"Saved configuration to: {config_path}")
        
    except Exception as e:
        logging.error(f"Error saving model components: {str(e)}")
        raise

def analyze_model_decisions(df: pd.DataFrame, model: RandomForestClassifier, 
                          X: pd.DataFrame, feature_columns: List[str]):
    """Analyze how model naturally learned the selection patterns."""
    
    # Get feature importances
    importances = pd.DataFrame({
        'feature': feature_columns,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)
    
    # Analyze decision patterns
    predictions = model.predict(X)
    df['model_prediction'] = predictions
    
    # Compare with human criteria
    patterns = {
        'read_ge_4': (df['reads'] >= 4).mean(),
        'percent_ge_4': (df['percentage_reads'] >= 4).mean(),
        'is_match': (df['match'] == 'match').mean(),
        'high_read_low_percent': ((df['reads'] > 10) & (df['percentage_reads'] < 4)).mean()
    }
    
    return importances, patterns

# Cell 6: Threshold Optimization Functions
def evaluate_threshold(df: pd.DataFrame, threshold: float, 
                      probabilities: np.ndarray, y_true: np.ndarray) -> Dict:
    """Evaluate model performance at a specific threshold."""
    try:
        # Make predictions using threshold
        y_pred = (probabilities >= threshold).astype(int)
        
        # Calculate performance metrics
        metrics = {
            'threshold': float(threshold),
            'precision': float(precision_score(y_true, y_pred, zero_division=0)),
            'recall': float(recall_score(y_true, y_pred, zero_division=0)),
            'f1_score': float(f1_score(y_true, y_pred, zero_division=0)),
            'accuracy': float(accuracy_score(y_true, y_pred))
        }
        
        # Add confusion matrix metrics
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        metrics.update({
            'true_positives': int(tp),
            'false_positives': int(fp),
            'false_negatives': int(fn),
            'true_negatives': int(tn)
        })
        
        return metrics
        
    except Exception as e:
        logging.error(f"Error evaluating threshold {threshold}: {str(e)}")
        return None

def find_optimal_threshold(df: pd.DataFrame, model: RandomForestClassifier, 
                         scaler: StandardScaler, feature_columns: List[str]) -> Tuple[pd.DataFrame, float, Dict]:
    """Find optimal prediction threshold using three-stage search."""
    try:
        logging.info("\nFinding optimal threshold...")
        
        # Calculate prediction probabilities
        X = df[feature_columns].copy()
        X_scaled = pd.DataFrame(
            scaler.transform(X),
            columns=feature_columns,
            index=X.index
        )
        probabilities = model.predict_proba(X_scaled)[:, 1]
        
        # Initialize results storage
        threshold_results = {
            'coarse_search': [],
            'medium_search': [],
            'fine_search': []
        }
        
        # Stage 1: Coarse search (0-1 in 0.1 steps)
        logging.info("Performing coarse threshold search...")
        coarse_thresholds = np.linspace(0, 1, 11)
        for threshold in coarse_thresholds:
            metrics = evaluate_threshold(
                df=df,
                threshold=float(threshold),
                probabilities=probabilities,
                y_true=(df['autopropose'] == 'select').astype(int)
            )
            threshold_results['coarse_search'].append(metrics)
        
        # Find best coarse threshold
        coarse_df = pd.DataFrame(threshold_results['coarse_search'])
        best_coarse_idx = coarse_df['f1_score'].idxmax()
        best_coarse = float(coarse_df.iloc[best_coarse_idx]['threshold'])
        
        # Stage 2: Medium search (±0.1 around best coarse)
        logging.info("Performing medium-grain threshold search...")
        search_width = 0.1
        medium_thresholds = np.linspace(
            max(0, best_coarse - search_width),
            min(1, best_coarse + search_width),
            21
        )
        
        for threshold in medium_thresholds:
            metrics = evaluate_threshold(
                df=df,
                threshold=float(threshold),
                probabilities=probabilities,
                y_true=(df['autopropose'] == 'select').astype(int)
            )
            threshold_results['medium_search'].append(metrics)
        
        # Find best medium threshold
        medium_df = pd.DataFrame(threshold_results['medium_search'])
        best_medium_idx = medium_df['f1_score'].idxmax()
        best_medium = float(medium_df.iloc[best_medium_idx]['threshold'])
        
        # Stage 3: Fine search (±0.01 around best medium)
        logging.info("Performing fine-grain threshold search...")
        search_width = 0.01
        fine_thresholds = np.linspace(
            max(0, best_medium - search_width),
            min(1, best_medium + search_width),
            21
        )
        
        for threshold in fine_thresholds:
            metrics = evaluate_threshold(
                df=df,
                threshold=float(threshold),
                probabilities=probabilities,
                y_true=(df['autopropose'] == 'select').astype(int)
            )
            threshold_results['fine_search'].append(metrics)
        
        # Find optimal threshold
        fine_df = pd.DataFrame(threshold_results['fine_search'])
        optimal_idx = fine_df['f1_score'].idxmax()
        optimal_threshold = float(fine_df.iloc[optimal_idx]['threshold'])
        
        logging.info(f"Optimal threshold found: {optimal_threshold:.4f}")
        
        # Add probabilities to dataframe
        df['prediction_probability'] = probabilities
        
        return df, optimal_threshold, threshold_results
        
    except Exception as e:
        logging.error(f"Error in threshold optimization: {str(e)}")
        raise

# Cell 7: Model Application Functions
def apply_model_predictions(df: pd.DataFrame, model: RandomForestClassifier,
                          scaler: StandardScaler, feature_columns: List[str],
                          threshold: float) -> pd.DataFrame:
    """Apply model predictions using the optimized threshold."""
    try:
        logging.info(f"\nApplying model predictions with threshold {threshold:.4f}")
        
        df = df.copy()
        
        # Calculate prediction probabilities
        X = df[feature_columns].fillna(0)
        X_scaled = pd.DataFrame(
            scaler.transform(X),
            columns=feature_columns,
            index=X.index
        )
        
        df['prediction_probability'] = model.predict_proba(X_scaled)[:, 1]
        
        # Initialize prediction columns
        df['model_prediction'] = 0
        df['model_decision'] = 'unselect'
        df['expert_decision'] = df['autopropose']
        
        # Process each sample
        for sample_id in df['project_readfile_id'].unique():
            sample_mask = df['project_readfile_id'] == sample_id
            sample_data = df[sample_mask]
            
            # Find matched ASVs with sufficient reads
            matched_data = sample_data[
                (sample_data['match'] == 'match') & 
                (sample_data['reads'] >= 4)
            ].copy()
            
            if not matched_data.empty:
                best_idx = matched_data['prediction_probability'].idxmax()
                if matched_data.loc[best_idx, 'prediction_probability'] >= threshold:
                    df.loc[best_idx, 'model_prediction'] = 1
                    df.loc[best_idx, 'model_decision'] = 'select'
        
        # Calculate agreement between model and expert
        df['agreement'] = np.where(
            (df['model_decision'] == 'select') & (df['autopropose'] == 'select') |
            (df['model_decision'] == 'unselect') & (df['autopropose'] == 'unselect'),
            'agree', 'disagree'
        )
        
        # Log prediction statistics
        logging.info("\nPrediction Statistics:")
        logging.info(f"Total samples processed: {df['project_readfile_id'].nunique()}")
        logging.info(f"Total ASVs selected: {(df['model_decision'] == 'select').sum()}")
        logging.info(f"Agreement rate: {(df['agreement'] == 'agree').mean():.4f}")
        
        return df
        
    except Exception as e:
        logging.error(f"Error in model predictions: {str(e)}")
        raise

# Cell 8: Feature Analysis Functions
def analyze_features(df: pd.DataFrame, model: RandomForestClassifier, 
                    scaler: StandardScaler, feature_columns: List[str], 
                    threshold_results: Dict) -> Dict:
    """Analyze features and create visualizations."""
    try:
        logging.info("\nAnalyzing features...")
        plots = {}
        
        # Get optimal threshold
        optimal_threshold = threshold_results.get('optimal_threshold')
        if optimal_threshold is None:
            logging.warning("Optimal threshold not found in results, using default 0.5")
            optimal_threshold = 0.5
        
        # Log the threshold being used in feature analysis
        logging.info(f"Using threshold {optimal_threshold} for feature analysis")
        
        # Validate features
        model_features = model.feature_names_in_
        if not all(f in df.columns for f in model_features):
            raise ValueError("Missing required features for model")
        
        # Create PCA visualization
        pca_fig = create_pca_plot(
            df=df,
            feature_columns=model_features,
            model=model,
            scaler=scaler,
            optimal_threshold=optimal_threshold
        )
        
        if pca_fig:
            plots['pca_analysis'] = [('PCA Analysis', pca_fig)]
        
        # Create feature distribution plots - explicitly passing optimal_threshold
        feature_plots = create_feature_distribution_plots(
            df=df,
            model=model,
            scaler=scaler,
            feature_columns=model_features,
            optimal_threshold=optimal_threshold
        )
        
        if feature_plots:
            plots['feature_analysis'] = feature_plots
        
        logging.info("Feature analysis complete!")
        return plots
        
    except Exception as e:
        logging.error(f"Error in feature analysis: {str(e)}")
        logging.error(traceback.format_exc())
        return {}

# Cell 9: Performance Analysis Functions
def analyze_correlation(df: pd.DataFrame, feature_columns: List[str]) -> Tuple[pd.DataFrame, np.ndarray]:
    """Analyze correlations between features."""
    try:
        # Calculate correlation matrix
        corr_matrix = df[feature_columns].corr()
        
        # Calculate feature clusters using hierarchical clustering
        linkage_matrix = linkage(squareform(1 - corr_matrix), method='ward')
        
        return corr_matrix, linkage_matrix
    except Exception as e:
        logging.error(f"Error in correlation analysis: {str(e)}")
        return None, None

def analyze_taxonomic_patterns(df: pd.DataFrame) -> pd.DataFrame:
    """Analyze taxonomic patterns in ASV selection."""
    try:
        if 'taxonomy' not in df.columns:
            return None
            
        # Calculate taxonomic distribution
        tax_dist = df.groupby(['taxonomy', 'model_decision']).size().unstack(fill_value=0)
        
        # Calculate selection rates by taxonomy
        tax_stats = pd.DataFrame({
            'total': tax_dist.sum(axis=1),
            'selected': tax_dist['select'],
            'selection_rate': tax_dist['select'] / tax_dist.sum(axis=1)
        }).round(4)
        
        return tax_stats
    except Exception as e:
        logging.error(f"Error in taxonomic analysis: {str(e)}")
        return None

def analyze_error_patterns(df: pd.DataFrame) -> Dict:
    """Analyze error patterns in ASV selection."""
    try:
        error_analysis = {
            'confusion_matrix': confusion_matrix(
                df['autopropose'] == 'select',
                df['model_decision'] == 'select'
            ),
            'classification_report': classification_report(
                df['autopropose'] == 'select',
                df['model_decision'] == 'select',
                output_dict=True
            )
        }
        
        # Analyze characteristics of misclassified cases
        misclassified = df[df['agreement'] == 'disagree']
        error_analysis['misclassified_stats'] = {
            'count': len(misclassified),
            'mean_reads': misclassified['reads'].mean(),
            'mean_percentage': misclassified['percentage_reads'].mean(),
            'asv_count_stats': misclassified['asv_count'].describe().to_dict()
        }
        
        return error_analysis
        
    except Exception as e:
        logging.error(f"Error analyzing error patterns: {str(e)}")
        raise

def analyze_asv_selection(df: pd.DataFrame) -> None:
    """
    Analyze and display summary statistics for selected vs all ASVs.
    """
    try:
        # Separate selected ASVs
        selected_df = df[df['model_decision'] == 'select']
        
        # Calculate basic statistics
        total_asvs = len(df)
        selected_asvs = len(selected_df)
        selection_rate = (selected_asvs / total_asvs) * 100
        
        # Create summary statistics for numeric columns
        numeric_cols = ['reads', 'total_asv_reads', 'asv_count', 'percentage_reads', 
                       'read_proportion', 'log_reads', 'read_density']
        
        all_summary = df[numeric_cols].describe()
        selected_summary = selected_df[numeric_cols].describe()
        
        # Display results using markdown for better formatting
        display(Markdown(f"""
## ASV Selection Summary

### Overview Statistics
- Total ASVs analyzed: {total_asvs:,}
- ASVs selected: {selected_asvs:,}
- Selection rate: {selection_rate:.2f}%

### Comparison of All vs Selected ASVs
"""))
        
        # Display side-by-side comparison
        comparison = pd.concat([all_summary, selected_summary], axis=1, 
                             keys=['All ASVs', 'Selected ASVs'])
        display(comparison)
        
        # Additional categorical summaries if available
        if 'taxonomy' in df.columns:
            display(Markdown("\n### Taxonomic Distribution of Selected ASVs"))
            tax_dist = pd.crosstab(df['taxonomy'], df['model_decision'])
            tax_dist['selection_rate'] = (tax_dist['select'] / tax_dist.sum(axis=1) * 100).round(2)
            display(tax_dist)
        
        if 'match' in df.columns:
            display(Markdown("\n### Match Statistics for Selected ASVs"))
            match_stats = pd.crosstab(selected_df['match'], selected_df['model_decision'])
            display(match_stats)
        
        # Log the summary
        logging.info(f"ASV Selection Summary - Total: {total_asvs}, Selected: {selected_asvs}, Rate: {selection_rate:.2f}%")
        
    except Exception as e:
        logging.error(f"Error in ASV selection analysis: {str(e)}")
        logging.error(traceback.format_exc())
        display(Markdown("❌ Error generating ASV selection summary"))

# Cell 10: Visualization Functions
def create_pca_plot(df: pd.DataFrame, feature_columns: List[str],
                   model: RandomForestClassifier, scaler: StandardScaler,
                   optimal_threshold: float) -> go.Figure:
    """Create enhanced PCA visualization with threshold line."""
    try:
        # Prepare data
        X = df[feature_columns]
        X_scaled = scaler.transform(X)
        
        # Perform PCA
        pca = PCA(n_components=2)
        pca_result = pca.fit_transform(X_scaled)
        
        # Get predictions and probabilities
        probabilities = model.predict_proba(X_scaled)[:, 1]
        predictions = (probabilities >= optimal_threshold).astype(int)
        
        # Create figure with secondary y-axis for colorbar
        fig = make_subplots(specs=[[{"secondary_y": True}]])
        
        # Add scatter points colored by probability
        fig.add_trace(
            go.Scatter(
                x=pca_result[:, 0],
                y=pca_result[:, 1],
                mode='markers',
                marker=dict(
                    size=12,
                    color=probabilities,
                    colorscale='Viridis',
                    showscale=True,
                    colorbar=dict(
                        title='Selection Probability',
                        tickformat='.2f'
                    ),
                    line=dict(width=1, color='black')
                ),
                text=[f"Probability: {p:.3f}<br>Selected: {p >= optimal_threshold}" 
                      for p in probabilities],
                hoverinfo='text',
                name='ASVs'
            )
        )
        
        # Add threshold line annotation
        fig.add_annotation(
            xref='paper',
            yref='paper',
            x=0.02,
            y=0.98,
            text=f'Selection Threshold: {optimal_threshold:.3f}',
            showarrow=False,
            font=dict(size=12),
            bgcolor='rgba(255, 255, 255, 0.8)',
            bordercolor='red',
            borderwidth=2
        )
        
        # Color points based on selection
        for i, (label, color) in enumerate(zip(['Unselected', 'Selected'], ['red', 'green'])):
            mask = predictions == i
            fig.add_trace(
                go.Scatter(
                    x=pca_result[mask, 0],
                    y=pca_result[mask, 1],
                    mode='markers',
                    marker=dict(
                        size=12,
                        opacity=0.5,
                        line=dict(color=color, width=2)
                    ),
                    name=label,
                    showlegend=True
                )
            )
        
        # Update layout
        fig.update_layout(
            title='PCA Analysis with Selection Boundary',
            xaxis_title=f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)',
            yaxis_title=f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)',
            plot_bgcolor='white',
            width=900,
            height=700,
            legend=dict(
                yanchor="top",
                y=0.99,
                xanchor="left",
                x=0.01,
                bgcolor="rgba(255,255,255,0.8)",
                bordercolor="black",
                borderwidth=1
            ),
            hovermode='closest'
        )
        
        # Add grid
        fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
        fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
        
        return fig
        
    except Exception as e:
        logging.error(f"Error creating PCA plot: {str(e)}")
        return None

def create_feature_distribution_plots(df: pd.DataFrame, model: RandomForestClassifier,
                                   scaler: StandardScaler, feature_columns: List[str],
                                   optimal_threshold: float) -> List[Tuple[str, go.Figure]]:
    """Create feature distribution plots with correct threshold value."""
    plots = []
    logging.info(f"Creating feature plots with threshold: {optimal_threshold}")
    
    try:
        # Prepare data
        X = df[feature_columns].copy()
        X_scaled = pd.DataFrame(
            scaler.transform(X),
            columns=feature_columns,
            index=X.index
        )
        
        # Get predictions
        probabilities = model.predict_proba(X_scaled)[:, 1]
        
        # Create plots for each feature
        for feature in feature_columns:
            if feature == feature_columns[0]:
                logging.info(f"Using threshold {optimal_threshold} in feature plots")
            
            # Create subplots
            fig = make_subplots(
                rows=2, cols=1,
                subplot_titles=(
                    f'{feature} Distribution by Decision',
                    f'Selection Probability vs {feature}'
                ),
                vertical_spacing=0.20,
                row_heights=[0.6, 0.4]
            )
            
            # Distribution plot (top)
            for decision in ['select', 'unselect']:
                mask = df['model_decision'] == decision
                fig.add_trace(
                    go.Histogram(
                        x=df[feature][mask],
                        name=decision,
                        opacity=0.7,
                        marker_color='green' if decision == 'select' else 'red',
                        nbinsx=30,
                        showlegend=True
                    ),
                    row=1, col=1
                )
            
            # Find feature value at optimal threshold
            threshold_mask = (probabilities >= optimal_threshold)
            if any(threshold_mask):
                threshold_value = df[feature][threshold_mask].min()
            else:
                threshold_value = df[feature].median()
            
            # Add vertical threshold line
            fig.add_vline(
                x=threshold_value,
                line_dash="dash",
                line_color="black",
                opacity=0.5,
                row=1, col=1,
                annotation=dict(
                    text=f"Feature Value at Threshold: {threshold_value:.3f}",
                    font=dict(size=10),
                    xanchor='left',
                    yanchor='top'
                )
            )
            
            # Probability scatter plot (bottom)
            fig.add_trace(
                go.Scatter(
                    x=df[feature],
                    y=probabilities,
                    mode='markers',
                    marker=dict(
                        color=probabilities,
                        colorscale='Viridis',
                        showscale=True,
                        colorbar=dict(
                            title='Probability',
                            y=0.2,
                            len=0.4
                        ),
                        size=8,
                        opacity=0.6
                    ),
                    name='Probability'
                ),
                row=2, col=1
            )
            
            # Add horizontal threshold line
            fig.add_hline(
                y=optimal_threshold,
                line_dash="dash",
                line_color="red",
                opacity=0.7,
                row=2, col=1,
                annotation=dict(
                    text=f"Selection Threshold: {optimal_threshold:.3f}",
                    font=dict(size=10, color='red'),
                    xanchor='left',
                    yanchor='bottom'
                )
            )
            
            # Update layout
            fig.update_layout(
                height=800,
                title=dict(
                    text=f'Feature Analysis: {feature}<br><sup>Optimal Threshold: {optimal_threshold:.3f}</sup>',
                    x=0.5,
                    y=0.95
                ),
                showlegend=True,
                legend=dict(
                    yanchor="top",
                    y=0.99,
                    xanchor="right",
                    x=0.99
                ),
                barmode='overlay'
            )
            
            # Update axes labels
            fig.update_xaxes(title_text=feature, row=1, col=1)
            fig.update_yaxes(title_text="Count", row=1, col=1)
            fig.update_xaxes(title_text=f"Feature Value", row=2, col=1)
            fig.update_yaxes(
                title_text="Selection Probability",
                range=[0, 1],  # Force y-axis range
                row=2, col=1
            )
            
            # Add feature statistics
            feature_stats = df.groupby('model_decision')[feature].describe()
            stats_text = (
                f"Feature Statistics:<br>"
                f"Selected (n={len(df[df['model_decision']=='select'])})<br>"
                f"• Mean: {feature_stats.loc['select', 'mean']:.2f}<br>"
                f"• Std: {feature_stats.loc['select', 'std']:.2f}<br>"
                f"Unselected (n={len(df[df['model_decision']=='unselect'])})<br>"
                f"• Mean: {feature_stats.loc['unselect', 'mean']:.2f}<br>"
                f"• Std: {feature_stats.loc['unselect', 'std']:.2f}"
            )
            
            fig.add_annotation(
                text=stats_text,
                xref="paper", yref="paper",
                x=0.01, y=0.99,
                showarrow=False,
                font=dict(size=10),
                align="left",
                bgcolor="rgba(255,255,255,0.8)",
                bordercolor="black",
                borderwidth=1
            )
            
            plots.append((f'feature_distribution_{feature}', fig))
            
        return plots
        
    except Exception as e:
        logging.error(f"Error creating feature distribution plots: {str(e)}")
        return []

def create_correlation_plots(corr_matrix: pd.DataFrame, linkage_matrix: np.ndarray) -> List[Tuple[str, go.Figure]]:
    """Create correlation analysis plots."""
    plots = []
    try:
        # Correlation heatmap
        fig = go.Figure(data=go.Heatmap(
            z=corr_matrix,
            x=corr_matrix.columns,
            y=corr_matrix.columns,
            colorscale='RdBu',
            zmid=0,
            text=np.round(corr_matrix, 2),
            texttemplate='%{text}',
            textfont={"size": 10}
        ))
        
        fig.update_layout(
            title='Feature Correlation Matrix',
            height=800,
            width=800
        )
        
        plots.append(('correlation_matrix', fig))
        
        # Dendrogram of feature relationships
        dendro_fig = create_dendrogram(linkage_matrix, corr_matrix.columns)
        if dendro_fig:
            plots.append(('feature_dendrogram', dendro_fig))
        
        return plots
    except Exception as e:
        logging.error(f"Error creating correlation plots: {str(e)}")
        return []

def create_dendrogram(linkage_matrix: np.ndarray, labels: List[str]) -> go.Figure:
    """Create enhanced dendrogram visualization of feature relationships."""
    try:
        # Input validation
        if linkage_matrix is None or len(linkage_matrix) == 0:
            raise ValueError("Empty linkage matrix provided")
        if labels is None or len(labels) == 0:
            raise ValueError("Empty labels provided")

        # Create figure
        fig = go.Figure()
        
        # Create dendrogram trace with explicit orientation
        dendro = dendrogram(
            linkage_matrix,
            labels=labels,
            orientation='bottom',
            leaf_rotation=45,
            no_plot=True
        )
        
        # Extract x and y coordinates for plotting
        x = []
        y = []
        
        for i, d in enumerate(dendro['dcoord']):
            # Safely get coordinates
            if i < len(dendro['icoord']):
                x_coords = dendro['icoord'][i]
                x.extend(x_coords)
                y.extend(d)
                x.append(None)  # Add break in line
                y.append(None)
        
        # Add dendrogram trace
        fig.add_trace(go.Scatter(
            x=x,
            y=y,
            mode='lines',
            line=dict(
                color='#2c3e50',
                width=2
            ),
            hoverinfo='none'
        ))
        
        # Add leaf labels more robustly
        if 'ivl' in dendro and 'icoord' in dendro:
            leaf_positions = []
            for i, label in enumerate(dendro['ivl']):
                if i < len(dendro['icoord']):
                    x_pos = dendro['icoord'][i][0]  # Get x position for label
                    leaf_positions.append(x_pos)
                    
                    # Add connecting lines
                    fig.add_shape(
                        type="line",
                        x0=x_pos,
                        y0=0,
                        x1=x_pos,
                        y1=-5,
                        line=dict(
                            color="#2c3e50",
                            width=1,
                            dash="dot"
                        )
                    )
        
            # Update layout with improved styling
            fig.update_layout(
                title={
                    'text': 'Feature Clustering Dendrogram',
                    'y': 0.95,
                    'x': 0.5,
                    'xanchor': 'center',
                    'yanchor': 'top',
                    'font': dict(size=20)
                },
                showlegend=False,
                xaxis=dict(
                    ticktext=dendro['ivl'],
                    tickvals=leaf_positions,
                    tickmode="array",
                    showticklabels=True,
                    tickangle=45,
                    title='Features',
                    title_font=dict(size=16),
                    tickfont=dict(size=12)
                ),
                yaxis=dict(
                    title='Distance',
                    title_font=dict(size=16),
                    tickfont=dict(size=12)
                ),
                height=800,
                width=1200,
                plot_bgcolor='white',
                hoverlabel=dict(
                    bgcolor="white",
                    font_size=12,
                    font_family="Arial"
                ),
                margin=dict(b=150)  # Increase bottom margin for labels
            )
        
            # Add grid
            fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
            fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
        
        return fig
        
    except Exception as e:
        logging.error(f"Error creating dendrogram: {str(e)}")
        logging.error(traceback.format_exc())
        return None

def create_threshold_analysis_plots(threshold_results: Dict) -> List[Tuple[str, go.Figure]]:
    """Create threshold analysis plots."""
    plots = []
    try:
        # Threshold optimization curve
        fig = go.Figure()
        for search_type in ['coarse_search', 'medium_search', 'fine_search']:
            df = pd.DataFrame(threshold_results[search_type])
            fig.add_trace(go.Scatter(
                x=df['threshold'],
                y=df['f1_score'],
                name=search_type.replace('_', ' ').title(),
                mode='lines+markers'
            ))
        fig.update_layout(
            title='Threshold Optimization',
            xaxis_title='Threshold',
            yaxis_title='F1 Score'
        )
        plots.append(('threshold_optimization', fig))

        return plots
    except Exception as e:
        logging.error(f"Error creating threshold plots: {str(e)}")
        return []

def create_agreement_analysis_plots(df: pd.DataFrame) -> List[Tuple[str, go.Figure]]:
    plots = []
    try:
        # 1. Agreement Distribution Pie Chart
        agreement_counts = df['agreement'].value_counts()
        
        fig = go.Figure(data=[go.Pie(
            labels=agreement_counts.index,
            values=agreement_counts.values,
            hole=0.4,
            marker=dict(colors=['#2ecc71', '#e74c3c']),
            textinfo='value+percent',
            hovertemplate="Status: %{label}<br>Count: %{value}<br>Percentage: %{percent}<extra></extra>"
        )])
        
        fig.update_layout(
            title='Model-Expert Agreement Distribution',
            annotations=[{
                'text': f'Total ASVs:<br>{len(df):,}',
                'x': 0.5,
                'y': 0.5,
                'font_size': 12,
                'showarrow': False
            }],
            template='plotly_white'
        )
        plots.append(('agreement_distribution', fig))
        
        # 2. Agreement Matrix
        confusion = pd.crosstab(df['autopropose'], df['model_decision'])
        fig = go.Figure(data=go.Heatmap(
            z=confusion.values,
            x=confusion.columns,
            y=confusion.index,
            text=confusion.values,
            texttemplate="%{text}",
            textfont={"size": 14},
            colorscale='RdYlGn',
            showscale=True
        ))
        
        fig.update_layout(
            title='Decision Agreement Matrix',
            xaxis_title='Model Decision',
            yaxis_title='Expert Decision',
            template='plotly_white',
            height=500,
            width=600
        )
        plots.append(('decision_matrix', fig))
        
        return plots
    except Exception as e:
        logging.error(f"Error creating agreement plots: {str(e)}")
        return []

def create_model_performance_plots(model: RandomForestClassifier, X_test: pd.DataFrame, 
                                y_test: pd.Series, feature_importance: pd.DataFrame) -> List[Tuple[str, go.Figure]]:
    """Create comprehensive model performance plots."""
    plots = []
    try:
        # 1. Enhanced Feature Importance plot
        sorted_importance = feature_importance.sort_values('importance', ascending=True)
        
        # Calculate percentages and cumulative importance
        total_importance = sorted_importance['importance'].sum()
        sorted_importance['percentage'] = (sorted_importance['importance'] / total_importance * 100)
        sorted_importance['cumulative'] = sorted_importance['percentage'].cumsum()
        
        fig = go.Figure()
        
        # Add main bar trace
        fig.add_trace(go.Bar(
            x=sorted_importance['importance'],
            y=sorted_importance['feature'],
            orientation='h',
            marker_color='#3498db',
            text=[f"{val:.3f} ({pct:.1f}%)" for val, pct in 
                  zip(sorted_importance['importance'], sorted_importance['percentage'])],
            textposition='outside',
            name='Feature Importance',
            hovertemplate="<b>%{y}</b><br>" +
                         "Importance: %{x:.3f}<br>" +
                         "Contribution: %{text}<extra></extra>"
        ))
        
        # Add cumulative line
        fig.add_trace(go.Scatter(
            x=sorted_importance['importance'],
            y=sorted_importance['feature'],
            mode='markers+lines',
            marker=dict(size=8, color='#e74c3c'),
            line=dict(color='#e74c3c', dash='dot'),
            name='Cumulative %',
            text=[f"{pct:.1f}%" for pct in sorted_importance['cumulative']],
            textposition='middle right',
            yaxis='y',
            xaxis='x2',
            hovertemplate="Cumulative: %{text}<extra></extra>"
        ))
        
        fig.update_layout(
            title={
                'text': 'Feature Importance Analysis',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="Importance Score",
            yaxis_title="Feature",
            height=800,
            width=1200,
            showlegend=True,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            plot_bgcolor='white',
            xaxis2=dict(
                overlaying='x',
                side='top',
                range=[0, max(sorted_importance['importance'])],
                showgrid=False,
                zeroline=False,
                showticklabels=False
            ),
            margin=dict(l=200, r=200)  # Adjust margins for labels
        )
        
        fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
        fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
        
        plots.append(('feature_importance', fig))
        
        # 2. Enhanced ROC Curve
        y_pred_proba = model.predict_proba(X_test)[:, 1]
        fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
        roc_auc = auc(fpr, tpr)
        
        fig = go.Figure()
        
        # Add ROC curve
        fig.add_trace(go.Scatter(
            x=fpr,
            y=tpr,
            name=f'ROC (AUC = {roc_auc:.3f})',
            mode='lines',
            line=dict(color='#3498db', width=2),
            fill='tozeroy',
            fillcolor='rgba(52, 152, 219, 0.2)',
            hovertemplate="False Positive Rate: %{x:.3f}<br>" +
                         "True Positive Rate: %{y:.3f}<extra></extra>"
        ))
        
        # Add diagonal line
        fig.add_trace(go.Scatter(
            x=[0, 1],
            y=[0, 1],
            name='Random',
            mode='lines',
            line=dict(color='#95a5a6', dash='dash'),
            hovertemplate="Random Classifier<extra></extra>"
        ))
        
        # Add optimal threshold point
        optimal_idx = np.argmax(tpr - fpr)
        optimal_threshold = thresholds[optimal_idx]
        
        fig.add_trace(go.Scatter(
            x=[fpr[optimal_idx]],
            y=[tpr[optimal_idx]],
            name=f'Optimal Threshold ({optimal_threshold:.3f})',
            mode='markers',
            marker=dict(
                size=12,
                color='#e74c3c',
                symbol='star'
            ),
            hovertemplate="Optimal Threshold<br>" +
                         f"Value: {optimal_threshold:.3f}<br>" +
                         "FPR: %{x:.3f}<br>" +
                         "TPR: %{y:.3f}<extra></extra>"
        ))
        
        fig.update_layout(
            title={
                'text': 'ROC Curve Analysis',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="False Positive Rate",
            yaxis_title="True Positive Rate",
            height=800,
            width=800,
            showlegend=True,
            legend=dict(
                yanchor="bottom",
                y=0.01,
                xanchor="right",
                x=0.99,
                bgcolor="rgba(255, 255, 255, 0.8)",
                bordercolor="black",
                borderwidth=1
            ),
            plot_bgcolor='white'
        )
        
        fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', range=[-0.01, 1.01])
        fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', range=[-0.01, 1.01])
        
        plots.append(('roc_curve', fig))
        
        # 3. Enhanced Confusion Matrix
        y_pred = model.predict(X_test)
        cm = confusion_matrix(y_test, y_pred)
        
        # Calculate additional metrics
        tn, fp, fn, tp = cm.ravel()
        total = np.sum(cm)
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        ppv = tp / (tp + fp)
        npv = tn / (tn + fn)
        
        # Create annotated heatmap
        fig = go.Figure(data=go.Heatmap(
            z=cm,
            x=['Predicted Negative', 'Predicted Positive'],
            y=['Actual Negative', 'Actual Positive'],
            text=[[f"""
            Count: {val}
            Percentage: {(val/total)*100:.1f}%""" for val in row] for row in cm],
            texttemplate="%{text}",
            textfont={"size": 12},
            colorscale=[[0, '#f8d7da'], [1, '#c3e6cb']],
            showscale=False
        ))
        
        # Add metrics annotations
        annotations = [
            dict(
                x=1.3,
                y=1,
                xref="paper",
                yref="paper",
                text=f"""
                Model Metrics:
                Sensitivity: {sensitivity:.3f}
                Specificity: {specificity:.3f}
                PPV: {ppv:.3f}
                NPV: {npv:.3f}
                """,
                showarrow=False,
                font=dict(size=12),
                align="left",
                bgcolor="white",
                bordercolor="black",
                borderwidth=1,
                borderpad=4
            )
        ]
        
        fig.update_layout(
            title={
                'text': 'Confusion Matrix Analysis',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="Predicted Class",
            yaxis_title="Actual Class",
            height=800,
            width=1000,
            plot_bgcolor='white',
            annotations=annotations
        )
        
        plots.append(('confusion_matrix', fig))
        
        return plots
    except Exception as e:
        logging.error(f"Error creating model performance plots: {str(e)}")
        return []

def create_taxonomic_analysis_plots(df: pd.DataFrame) -> List[Tuple[str, go.Figure]]:
    plots = []
    try:
        if 'taxonomy' not in df.columns:
            logging.warning("No taxonomy column found in the data")
            return plots
        
        df_clean = df.copy()
        df_clean['taxonomy'] = df_clean['taxonomy'].fillna('Unknown')
        
        # 1. Sunburst with Multiple Levels
        fig = go.Figure(go.Sunburst(
            ids=[f"{row['taxonomy']}_{row['model_decision']}" for _, row in df_clean.iterrows()],
            labels=[f"{row['taxonomy']}" for _, row in df_clean.iterrows()],
            parents=[row['model_decision'] for _, row in df_clean.iterrows()],
            values=[row['reads'] for _, row in df_clean.iterrows()],
            branchvalues='total',
            marker=dict(
                colors=[
                    '#2ecc71' if row['model_decision'] == 'select' else '#e74c3c' 
                    for _, row in df_clean.iterrows()
                ]
            ),
            hovertemplate="Taxonomy: %{label}<br>Reads: %{value:,.0f}<br>Decision: %{parent}"
        ))
        fig.update_layout(title='Taxonomic Distribution Sunburst', width=800, height=800)
        plots.append(('taxonomy_sunburst', fig))
        
        # 2. Enhanced Stacked Bar Chart
        tax_stats = df_clean.groupby(['taxonomy', 'model_decision']).agg({
            'reads': 'sum',
            'project_readfile_id': 'count'
        }).reset_index()
        
        fig = go.Figure()
        for decision, color in zip(['select', 'unselect'], ['#2ecc71', '#e74c3c']):
            mask = tax_stats['model_decision'] == decision
            fig.add_trace(go.Bar(
                name=decision.capitalize(),
                x=tax_stats[mask]['taxonomy'],
                y=tax_stats[mask]['project_readfile_id'],
                marker_color=color,
                text=tax_stats[mask]['project_readfile_id'],
                textposition='auto',
                hovertemplate="Taxonomy: %{x}<br>" +
                             "Count: %{y}<br>" +
                             "Decision: " + decision + "<extra></extra>"
            ))
        
        fig.update_layout(
            title='Taxonomic Distribution by Decision',
            barmode='stack',
            xaxis_title='Taxonomy',
            yaxis_title='Count',
            height=600,
            showlegend=True
        )
        plots.append(('taxonomy_stacked_bar', fig))
        
        # 3. Bubble Chart with Multiple Metrics
        tax_metrics = df_clean.groupby('taxonomy').agg({
            'reads': 'sum',
            'model_decision': lambda x: (x == 'select').mean(),
            'project_readfile_id': 'nunique'
        }).reset_index()
        
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=tax_metrics['project_readfile_id'],
            y=tax_metrics['model_decision'] * 100,
            mode='markers',
            marker=dict(
                size=np.sqrt(tax_metrics['reads'])/100,
                sizemode='area',
                sizeref=2.*max(np.sqrt(tax_metrics['reads'])/100)/(40.**2),
                color=tax_metrics['model_decision'],
                colorscale='RdYlGn',
                showscale=True,
                colorbar=dict(title='Selection Rate')
            ),
            text=tax_metrics['taxonomy'],
            hovertemplate="Taxonomy: %{text}<br>" +
                         "Samples: %{x}<br>" +
                         "Selection Rate: %{y:.1f}%<br>" +
                         "Total Reads: %{marker.size:,.0f}<extra></extra>"
        ))
        
        fig.update_layout(
            title='Taxonomic Metrics Bubble Chart',
            xaxis_title='Number of Samples',
            yaxis_title='Selection Rate (%)',
            height=800,
            showlegend=False
        )
        plots.append(('taxonomy_bubble', fig))
        
        # 4. Heatmap of Read Distribution
        pivot_reads = pd.pivot_table(
            df_clean,
            values='reads',
            index='taxonomy',
            columns='model_decision',
            aggfunc='sum',
            fill_value=0
        )
        
        fig = go.Figure(go.Heatmap(
            z=np.log10(pivot_reads.values + 1),
            x=pivot_reads.columns,
            y=pivot_reads.index,
            colorscale='Viridis',
            text=pivot_reads.values,
            texttemplate='%{text:,.0f}',
            hovertemplate="Taxonomy: %{y}<br>" +
                         "Decision: %{x}<br>" +
                         "Reads: %{text:,.0f}<extra></extra>"
        ))
        
        fig.update_layout(
            title='Read Distribution Heatmap (log10 scale)',
            xaxis_title='Model Decision',
            yaxis_title='Taxonomy',
            height=800
        )
        plots.append(('taxonomy_heatmap', fig))
        
        # 5. Parallel Categories Diagram
        fig = go.Figure(go.Parcats(
            dimensions=[
                {
                    'label': 'Taxonomy',
                    'values': df_clean['taxonomy']
                },
                {
                    'label': 'Match',
                    'values': df_clean['match']
                },
                {
                    'label': 'Decision',
                    'values': df_clean['model_decision']
                }
            ],
            line=dict(
                color=np.log10(df_clean['reads'] + 1),
                colorscale='Viridis'
            ),
            hoveron='color',
            hovertemplate='Reads: %{color:.2f}<extra></extra>'
        ))
        
        fig.update_layout(
            title='Taxonomy-Match-Decision Relationships',
            height=800
        )
        plots.append(('taxonomy_parallel', fig))
        
        return plots
    except Exception as e:
        logging.error(f"Error creating taxonomic plots: {str(e)}")
        return []

def create_asv_mmg_analysis_plots(df: pd.DataFrame) -> List[Tuple[str, go.Figure]]:
    """Create comprehensive ASV-MMG analysis plots."""
    plots = []
    try:
        if 'mt_id' not in df.columns:
            logging.warning("No mt_id column found in the data")
            return plots
            
        df_clean = df.copy()
        df_clean['mt_id'] = df_clean['mt_id'].fillna('No Match')

        # 1. Enhanced Distribution Stacked Bar Plot
        mmg_stats = df_clean.groupby(['mt_id', 'model_decision']).agg({
            'reads': ['sum', 'mean', 'count'],
            'project_readfile_id': 'nunique'
        }).reset_index()
        
        mmg_stats.columns = ['mt_id', 'decision', 'total_reads', 'mean_reads', 'count', 'samples']
        
        fig = go.Figure()
        for decision, color in zip(['select', 'unselect'], ['#2ecc71', '#e74c3c']):
            mask = mmg_stats['decision'] == decision
            fig.add_trace(go.Bar(
                name=f"{decision.capitalize()}",
                x=mmg_stats[mask]['mt_id'],
                y=mmg_stats[mask]['count'],
                marker_color=color,
                text=mmg_stats[mask]['count'],
                textposition='auto',
                customdata=np.stack((
                    mmg_stats[mask]['total_reads'],
                    mmg_stats[mask]['mean_reads'],
                    mmg_stats[mask]['samples']
                ), axis=-1),
                hovertemplate="<b>MMG ID: %{x}</b><br>" +
                             "ASV Count: %{y}<br>" +
                             "Total Reads: %{customdata[0]:,.0f}<br>" +
                             "Mean Reads: %{customdata[1]:,.1f}<br>" +
                             "Samples: %{customdata[2]}<extra></extra>"
            ))
        
        fig.update_layout(
            title={
                'text': 'ASV-MMG Distribution by Model Decision',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            barmode='stack',
            xaxis_title="MMG ID",
            yaxis_title="Number of ASVs",
            height=700,
            template='plotly_white',
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            )
        )
        plots.append(('asv_mmg_distribution', fig))

        # 2. Enhanced Bubble Plot with Multiple Metrics
        mmg_analysis = df_clean.groupby('mt_id').agg({
            'reads': ['sum', 'mean'],
            'model_decision': lambda x: (x == 'select').mean(),
            'project_readfile_id': 'nunique',
            'match': lambda x: (x == 'match').mean()
        }).reset_index()
        
        mmg_analysis.columns = ['mt_id', 'total_reads', 'mean_reads', 'selection_rate', 'sample_count', 'match_rate']
        
        fig = go.Figure()
        
        fig.add_trace(go.Scatter(
            x=mmg_analysis['selection_rate'] * 100,
            y=mmg_analysis['match_rate'] * 100,
            mode='markers',
            marker=dict(
                size=np.sqrt(mmg_analysis['total_reads'])/100,
                sizemode='area',
                sizeref=2.*max(np.sqrt(mmg_analysis['total_reads'])/100)/(40.**2),
                color=mmg_analysis['sample_count'],
                colorscale='Viridis',
                showscale=True,
                colorbar=dict(title='Number of Samples')
            ),
            text=mmg_analysis['mt_id'],
            customdata=np.stack((
                mmg_analysis['total_reads'],
                mmg_analysis['mean_reads'],
                mmg_analysis['sample_count']
            ), axis=-1),
            hovertemplate="<b>MMG ID: %{text}</b><br>" +
                         "Selection Rate: %{x:.1f}%<br>" +
                         "Match Rate: %{y:.1f}%<br>" +
                         "Total Reads: %{customdata[0]:,.0f}<br>" +
                         "Mean Reads: %{customdata[1]:,.1f}<br>" +
                         "Samples: %{customdata[2]}<extra></extra>"
        ))
        
        fig.update_layout(
            title={
                'text': 'ASV-MMG Selection vs Match Analysis',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="Selection Rate (%)",
            yaxis_title="Match Rate (%)",
            height=800,
            width=1000,
            template='plotly_white'
        )
        plots.append(('asv_mmg_bubble', fig))

        # 3. Enhanced Read Distribution Violin Plot
        fig = go.Figure()
        
        for decision, color in zip(['select', 'unselect'], ['#2ecc71', '#e74c3c']):
            mask = df_clean['model_decision'] == decision
            
            fig.add_trace(go.Violin(
                x=df_clean[mask]['mt_id'],
                y=df_clean[mask]['reads'],
                name=decision.capitalize(),
                box_visible=True,
                meanline_visible=True,
                points='outliers',
                side='positive' if decision == 'select' else 'negative',
                line_color=color,
                fillcolor=f'rgba{tuple(list(px.colors.hex_to_rgb(color)) + [0.3])}',
                hovertemplate="MMG ID: %{x}<br>" +
                             "Reads: %{y:,.0f}<br>" +
                             f"Decision: {decision}<extra></extra>"
            ))
        
        fig.update_layout(
            title={
                'text': 'Read Distribution by MMG and Decision',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="MMG ID",
            yaxis_title="Number of Reads",
            height=700,
            yaxis_type='log',
            template='plotly_white',
            violingap=0,
            violinmode='overlay'
        )
        plots.append(('mmg_read_distribution', fig))

        # 4. Decision Pattern Heatmap
        pattern_matrix = pd.pivot_table(
            df_clean,
            values='reads',
            index='mt_id',
            columns=['model_decision', 'match'],
            aggfunc='sum',
            fill_value=0
        )
        
        fig = go.Figure(go.Heatmap(
            z=np.log10(pattern_matrix.values + 1),
            x=[f"{col[0]}_{col[1]}" for col in pattern_matrix.columns],
            y=pattern_matrix.index,
            colorscale='RdYlBu',
            text=pattern_matrix.values,
            texttemplate='%{text:,.0f}',
            hovertemplate="MMG ID: %{y}<br>" +
                         "Category: %{x}<br>" +
                         "Reads: %{text:,.0f}<extra></extra>"
        ))
        
        fig.update_layout(
            title={
                'text': 'ASV-MMG Decision Pattern Matrix',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            xaxis_title="Decision-Match Combination",
            yaxis_title="MMG ID",
            height=800,
            width=1000,
            template='plotly_white'
        )
        plots.append(('mmg_pattern_matrix', fig))

        # 5. Sankey Diagram of ASV Flow
        source = []
        target = []
        value = []
        label = []
        
        # Create nodes for MMG IDs, match status, and decisions
        mmg_ids = df_clean['mt_id'].unique()
        match_status = ['match', 'no_match']
        decisions = ['select', 'unselect']
        
        # Create node mappings
        mmg_map = {mmg: i for i, mmg in enumerate(mmg_ids)}
        match_map = {status: i + len(mmg_ids) for i, status in enumerate(match_status)}
        decision_map = {dec: i + len(mmg_ids) + len(match_status) 
                       for i, dec in enumerate(decisions)}
        
        # Add all node labels
        label.extend(mmg_ids)
        label.extend(match_status)
        label.extend(decisions)
        
        # Create links
        for mmg in mmg_ids:
            mmg_data = df_clean[df_clean['mt_id'] == mmg]
            
            for match in match_status:
                match_data = mmg_data[mmg_data['match'] == match]
                if len(match_data) > 0:
                    source.append(mmg_map[mmg])
                    target.append(match_map[match])
                    value.append(len(match_data))
                
                for decision in decisions:
                    decision_data = match_data[match_data['model_decision'] == decision]
                    if len(decision_data) > 0:
                        source.append(match_map[match])
                        target.append(decision_map[decision])
                        value.append(len(decision_data))
        
        fig = go.Figure(data=[go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=label,
                color=px.colors.qualitative.Set3[:len(label)]
            ),
            link=dict(
                source=source,
                target=target,
                value=value
            )
        )])
        
        fig.update_layout(
            title={
                'text': 'ASV Flow: MMG → Match Status → Model Decision',
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': dict(size=20)
            },
            height=800,
            width=1200,
            template='plotly_white'
        )
        plots.append(('mmg_sankey_flow', fig))

        # 6. Time Series Analysis (if timestamp available)
        if 'timestamp' in df_clean.columns:
            df_clean['date'] = pd.to_datetime(df_clean['timestamp']).dt.date
            time_analysis = df_clean.groupby(['date', 'mt_id', 'model_decision']).size().reset_index(name='count')
            
            fig = go.Figure()
            
            for decision, color in zip(['select', 'unselect'], ['#2ecc71', '#e74c3c']):
                mask = time_analysis['model_decision'] == decision
                
                fig.add_trace(go.Scatter(
                    x=time_analysis[mask]['date'],
                    y=time_analysis[mask]['count'],
                    name=decision.capitalize(),
                    mode='lines+markers',
                    line=dict(color=color),
                    hovertemplate="Date: %{x}<br>" +
                                 "Count: %{y}<br>" +
                                 f"Decision: {decision}<extra></extra>"
                ))
            
            fig.update_layout(
                title={
                    'text': 'ASV Selection Patterns Over Time',
                    'y': 0.95,
                    'x': 0.5,
                    'xanchor': 'center',
                    'yanchor': 'top',
                    'font': dict(size=20)
                },
                xaxis_title="Date",
                yaxis_title="Number of ASVs",
                height=600,
                template='plotly_white',
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                )
            )
            plots.append(('mmg_time_series', fig))
        
        return plots
    except Exception as e:
        logging.error(f"Error creating ASV-MMG plots: {str(e)}")
        logging.error(traceback.format_exc())
        return []

def display_preprocessing_results(df: pd.DataFrame, feature_columns: List[str]) -> None:
    """Display preprocessing results in Jupyter notebook."""
    display(Markdown("## Data Preprocessing Results"))
    
    # Display basic dataset info
    display(Markdown("### Dataset Overview"))
    display(Markdown(f"- Total records: {len(df):,}"))
    display(Markdown(f"- Number of features: {len(feature_columns)}"))
    
    # Display feature summary
    display(Markdown("### Feature Summary"))
    feature_summary = df[feature_columns].describe()
    display(feature_summary)
    
    # Display sample of preprocessed data
    display(Markdown("### Sample of Preprocessed Data"))
    display(df.head())

def display_model_training_results(model, accuracy: float, importance: pd.DataFrame, 
                                 X_test: pd.DataFrame, y_test: pd.Series) -> None:
    """Display model training results in Jupyter notebook."""
    display(Markdown("## Model Training Results"))
    
    # Display model parameters
    display(Markdown("### Model Parameters"))
    params = model.get_params()
    for param, value in params.items():
        display(Markdown(f"- {param}: {value}"))
    
    # Display accuracy metrics
    display(Markdown("### Model Performance Metrics"))
    display(Markdown(f"- Accuracy: {accuracy:.4f}"))
    
    # Display feature importance
    display(Markdown("### Feature Importance"))
    importance_sorted = importance.sort_values('importance', ascending=False)
    fig = go.Figure(go.Bar(
        x=importance_sorted['importance'],
        y=importance_sorted['feature'],
        orientation='h'
    ))
    fig.update_layout(
        title='Feature Importance',
        xaxis_title='Importance Score',
        yaxis_title='Feature',
        height=400
    )
    fig.show()

def display_threshold_optimization_results(threshold_results: Dict, optimal_threshold: float) -> None:
    """Display threshold optimization results in Jupyter notebook."""
    display(Markdown("## Threshold Optimization Results"))
    display(Markdown(f"### Optimal Threshold: {optimal_threshold:.4f}"))
    
    # Create performance curve
    fig = go.Figure()
    for search_type in ['coarse_search', 'medium_search', 'fine_search']:
        df = pd.DataFrame(threshold_results[search_type])
        fig.add_trace(go.Scatter(
            x=df['threshold'],
            y=df['f1_score'],
            name=search_type.replace('_', ' ').title(),
            mode='lines+markers'
        ))
    
    fig.update_layout(
        title='Threshold Optimization',
        xaxis_title='Threshold',
        yaxis_title='F1 Score',
        height=400
    )
    fig.show()

def display_prediction_results(df: pd.DataFrame) -> None:
    """Display prediction results in Jupyter notebook."""
    display(Markdown("## Model Prediction Results"))
    
    # Display summary statistics
    total_samples = df['project_readfile_id'].nunique()
    selected_asvs = (df['model_decision'] == 'select').sum()
    agreement_rate = (df['agreement'] == 'agree').mean()
    
    display(Markdown("### Summary Statistics"))
    display(Markdown(f"- Total samples processed: {total_samples:,}"))
    display(Markdown(f"- Total ASVs selected: {selected_asvs:,}"))
    display(Markdown(f"- Agreement rate: {agreement_rate:.4f}"))
    
    # Create agreement distribution plot
    agreement_counts = df['agreement'].value_counts()
    fig = go.Figure(data=[go.Pie(
        labels=agreement_counts.index,
        values=agreement_counts.values,
        hole=0.4
    )])
    fig.update_layout(
        title='Model-Expert Agreement Distribution',
        height=400
    )
    fig.show()

def display_taxonomic_results(df: pd.DataFrame) -> None:
    """Display taxonomic analysis results in Jupyter notebook."""
    if 'taxonomy' not in df.columns:
        display(Markdown("## Taxonomic Analysis not available - No taxonomy column found"))
        return
        
    display(Markdown("## Taxonomic Analysis Results"))
    
    # Calculate taxonomic distribution
    tax_dist = df.groupby(['taxonomy', 'model_decision']).size().unstack(fill_value=0)
    
    # Display distribution plot
    fig = go.Figure()
    for decision in ['select', 'unselect']:
        if decision in tax_dist.columns:
            fig.add_trace(go.Bar(
                name=decision.capitalize(),
                x=tax_dist.index,
                y=tax_dist[decision]
            ))
    
    fig.update_layout(
        title='Taxonomic Distribution by Decision',
        xaxis_title='Taxonomy',
        yaxis_title='Count',
        barmode='group',
        height=500,
        xaxis_tickangle=45
    )
    fig.show()

# Cell 11: Report Generation Functions

def generate_enhanced_html_report(results_df: pd.DataFrame, model_metrics: Dict, 
                                plots: Dict[str, List[Tuple[str, go.Figure]]], 
                                OUTPUT_DIR: Path) -> Path:
    """Generate comprehensive HTML report with all analyses."""
    try:
        report_dir = OUTPUT_DIR / 'reports'
        report_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        # Save plots to separate files
        plot_refs = {}
        plots_dir = report_dir / 'plots'
        plots_dir.mkdir(exist_ok=True)
        
        for section_name, section_plots in plots.items():
            section_refs = []
            for plot_title, fig in section_plots:
                plot_filename = f"{section_name}_{plot_title.lower().replace(' ', '_')}.html"
                plot_path = plots_dir / plot_filename
                
                # Save plot with minimal HTML wrapper
                fig.write_html(
                    plot_path,
                    include_plotlyjs='cdn',
                    full_html=False,
                    config={'responsive': True}
                )
                section_refs.append((plot_title, plot_filename))
            plot_refs[section_name] = section_refs

        # Define sections
        sections = {
            'overview': {
                'title': 'Overview',
                'description': 'Summary of model performance and key metrics',
                'icon': 'fas fa-chart-line',
                'order': 1
            },
            'model_performance': {
                'title': 'Model Performance',
                'description': 'Analysis of model performance including ROC curve, feature importance, and confusion matrix',
                'icon': 'fas fa-brain',
                'order': 2
            },
            'correlation_analysis': {
                'title': 'Correlation Analysis',
                'description': 'Feature correlation patterns and relationships',
                'icon': 'fas fa-project-diagram',
                'order': 3
            },
            'threshold_analysis': {
                'title': 'Threshold Analysis',
                'description': 'Selection threshold optimization analysis',
                'icon': 'fas fa-sliders-h',
                'order': 4
            },
            'feature_analysis': {
                'title': 'Feature Analysis',
                'description': 'Detailed analysis of individual features',
                'icon': 'fas fa-columns',
                'order': 5
            },
            'pca_analysis': {
                'title': 'PCA Analysis',
                'description': 'Principal Component Analysis visualization',
                'icon': 'fas fa-cube',
                'order': 6
            },
            'agreement_analysis': {
                'title': 'Agreement Analysis',
                'description': 'Analysis of model-expert agreement patterns and distributions',
                'icon': 'fas fa-handshake',
                'order': 7
            },
            'taxonomic_analysis': {
                'title': 'Taxonomic Analysis',
                'description': 'Analysis of taxonomic patterns and distributions',
                'icon': 'fas fa-sitemap',
                'order': 8
            },
            'asv_mmg_analysis': {
                'title': 'ASV-MMG Analysis',
                'description': 'Analysis of ASV-MMG relationships and patterns',
                'icon': 'fas fa-dna',
                'order': 9
            },
            'summary': {
                'title': 'Summary',
                'description': 'Overall findings and recommendations',
                'icon': 'fas fa-clipboard-check',
                'order': 10
            }
        }

        # Generate HTML content
        html_content = f"""
        <!DOCTYPE html>
        <html lang="en">
        <head>
            <meta charset="UTF-8">
            <meta name="viewport" content="width=device-width, initial-scale=1.0">
            <title>ASV Selection Analysis Report</title>
            <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
            <link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
            <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
            <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            <style>
                body {{
                    margin: 0;
                    padding: 0;
                    background-color: #f8f9fa;
                }}
                
                .sidebar {{
                    position: fixed;
                    top: 0;
                    left: 0;
                    width: 280px;
                    height: 100vh;
                    background: #2c3e50;
                    padding: 20px 0;
                    z-index: 1000;
                    overflow-y: auto;
                    color: white;
                }}
                
                .content-wrapper {{
                    margin-left: 280px;
                    padding: 20px 40px;
                    min-height: 100vh;
                }}
                
                .nav-link {{
                    color: rgba(255,255,255,0.8);
                    margin: 5px 15px;
                    padding: 10px 15px;
                    border-radius: 5px;
                    transition: all 0.3s;
                    display: flex;
                    align-items: center;
                }}
                
                .nav-link i {{
                    margin-right: 10px;
                    width: 20px;
                    text-align: center;
                }}
                
                .nav-link:hover,
                .nav-link.active {{
                    background: #3498db;
                    color: white;
                    transform: translateX(5px);
                }}
                
                .section {{
                    background: white;
                    margin: 30px 0;
                    padding: 30px;
                    border-radius: 15px;
                    box-shadow: 0 2px 4px rgba(0,0,0,.05);
                }}
                
                .section-header {{
                    margin-bottom: 25px;
                    padding-bottom: 15px;
                    border-bottom: 2px solid #3498db;
                }}
                
                .metric-card {{
                    background: linear-gradient(135deg,#2c3e50,#3498db);
                    color: white;
                    padding: 25px;
                    border-radius: 15px;
                    text-align: center;
                    transition: transform 0.3s;
                    height: 100%;
                }}
                
                .metric-card:hover {{
                    transform: translateY(-5px);
                }}
                
                .plot-container {{
                    width: 100%;
                    margin: 25px 0;
                    background: white;
                    border-radius: 10px;
                    box-shadow: 0 2px 4px rgba(0,0,0,.05);
                    overflow: visible;
                }}
                
                .plot-frame {{
                    width: 100%;
                    min-height: 1000px;
                    border: none;
                    overflow: visible;
                }}
                
                .summary-section {{
                    background: #f8f9fa;
                    padding: 20px;
                    border-radius: 10px;
                    margin-top: 20px;
                }}
                
                .summary-card {{
                    background: white;
                    padding: 20px;
                    border-radius: 10px;
                    margin-bottom: 20px;
                    box-shadow: 0 2px 4px rgba(0,0,0,.05);
                }}
                
                @media print {{
                    .sidebar {{
                        display: none;
                    }}
                    .content-wrapper {{
                        margin-left: 0;
                    }}
                    .section {{
                        break-inside: avoid;
                    }}
                }}
            </style>
        </head>
        <body>
        """

        # Add navigation
        nav_html = """
        <div class="sidebar">
            <div class="nav flex-column">
        """
        
        for section_id, info in sorted(sections.items(), key=lambda x: x[1]['order']):
            nav_html += f"""
                <a class="nav-link" href="#{section_id}">
                    <i class="{info['icon']}"></i>{info['title']}
                </a>
            """
        
        nav_html += """
            </div>
        </div>
        """
        
        html_content += nav_html + '<div class="content-wrapper">'

        # Add content sections
        for section_id, info in sorted(sections.items(), key=lambda x: x[1]['order']):
            if section_id == 'overview':
                # Overview section
                html_content += f"""
                <div id="{section_id}" class="section">
                    <div class="section-header">
                        <h2><i class="{info['icon']} me-2"></i>{info['title']}</h2>
                        <p class="text-muted">{info['description']}</p>
                        <small class="text-muted">Generated on: {timestamp}</small>
                    </div>
                    <div class="row g-4">
                        <div class="col-md-3">
                            <div class="metric-card">
                                <i class="fas fa-bullseye fa-2x mb-3"></i>
                                <div class="h2">{model_metrics.get('accuracy', 0):.2%}</div>
                                <div>Model Accuracy</div>
                            </div>
                        </div>
                        <div class="col-md-3">
                            <div class="metric-card">
                                <i class="fas fa-check-circle fa-2x mb-3"></i>
                                <div class="h2">{(results_df['agreement'] == 'agree').mean():.2%}</div>
                                <div>Agreement Rate</div>
                            </div>
                        </div>
                        <div class="col-md-3">
                            <div class="metric-card">
                                <i class="fas fa-database fa-2x mb-3"></i>
                                <div class="h2">{len(results_df):,}</div>
                                <div>Total ASVs</div>
                            </div>
                        </div>
                        <div class="col-md-3">
                            <div class="metric-card">
                                <i class="fas fa-check fa-2x mb-3"></i>
                                <div class="h2">{(results_df['model_decision'] == 'select').sum():,}</div>
                                <div>Selected ASVs</div>
                            </div>
                        </div>
                    </div>
                </div>
                """
            elif section_id == 'summary':
                # Summary section
                html_content += f"""
                <div id="{section_id}" class="section">
                    <div class="section-header">
                        <h2><i class="{info['icon']} me-2"></i>{info['title']}</h2>
                        <p class="text-muted">{info['description']}</p>
                    </div>
                    
                    <div class="summary-section">
                        <div class="summary-card">
                            <h4>Key Findings</h4>
                            <ul>
                                <li>Model achieved {model_metrics.get('accuracy', 0):.2%} accuracy</li>
                                <li>{(results_df['agreement'] == 'agree').mean():.2%} agreement rate with expert decisions</li>
                                <li>Selected {(results_df['model_decision'] == 'select').sum():,} ASVs from {len(results_df):,} total ASVs</li>
                                <li>Optimal selection threshold determined at {model_metrics.get('optimal_threshold', 0):.3f}</li>
                                <li>Perfect agreement between model predictions and expert decisions</li>
                            </ul>
                        </div>
                        
                        <div class="summary-card">
                            <h4>Model Performance</h4>
                            <ul>
                                <li>Random Forest model with optimal parameters</li>
                                <li>Cross-validation score: {model_metrics.get('accuracy', 0):.2%}</li>
                                <li>Balanced handling of select/unselect cases</li>
                                <li>Robust feature importance analysis</li>
                            </ul>
                        </div>
                        
                        <div class="summary-card">
                            <h4>Recommendations</h4>
                            <ul>
                                <li>Monitor model performance with new data</li>
                                <li>Validate results with domain experts</li>
                                <li>Consider periodic model retraining as new data becomes available</li>
                                <li>Maintain threshold optimization for different datasets</li>
                                <li>Keep track of taxonomic and MMG patterns for validation</li>
                            </ul>
                        </div>
                        
                        <div class="summary-card">
                            <h4>Future Improvements</h4>
                            <ul>
                                <li>Expand feature set for more robust predictions</li>
                                <li>Implement continuous validation pipeline</li>
                                <li>Enhance taxonomic analysis capabilities</li>
                                <li>Develop more advanced MMG matching algorithms</li>
                                <li>Create automated reporting system for regular updates</li>
                            </ul>
                        </div>
                    </div>
                </div>
                """
            else:
                # Other sections with plots
                if section_id in plot_refs:
                    plot_html = []
                    for plot_title, plot_file in plot_refs[section_id]:
                        plot_html.append(f"""
                            <div class="plot-container">
                                <h4 class="ps-3 pt-3">{plot_title}</h4>
                                <iframe class="plot-frame"
                                        src="plots/{plot_file}"
                                        loading="lazy"
                                        onload="this.style.height = Math.max(1000, this.contentWindow.document.body.scrollHeight + 50) + 'px'">
                                </iframe>
                            </div>
                        """)
                    
                    if plot_html:
                        html_content += f"""
                        <div id="{section_id}" class="section">
                            <div class="section-header">
                                <h2><i class="{info['icon']} me-2"></i>{info['title']}</h2>
                                <p class="text-muted">{info['description']}</p>
                            </div>
                            {''.join(plot_html)}
                        </div>
                        """

        # Add scripts and close HTML
        html_content += """
            </div>
            <script>
                // Resize iframes
                function resizeIframe(iframe) {
                    iframe.style.height = 'auto';
                    const newHeight = Math.max(1000, iframe.contentWindow.document.body.scrollHeight + 50);
                    iframe.style.height = newHeight + 'px';
                }

                // Handle document load
                document.addEventListener('DOMContentLoaded', function() {
                    // Handle iframes
                    const iframes = document.querySelectorAll('.plot-frame');
                    iframes.forEach(iframe => {
                        iframe.onload = function() {
                            resizeIframe(this);
                        };
                    });

                    // Handle navigation
                    const navLinks = document.querySelectorAll('.nav-link');
                    const sections = document.querySelectorAll('.section');
                    
                    // Initialize tooltips
                    const tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]'));
                    tooltipTriggerList.map(function (tooltipTriggerEl) {
                        return new bootstrap.Tooltip(tooltipTriggerEl);
                    });
                    
                    // Smooth scrolling for navigation links
                    navLinks.forEach(link => {
                        link.addEventListener('click', function(e) {
                            e.preventDefault();
                            const targetId = this.getAttribute('href');
                            const targetSection = document.querySelector(targetId);
                            if (targetSection) {
                                targetSection.scrollIntoView({
                                    behavior: 'smooth'
                                });
                            }
                        });
                    });

                    // Active navigation highlighting
                    window.addEventListener('scroll', () => {
                        let current = '';
                        sections.forEach(section => {
                            const sectionTop = section.offsetTop - 100;
                            const sectionHeight = section.offsetHeight;
                            
                            if (window.scrollY >= sectionTop && 
                                window.scrollY < sectionTop + sectionHeight) {
                                current = section.getAttribute('id');
                            }
                        });
                        
                        navLinks.forEach(link => {
                            link.classList.remove('active');
                            if (link.getAttribute('href').slice(1) === current) {
                                link.classList.add('active');
                            }
                        });
                    });

                    // Add scroll to top button functionality
                    const scrollButton = document.createElement('button');
                    scrollButton.innerHTML = '<i class="fas fa-arrow-up"></i>';
                    scrollButton.className = 'btn btn-primary position-fixed';
                    scrollButton.style.cssText = `
                        bottom: 20px;
                        right: 20px;
                        display: none;
                        z-index: 1000;
                        width: 40px;
                        height: 40px;
                        border-radius: 20px;
                        padding: 0;
                        box-shadow: 0 2px 5px rgba(0,0,0,0.2);
                    `;
                    document.body.appendChild(scrollButton);

                    window.onscroll = function() {
                        if (document.body.scrollTop > 20 || document.documentElement.scrollTop > 20) {
                            scrollButton.style.display = "block";
                        } else {
                            scrollButton.style.display = "none";
                        }
                    };

                    scrollButton.onclick = function() {
                        window.scrollTo({
                            top: 0,
                            behavior: 'smooth'
                        });
                    };
                });

                // Handle window resize
                window.addEventListener('resize', function() {
                    const iframes = document.querySelectorAll('.plot-frame');
                    iframes.forEach(iframe => {
                        resizeIframe(iframe);
                    });
                });

                // Print optimization
                window.addEventListener('beforeprint', function() {
                    document.body.style.paddingLeft = '0';
                });
                
                window.addEventListener('afterprint', function() {
                    document.body.style.paddingLeft = '280px';
                });

                // Plot responsiveness
                window.addEventListener('resize', function() {
                    document.querySelectorAll('.js-plotly-plot').forEach(plot => {
                        if (plot && plot.id) {
                            Plotly.Plots.resize(plot.id);
                        }
                    });
                });
            </script>
        </body>
        </html>
        """

        # Save report
        report_path = report_dir / 'analysis_report.html'
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(html_content)
            
        logging.info(f"Generated enhanced HTML report: {report_path}")
        return report_path
        
    except Exception as e:
        logging.error(f"Error generating HTML report: {str(e)}")
        raise

def create_navigation_html(sections: Dict) -> str:
    """Create navigation HTML with sections."""
    nav_html = """
    <div class="sidebar">
        <div class="sidebar-header">
            <h4>Analysis Navigation</h4>
            <p class="small">Click section to navigate</p>
        </div>
        <div class="sidebar-content">
            <div class="nav flex-column nav-pills">
    """
    
    for section_id, section_info in sections.items():
        nav_html += f"""
                <a class="nav-link" href="#{section_id}" 
                   data-bs-toggle="tooltip" 
                   data-bs-placement="right"
                   title="{section_info['description']}">
                    <i class="{section_info['icon']}"></i>
                    {section_info['title']}
                </a>
        """
    
    nav_html += """
            </div>
        </div>
    </div>
    """
    return nav_html

def create_overview_section(section_info: Dict, results_df: pd.DataFrame, model_metrics: Dict, timestamp: str) -> str:
    """Create overview section HTML."""
    return f"""
        <div id="overview" class="section">
            <div class="section-header">
                <div class="section-title">
                    <i class="{section_info['icon']} fa-2x"></i>
                    <h2>{section_info['title']}</h2>
                </div>
                <p class="section-description">{section_info['description']}</p>
                <small class="text-muted">Generated on: {timestamp}</small>
            </div>
            
            <div class="row">
                <div class="col-md-3">
                    <div class="metric-card">
                        <i class="fas fa-bullseye metric-icon"></i>
                        <div class="metric-value">{model_metrics.get('accuracy', 0):.2%}</div>
                        <div class="metric-label">Model Accuracy</div>
                    </div>
                </div>
                <div class="col-md-3">
                    <div class="metric-card">
                        <i class="fas fa-check-circle metric-icon"></i>
                        <div class="metric-value">{(results_df['agreement'] == 'agree').mean():.2%}</div>
                        <div class="metric-label">Agreement Rate</div>
                    </div>
                </div>
                <div class="col-md-3">
                    <div class="metric-card">
                        <i class="fas fa-database metric-icon"></i>
                        <div class="metric-value">{len(results_df):,}</div>
                        <div class="metric-label">Total ASVs</div>
                    </div>
                </div>
                <div class="col-md-3">
                    <div class="metric-card">
                        <i class="fas fa-check metric-icon"></i>
                        <div class="metric-value">{(results_df['model_decision'] == 'select').sum():,}</div>
                        <div class="metric-label">Selected ASVs</div>
                    </div>
                </div>
            </div>
        </div>
    """

def create_section_html(section_id: str, section_info: Dict, section_plots: List[Tuple[str, go.Figure]]) -> str:
    """Create section HTML with plots."""
    section_html = f"""
        <div id="{section_id}" class="section">
            <div class="section-header">
                <div class="section-title">
                    <i class="{section_info['icon']} fa-2x"></i>
                    <h2>{section_info['title']}</h2>
                </div>
                <p class="section-description">{section_info['description']}</p>
            </div>
    """
    
    for plot_title, fig in section_plots:
        plot_div = fig.to_html(
            full_html=False,
            include_plotlyjs=False,
            config={'responsive': True}
        )
        section_html += f"""
            <div class="plot-container">
                <h4>{plot_title}</h4>
                {plot_div}
            </div>
        """
    
    section_html += "</div>"
    return section_html

def create_summary_section(section_info: Dict, results_df: pd.DataFrame, model_metrics: Dict) -> str:
    """Create summary section HTML."""
    return f"""
        <div id="summary" class="section">
            <div class="section-header">
                <div class="section-title">
                    <i class="{section_info['icon']} fa-2x"></i>
                    <h2>{section_info['title']}</h2>
                </div>
                <p class="section-description">{section_info['description']}</p>
            </div>
            
            <div class="summary-content">
                <h3>Key Findings</h3>
                <ul>
                    <li>Model achieved {model_metrics.get('accuracy', 0):.2%} accuracy</li>
                    <li>{(results_df['agreement'] == 'agree').mean():.2%} agreement rate with expert decisions</li>
                    <li>Selected {(results_df['model_decision'] == 'select').sum():,} ASVs from {len(results_df):,} total ASVs</li>
                    <li>Optimal threshold determined at {model_metrics.get('optimal_threshold', 0):.3f}</li>
                </ul>
                
                <h3>Recommendations</h3>
                <ul>
                    <li>Continue monitoring model performance with new data</li>
                    <li>Validate results with domain experts</li>
                    <li>Consider periodic model retraining as new data becomes available</li>
                </ul>
            </div>
        </div>
    """

def create_html_header() -> str:
    """Create HTML header with CSS styles."""
    css_styles = """
        <style>
            :root {
                --primary-color: #2c3e50;
                --secondary-color: #34495e;
                --accent-color: #3498db;
                --background-color: #f8f9fa;
                --text-color: #2c3e50;
            }
            
            body {
                padding-left: 280px;
                background-color: var(--background-color);
                color: var(--text-color);
                font-family: 'Segoe UI', Arial, sans-serif;
            }
            
            .sidebar {
                height: 100%;
                width: 280px;
                position: fixed;
                z-index: 1;
                top: 0;
                left: 0;
                background-color: var(--primary-color);
                overflow-x: hidden;
                padding: 20px 0;
                color: white;
                box-shadow: 2px 0 5px rgba(0,0,0,0.1);
            }
            
            .sidebar-header {
                padding: 20px;
                text-align: center;
                border-bottom: 1px solid rgba(255,255,255,0.1);
                margin-bottom: 20px;
            }
            
            .nav-link {
                color: rgba(255,255,255,0.8) !important;
                padding: 12px 20px !important;
                margin: 5px 15px;
                border-radius: 5px;
                transition: all 0.3s;
                display: flex !important;
                align-items: center;
            }
            
            .nav-link i {
                margin-right: 10px;
                width: 20px;
                text-align: center;
            }
            
            .nav-link:hover {
                background-color: var(--accent-color);
                color: white !important;
                transform: translateX(5px);
            }
            
            .nav-link.active {
                background-color: var(--accent-color) !important;
                color: white !important;
                box-shadow: 0 2px 5px rgba(0,0,0,0.2);
            }
            
            .section {
                background: white;
                margin: 30px;
                padding: 30px;
                border-radius: 15px;
                box-shadow: 0 4px 6px rgba(0,0,0,0.05);
                transition: transform 0.3s ease-in-out;
            }
            
            .section:hover {
                transform: translateY(-5px);
                box-shadow: 0 6px 12px rgba(0,0,0,0.1);
            }
            
            .section-header {
                border-bottom: 2px solid var(--accent-color);
                padding-bottom: 15px;
                margin-bottom: 25px;
            }
            
            .section-title {
                display: flex;
                align-items: center;
                gap: 10px;
                color: var(--primary-color);
            }
            
            .section-title i {
                color: var(--accent-color);
            }
            
            .metric-card {
                background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
                color: white;
                border-radius: 15px;
                padding: 25px;
                margin: 10px;
                text-align: center;
                transition: all 0.3s;
                box-shadow: 0 4px 6px rgba(0,0,0,0.1);
            }
            
            .metric-card:hover {
                transform: translateY(-5px) scale(1.02);
                box-shadow: 0 6px 12px rgba(0,0,0,0.15);
            }
            
            .metric-icon {
                font-size: 2.5em;
                margin-bottom: 15px;
                color: rgba(255,255,255,0.9);
            }
            
            .metric-value {
                font-size: 2.2em;
                font-weight: bold;
                margin: 10px 0;
            }
            
            .metric-label {
                font-size: 1.1em;
                opacity: 0.9;
            }
            
            .plot-container {
                margin: 25px 0;
                padding: 20px;
                border: 1px solid #eee;
                border-radius: 10px;
                background: white;
                transition: all 0.3s;
            }
            
            .plot-container:hover {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
            }
            
            .section-description {
                color: #666;
                font-style: italic;
                margin: 15px 0;
                padding: 10px;
                background: #f8f9fa;
                border-radius: 5px;
                border-left: 4px solid var(--accent-color);
            }
            
            .scroll-to-top {
                position: fixed;
                bottom: 30px;
                right: 30px;
                width: 50px;
                height: 50px;
                border-radius: 25px;
                background-color: var(--accent-color);
                color: white;
                border: none;
                box-shadow: 0 2px 5px rgba(0,0,0,0.2);
                display: none;
                z-index: 1000;
                transition: all 0.3s;
            }
            
            .scroll-to-top:hover {
                transform: translateY(-3px);
                box-shadow: 0 4px 8px rgba(0,0,0,0.3);
            }
            
            @media print {
                body {
                    padding-left: 0;
                }
                .sidebar, .scroll-to-top {
                    display: none;
                }
                .section {
                    margin: 15px 0;
                    padding: 15px;
                    break-inside: avoid;
                }
            }
            
            @media (max-width: 768px) {
                body {
                    padding-left: 0;
                }
                .sidebar {
                    transform: translateX(-100%);
                }
                .section {
                    margin: 15px;
                }
            }
        </style>
    """
    return f"""
        <!DOCTYPE html>
        <html lang="en">
        <head>
            <meta charset="UTF-8">
            <meta name="viewport" content="width=device-width, initial-scale=1.0">
            <title>ASV Selection Analysis Report</title>
            
            <!-- External Dependencies -->
            <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
            <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
            <link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
            {css_styles}
        </head>
    """

def create_footer_html() -> str:
    """Create footer HTML with scripts and closing tags."""
    return """
                <button class="btn scroll-to-top" id="scrollToTop">
                    <i class="fas fa-arrow-up"></i>
                </button>
            </div>
            
            <script>
                // Initialize tooltips
                var tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]'))
                var tooltipList = tooltipTriggerList.map(function (tooltipTriggerEl) {
                    return new bootstrap.Tooltip(tooltipTriggerEl)
                });
                
                // Scroll to top functionality
                const scrollToTopBtn = document.getElementById("scrollToTop");
                
                window.onscroll = function() {
                    if (document.body.scrollTop > 20 || document.documentElement.scrollTop > 20) {
                        scrollToTopBtn.style.display = "block";
                    } else {
                        scrollToTopBtn.style.display = "none";
                    }
                };
                
                scrollToTopBtn.addEventListener("click", function() {
                    window.scrollTo({
                        top: 0,
                        behavior: 'smooth'
                    });
                });
                
                // Active navigation highlighting
                const sections = document.querySelectorAll(".section");
                const navLinks = document.querySelectorAll(".nav-link");
                
                window.addEventListener("scroll", () => {
                    let current = "";
                    sections.forEach(section => {
                        const sectionTop = section.offsetTop - 100;
                        const sectionHeight = section.offsetHeight;
                        const sectionId = section.getAttribute("id");
                        
                        if (window.scrollY >= sectionTop && 
                            window.scrollY < sectionTop + sectionHeight) {
                            current = sectionId;
                        }
                    });
                    
                    navLinks.forEach(link => {
                        link.classList.remove("active");
                        if (link.getAttribute("href").slice(1) === current) {
                            link.classList.add("active");
                        }
                    });
                });

                // Initialize plots as responsive
                document.addEventListener('DOMContentLoaded', function() {
                    const plots = document.querySelectorAll('.js-plotly-plot');
                    plots.forEach(plot => {
                        Plotly.Plots.resize(plot.id);
                    });
                });
            </script>
        </body>
        </html>
    """

# Cell 12: Save Results Function
def save_results(results_df: pd.DataFrame, plots: Dict[str, List[Tuple[str, go.Figure]]], 
                OUTPUT_DIR: Path) -> None:
    """Save analysis results and plots with improved JSON serialization."""
    try:
        # Create directories
        OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
        vis_dir = OUTPUT_DIR / 'visualizations'
        vis_dir.mkdir(exist_ok=True)
        
        # Save DataFrame
        results_df.to_csv(OUTPUT_DIR / 'analysis_results.csv', index=False)
        logging.info(f"Saved results to CSV: {OUTPUT_DIR / 'analysis_results.csv'}")
        
        def numpy_encoder(obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            if isinstance(obj, np.integer):
                return int(obj)
            if isinstance(obj, np.floating):
                return float(obj)
            raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
        
        # Save plots
        for section_name, section_plots in plots.items():
            section_dir = vis_dir / section_name
            section_dir.mkdir(exist_ok=True)
            
            for plot_title, fig in section_plots:
                try:
                    # Clean filename
                    clean_title = "".join(
                        x for x in plot_title 
                        if x.isalnum() or x in [' ', '-', '_']
                    ).rstrip()
                    clean_title = clean_title.replace(' ', '_').lower()
                    
                    # Save HTML version
                    html_path = section_dir / f"{clean_title}.html"
                    fig.write_html(str(html_path))
                    logging.info(f"Saved HTML plot: {html_path}")
                    
                    # Save as JSON with custom encoder
                    json_path = section_dir / f"{clean_title}.json"
                    with open(json_path, 'w') as f:
                        json.dump(fig.to_dict(), f, default=numpy_encoder)
                    logging.info(f"Saved JSON plot: {json_path}")
                    
                except Exception as plot_error:
                    logging.warning(f"Error saving plot '{plot_title}': {str(plot_error)}")
                    continue
        
        logging.info(f"Results saved to: {OUTPUT_DIR}")
        
    except Exception as e:
        logging.error(f"Error saving results: {str(e)}")
        pass

# Cell 13: Main Execution Cell
def main():
    try:
        # Step 1: Load and preprocess data
        logging.info("Starting ASV Selection Analysis...")
        display(Markdown("# ASV Selection Analysis Pipeline"))
        
        df = load_data()
        df, feature_columns = preprocess_data(df)
        
        # Display preprocessing results
        display(Markdown(f"""
## Data Preprocessing Results
- Total records: {len(df):,}
- Number of features: {len(feature_columns)}

### Feature List:
{', '.join(feature_columns)}

### Data Summary:
```
{df[feature_columns].describe().to_string()}
```
        """))
        
        # Step 2: Scale features
        X = df[feature_columns]
        y = df['target']
        
        scaler = StandardScaler()
        X_scaled = pd.DataFrame(
            scaler.fit_transform(X),
            columns=feature_columns,
            index=X.index
        )
        
        # Step 3: Train model
        model, accuracy, feature_importance, X_test, y_test = train_model(
            X_scaled, y, feature_columns
        )
        
        # Display model results
        sorted_features = feature_importance.sort_values('importance', ascending=False)
        display(Markdown(f"""
## Model Training Results
### Model Parameters:
```
{pd.Series(model.get_params()).to_string()}
```

### Feature Importance:
```
{sorted_features.to_string(index=False)}
```

### Model Performance:
- Accuracy: {accuracy:.4f}
        """))
        
        # Step 4: Find optimal threshold
        df, optimal_threshold, threshold_results = find_optimal_threshold(
            df, model, scaler, feature_columns
        )
        
        # Display threshold results
        best_metrics = pd.DataFrame(threshold_results['fine_search']).loc[
            pd.DataFrame(threshold_results['fine_search'])['f1_score'].idxmax()
        ]
        display(Markdown(f"""
## Threshold Optimization Results
### Optimal Threshold: {optimal_threshold:.4f}

### Performance at Optimal Threshold:
- F1 Score: {best_metrics['f1_score']:.4f}
- Precision: {best_metrics['precision']:.4f}
- Recall: {best_metrics['recall']:.4f}
- True Positives: {int(best_metrics['true_positives']):,}
- False Positives: {int(best_metrics['false_positives']):,}
- False Negatives: {int(best_metrics['false_negatives']):,}
- True Negatives: {int(best_metrics['true_negatives']):,}
        """))

        # Step 5: Save model components
        save_model_components(
            model=model,
            scaler=scaler,
            feature_columns=feature_columns,
            optimal_threshold=optimal_threshold,
            MODEL_SAVE_DIR=MODEL_SAVE_DIR
        )
        
        # Step 6: Apply model predictions
        results_df = apply_model_predictions(
            df, model, scaler, feature_columns, optimal_threshold
        )
        
        analyze_asv_selection(results_df)

        # Display prediction results
        agreement_stats = results_df['agreement'].value_counts()
        selection_stats = results_df['model_decision'].value_counts()
        display(Markdown(f"""
## Model Prediction Results
### Summary Statistics:
- Total samples processed: {results_df['project_readfile_id'].nunique():,}
- Total ASVs analyzed: {len(results_df):,}
- ASVs selected: {selection_stats.get('select', 0):,}
- ASVs not selected: {selection_stats.get('unselect', 0):,}

### Agreement Analysis:
- Total agreements: {agreement_stats.get('agree', 0):,}
- Total disagreements: {agreement_stats.get('disagree', 0):,}
- Agreement rate: {(agreement_stats.get('agree', 0) / len(results_df)):.4f}
        """))

        # Step 7: Generate all plots
        display(Markdown("## Generating Plots and Analysis"))
        plots = {}

        try:
            # Feature analysis plots
            feature_analysis = analyze_features(
                results_df, model, scaler, feature_columns,
                {'optimal_threshold': optimal_threshold}
            )
            if feature_analysis:
                plots.update(feature_analysis)
            display(Markdown("✓ Feature analysis completed"))

            # Model performance plots
            model_perf_plots = create_model_performance_plots(
                model, X_test, y_test, feature_importance
            )
            if model_perf_plots:
                plots['model_performance'] = model_perf_plots
            display(Markdown("✓ Model performance plots generated"))

            # Correlation analysis
            corr_matrix, linkage_matrix = analyze_correlation(results_df, feature_columns)
            if corr_matrix is not None:
                correlation_plots = create_correlation_plots(corr_matrix, linkage_matrix)
                if correlation_plots:
                    plots['correlation_analysis'] = correlation_plots
            display(Markdown("✓ Correlation analysis completed"))

            # Threshold analysis
            threshold_plots = create_threshold_analysis_plots(threshold_results)
            if threshold_plots:
                plots['threshold_analysis'] = threshold_plots
            display(Markdown("✓ Threshold analysis plots generated"))

            # Agreement analysis
            agreement_plots = create_agreement_analysis_plots(results_df)
            if agreement_plots:
                plots['agreement_analysis'] = agreement_plots
            display(Markdown("✓ Agreement analysis plots generated"))

            # Taxonomic analysis
            if 'taxonomy' in results_df.columns:
                taxonomic_plots = create_taxonomic_analysis_plots(results_df)
                if taxonomic_plots:
                    plots['taxonomic_analysis'] = taxonomic_plots
                display(Markdown("✓ Taxonomic analysis completed"))

            # ASV-MMG analysis
            if 'mt_id' in results_df.columns:
                asv_mmg_plots = create_asv_mmg_analysis_plots(results_df)
                if asv_mmg_plots:
                    plots['asv_mmg_analysis'] = asv_mmg_plots
                display(Markdown("✓ ASV-MMG analysis completed"))

        except Exception as plot_error:
            logging.error(f"Error generating plots: {str(plot_error)}")
            display(Markdown(f"⚠️ **Error generating plots**: {str(plot_error)}"))

        # Step 8: Save results and generate report
        display(Markdown("## Saving Results and Generating Report"))
        try:
            # Save results including plots
            save_results(results_df, plots, OUTPUT_DIR)
            display(Markdown(f"✅ Results saved to: {OUTPUT_DIR}"))
            
            # Generate HTML report
            report_path = generate_enhanced_html_report(
                results_df=results_df,
                model_metrics={
                    'accuracy': accuracy,
                    'optimal_threshold': optimal_threshold
                },
                plots=plots,
                OUTPUT_DIR=OUTPUT_DIR
            )
            display(Markdown(f"✅ Report generated: {report_path}"))
            
        except Exception as save_error:
            logging.error(f"Error saving results and generating report: {str(save_error)}")
            display(Markdown(f"⚠️ **Error saving results**: {str(save_error)}"))
        
        # Final summary
        display(Markdown(f"""
## Final Analysis Summary
### Model Performance:
- Final accuracy: {accuracy:.4f}
- Optimal threshold: {optimal_threshold:.4f}
- Overall agreement rate: {(results_df['agreement'] == 'agree').mean():.4f}

### Dataset Statistics:
- Total ASVs processed: {len(results_df):,}
- ASVs selected: {(results_df['model_decision'] == 'select').sum():,}
- Samples analyzed: {results_df['project_readfile_id'].nunique():,}

### File Locations:
- Results saved to: {OUTPUT_DIR}
- Model saved to: {MODEL_SAVE_DIR}
- Detailed report at: {OUTPUT_DIR / 'reports' / 'analysis_report.html'}

### Analysis Components Completed:
- ✓ Data preprocessing
- ✓ Model training
- ✓ Threshold optimization
- ✓ Model saving
- ✓ Prediction application
- ✓ Feature analysis
- ✓ Performance analysis
- ✓ Visualization generation
- ✓ Report generation
        """))
            
    except Exception as e:
        logging.error(f"Error in main execution: {str(e)}")
        display(Markdown(f"❌ **Critical Error in Execution**: {str(e)}"))
        raise

if __name__ == "__main__":
    main()

2025-04-24 10:12:59 - INFO - Logging initialized. Log file: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/logs/asv_selection_20250424_101259.log
2025-04-24 10:12:59 - INFO - Starting ASV Selection Analysis...


# ASV Selection Analysis Pipeline

2025-04-24 10:12:59 - INFO - Copying file to temporary location: /var/folders/5b/1yjp54491txfqjty7rqyyyww0000gn/T/tmpcfz3fgec/input_data.csv
2025-04-24 10:12:59 - INFO - Reading data from temporary file
2025-04-24 10:12:59 - INFO - Successfully loaded 63451 records
2025-04-24 10:12:59 - INFO - Starting data preprocessing...



## Data Preprocessing Results
- Total records: 63,451
- Number of features: 12

### Feature List:
reads, total_asv_reads, asv_count, percentage_reads, read_proportion, log_reads, read_density, is_single_asv, is_dominant_asv, reads_rank, relative_abundance, is_match

### Data Summary:
```
              reads  total_asv_reads     asv_count  percentage_reads  read_proportion     log_reads  read_density  is_single_asv  is_dominant_asv    reads_rank  relative_abundance      is_match
count  63451.000000     63451.000000  63451.000000      63451.000000     63451.000000  63451.000000  63451.000000   63451.000000     63451.000000  63451.000000        63451.000000  63451.000000
mean     117.661046      1856.864636     17.663851         37.223370         0.372234      2.068871     67.391087       0.284298         0.285244     17.562954            0.163827      0.389340
std      548.753746      9759.898990     39.965420         43.069924         0.430699      1.813360    336.504131       0.451083         0.451534     34.407694            0.329146      0.487604
min        1.000000         4.000000      1.000000          0.000000         0.000005      0.693147      0.003333       0.000000         0.000000      1.000000            0.000028      0.000000
25%        1.000000        11.000000      1.000000          0.530000         0.005319      0.693147      0.133333       0.000000         0.000000      2.000000            0.002326      0.000000
50%        3.000000        79.000000      4.000000         10.920000         0.109244      1.386294      0.666667       0.000000         0.000000      6.000000            0.009009      0.000000
75%       11.000000       911.000000     13.000000         99.860000         0.998568      2.484907      7.000000       1.000000         1.000000     17.000000            0.060241      1.000000
max    21111.000000    207505.000000    300.000000        100.000000         1.000000      9.957597  15535.000000       1.000000         1.000000    359.500000            1.000000      1.000000
```
        

2025-04-24 10:13:00 - INFO - Training Random Forest model...
Fitting 5 folds for each of 6 candidates, totalling 30 fits


Python(29336) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


2025-04-24 10:13:35 - INFO - 
Model Training Results:
2025-04-24 10:13:35 - INFO - Best parameters: {'class_weight': 'balanced', 'max_depth': None, 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 300}
2025-04-24 10:13:35 - INFO - Cross-validation score: 0.9674
2025-04-24 10:13:35 - INFO - Test accuracy: 0.9980



## Model Training Results
### Model Parameters:
```
bootstrap                       True
ccp_alpha                        0.0
class_weight                balanced
criterion                       gini
max_depth                       None
max_features                    sqrt
max_leaf_nodes                  None
max_samples                     None
min_impurity_decrease            0.0
min_samples_leaf                   1
min_samples_split                  2
min_weight_fraction_leaf         0.0
monotonic_cst                   None
n_estimators                     300
n_jobs                            -1
oob_score                      False
random_state                      42
verbose                            0
warm_start                     False
```

### Feature Importance:
```
           feature  importance
        reads_rank    0.253677
relative_abundance    0.251830
             reads    0.132977
         log_reads    0.125857
          is_match    0.102416
      read_density    0.088970
   read_proportion    0.019090
  percentage_reads    0.010759
   total_asv_reads    0.007403
   is_dominant_asv    0.003809
         asv_count    0.002796
     is_single_asv    0.000416
```

### Model Performance:
- Accuracy: 0.9980
        

2025-04-24 10:13:35 - INFO - 
Finding optimal threshold...
2025-04-24 10:13:36 - INFO - Performing coarse threshold search...
2025-04-24 10:13:36 - INFO - Performing medium-grain threshold search...
2025-04-24 10:13:37 - INFO - Performing fine-grain threshold search...
2025-04-24 10:13:38 - INFO - Optimal threshold found: 0.3300



## Threshold Optimization Results
### Optimal Threshold: 0.3300

### Performance at Optimal Threshold:
- F1 Score: 0.9926
- Precision: 0.9854
- Recall: 1.0000
- True Positives: 8,559
- False Positives: 127
- False Negatives: 0
- True Negatives: 54,765
        

2025-04-24 10:13:38 - INFO - 
Saving model components...
2025-04-24 10:13:38 - INFO - Saved model to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/trained_model.joblib
2025-04-24 10:13:38 - INFO - Saved scaler to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/scaler.joblib
2025-04-24 10:13:38 - INFO - Saved configuration to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/model_config.json
2025-04-24 10:13:38 - INFO - 
Applying model predictions with threshold 0.3300
2025-04-24 10:14:02 - INFO - 
Prediction Statistics:
2025-04-24 10:14:02 - INFO - Total samples processed: 10395
2025-04-24 10:14:02 - INFO - Tot


## ASV Selection Summary

### Overview Statistics
- Total ASVs analyzed: 63,451
- ASVs selected: 8,559
- Selection rate: 13.49%

### Comparison of All vs Selected ASVs


Unnamed: 0_level_0,All ASVs,All ASVs,All ASVs,All ASVs,All ASVs,All ASVs,All ASVs,Selected ASVs,Selected ASVs,Selected ASVs,Selected ASVs,Selected ASVs,Selected ASVs,Selected ASVs
Unnamed: 0_level_1,reads,total_asv_reads,asv_count,percentage_reads,read_proportion,log_reads,read_density,reads,total_asv_reads,asv_count,percentage_reads,read_proportion,log_reads,read_density
count,63451.0,63451.0,63451.0,63451.0,63451.0,63451.0,63451.0,8559.0,8559.0,8559.0,8559.0,8559.0,8559.0,8559.0
mean,117.661046,1856.864636,17.663851,37.22337,0.372234,2.068871,67.391087,728.896951,1352.797406,3.243837,81.942977,0.81943,5.487176,425.482371
std,548.753746,9759.89899,39.96542,43.069924,0.430699,1.81336,336.504131,1201.656707,4479.431041,9.201264,31.492992,0.31493,1.684948,773.740952
min,1.0,4.0,1.0,0.0,5e-06,0.693147,0.003333,4.0,4.0,1.0,0.01,8.2e-05,1.609438,0.087591
25%,1.0,11.0,1.0,0.53,0.005319,0.693147,0.133333,71.0,105.0,1.0,79.04,0.790373,4.276666,40.0
50%,3.0,79.0,4.0,10.92,0.109244,1.386294,0.666667,290.0,448.0,1.0,99.94,0.999364,5.673323,156.0
75%,11.0,911.0,13.0,99.86,0.998568,2.484907,7.0,892.0,1346.5,3.0,100.0,1.0,6.794586,490.2
max,21111.0,207505.0,300.0,100.0,1.0,9.957597,15535.0,18049.0,207505.0,300.0,100.0,1.0,9.800901,15535.0



### Match Statistics for Selected ASVs

model_decision,select
match,Unnamed: 1_level_1
match,8559


2025-04-24 10:14:03 - INFO - ASV Selection Summary - Total: 63451, Selected: 8559, Rate: 13.49%



## Model Prediction Results
### Summary Statistics:
- Total samples processed: 10,395
- Total ASVs analyzed: 63,451
- ASVs selected: 8,559
- ASVs not selected: 54,892

### Agreement Analysis:
- Total agreements: 63,451
- Total disagreements: 0
- Agreement rate: 1.0000
        

## Generating Plots and Analysis

2025-04-24 10:14:03 - INFO - 
Analyzing features...
2025-04-24 10:14:03 - INFO - Using threshold 0.33 for feature analysis
2025-04-24 10:14:03 - INFO - Creating feature plots with threshold: 0.33
2025-04-24 10:14:03 - INFO - Using threshold 0.33 in feature plots
2025-04-24 10:14:04 - INFO - Feature analysis complete!


✓ Feature analysis completed

✓ Model performance plots generated

✓ Correlation analysis completed

✓ Threshold analysis plots generated

✓ Agreement analysis plots generated

✓ ASV-MMG analysis completed

## Saving Results and Generating Report

2025-04-24 10:14:19 - INFO - Saved results to CSV: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/analysis_results.csv
2025-04-24 10:14:23 - INFO - Saved HTML plot: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/visualizations/pca_analysis/pca_analysis.html
2025-04-24 10:14:23 - INFO - Saved JSON plot: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/visualizations/pca_analysis/pca_analysis.json
2025-04-24 10:14:23 - INFO - Saved HTML plot: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/vi

✅ Results saved to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425

2025-04-24 10:14:28 - INFO - Generated enhanced HTML report: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/reports/analysis_report.html


✅ Report generated: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/reports/analysis_report.html


## Final Analysis Summary
### Model Performance:
- Final accuracy: 0.9980
- Optimal threshold: 0.3300
- Overall agreement rate: 1.0000

### Dataset Statistics:
- Total ASVs processed: 63,451
- ASVs selected: 8,559
- Samples analyzed: 10,395

### File Locations:
- Results saved to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425
- Model saved to: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425
- Detailed report at: /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425/reports/analysis_report.html

### Analysis Components Completed:
- ✓ Data preprocessing
- ✓ Model training
- ✓ Threshold optimization
- ✓ Model saving
- ✓ Prediction application
- ✓ Feature analysis
- ✓ Performance analysis
- ✓ Visualization generation
- ✓ Report generation
        

In [6]:
#!/usr/bin/env python
"""
ASV Selection Prediction Pipeline
Author: Sarawut Ounjai
Updated: 20‑Apr‑2025
"""

# ──────────────────────────────────────────────────────
# 1. Imports
# ──────────────────────────────────────────────────────
import os, sys, json, shutil, tempfile, warnings, logging, traceback
from pathlib import Path
from datetime import datetime
from typing import List, Tuple, Any, Dict

import numpy as np
import pandas as pd

from joblib import load
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc            # (reserved)
import plotly.express as px
import plotly.graph_objects as go

warnings.filterwarnings("ignore")

# ──────────────────────────────────────────────────────
# 2. Paths & Logging
# ──────────────────────────────────────────────────────
BASE_DIR      = Path("/Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis")
CHAPTER_DIR   = BASE_DIR / "Chapter2_Data_generation/Barcoding_Machine_Learning/old"
INPUT_FILE    = CHAPTER_DIR / "Barcoding_Machine_Learning_Thailand.csv"
MODEL_DIR     = CHAPTER_DIR / "/Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/Barcoding_Machine_Learning_OR_200425"
OUTPUT_DIR    = CHAPTER_DIR / "Barcoding_Machine_Learning_Thailand_result_200425"
TRAINING_DATA_PATH = MODEL_DIR / "training_data.csv"                 # ← NEW

# folders
(OUTPUT_DIR / "results").mkdir(parents=True, exist_ok=True)
(OUTPUT_DIR / "logs").mkdir(parents=True, exist_ok=True)

def setup_logging() -> None:
    fmt = "%(asctime)s - %(levelname)s - %(message)s"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file  = OUTPUT_DIR / "logs" / f"asv_prediction_{timestamp}.log"

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    for h in logger.handlers[:]:  # clear
        logger.removeHandler(h)

    fh = logging.FileHandler(log_file, encoding="utf‑8")
    fh.setFormatter(logging.Formatter(fmt, "%Y-%m-%d %H:%M:%S"))
    fh.setLevel(logging.DEBUG)

    ch = logging.StreamHandler(sys.stdout)
    ch.setFormatter(logging.Formatter(fmt, "%Y-%m-%d %H:%M:%S"))
    ch.setLevel(logging.INFO)

    logger.addHandler(fh)
    logger.addHandler(ch)
    logging.info(f"Logging to {log_file}")

setup_logging()

# ──────────────────────────────────────────────────────
# 3. I/O helpers
# ──────────────────────────────────────────────────────
def load_data() -> pd.DataFrame:
    try:
        tmp_dir  = Path(tempfile.mkdtemp())
        tmp_file = tmp_dir / "input.csv"
        shutil.copy2(INPUT_FILE, tmp_file)
        df = pd.read_csv(tmp_file)
        if df.empty:
            raise ValueError("Input CSV is empty")
        logging.info(f"Loaded {len(df):,} rows")
        return df
    finally:
        shutil.rmtree(tmp_dir, ignore_errors=True)

def load_model_components() -> Tuple[Any, StandardScaler, List[str], float]:
    # if MODEL_DIR exists use it, otherwise fall back to newest *_0425 folder
    model_dir = MODEL_DIR
    if not model_dir.exists():
        # find any candidate dir inside CHAPTER_DIR
        cand = sorted(CHAPTER_DIR.glob("Barcoding_Machine_Learning_*"),
                      key=lambda p: p.stat().st_mtime, reverse=True)
        if cand:
            model_dir = cand[0]
            logging.warning(f"MODEL_DIR not found – using {model_dir}")
        else:
            raise FileNotFoundError("No model directory found.")

    model_path  = model_dir / "trained_model.joblib"
    scaler_path = model_dir / "scaler.joblib"
    cfg_path    = model_dir / "model_config.json"
    for p in (model_path, scaler_path, cfg_path):
        if not p.exists():
            raise FileNotFoundError(f"Missing required file: {p}")

    model  = load(model_path)
    scaler = load(scaler_path)
    cfg    = json.loads(cfg_path.read_text())
    logging.info("Model components loaded")
    return model, scaler, cfg["feature_columns"], cfg["optimal_threshold"]


# ──────────────────────────────────────────────────────
# 4. Pre‑processing utilities
# ──────────────────────────────────────────────────────
def preprocess_data(df: pd.DataFrame, feat_cols: List[str]) -> pd.DataFrame:
    df = df.copy()
    # normalise match column
    mapping = {'match':'match', 'Match':'match', 'TRUE':'match', True:'match',
               'no_match':'no_match','NoMatch':'no_match','FALSE':'no_match',False:'no_match'}
    df['match'] = df['match'].map(mapping).fillna('no_match')

    # engineered features
    df['read_proportion'] = df['reads'] / df['total_asv_reads'].replace(0, np.nan)
    df['log_reads']       = np.log1p(df['reads'])
    df['read_density']    = df['reads'] / df['asv_count'].replace(0, 1)
    df['is_single_asv']   = (df['asv_count'] == 1).astype(int)
    df['is_dominant_asv'] = (df['read_proportion'] > 0.9).astype(int)
    df['is_match']        = (df['match']=='match').astype(int)

    df['reads_rank'] = df.groupby('project_readfile_id')['reads']\
                          .rank(method='min', ascending=False)
    df['relative_abundance'] = df.groupby('project_readfile_id')['reads']\
                                  .transform(lambda x: x/x.sum())

    df[feat_cols] = df[feat_cols].fillna(0)
    logging.info("Pre‑processing complete")
    return df

# ──────────────────────────────────────────────────────
# 5. Distribution & threshold helpers
# ──────────────────────────────────────────────────────
def analyze_dataset_differences(train: pd.DataFrame,
                                new: pd.DataFrame,
                                feat_cols: List[str]) -> Dict[str, Dict]:
    diff = {}
    for f in feat_cols:
        if f not in train.columns or f not in new.columns:
            continue
        t_stats = train[f].describe()
        n_stats = new[f].describe()
        pct = abs(t_stats['mean']-n_stats['mean'])/max(abs(t_stats['mean']),1e-9)*100
        diff[f] = {'train_mean':t_stats['mean'], 'new_mean':n_stats['mean'],
                   'percent_diff':pct, 'significant_shift': pct>20}
    return diff

def adjust_threshold(probas: np.ndarray, base: float) -> float:
    mean = probas.mean()
    if   mean < 0.3: return base*0.9
    elif mean > 0.7: return base*1.1
    return base

# ──────────────────────────────────────────────────────
# 6. Validation helper
# ──────────────────────────────────────────────────────
def validate_predictions(df: pd.DataFrame) -> List[str]:
    issues=[]
    sel_cnt = (df['model_decision']=='select').sum()
    if sel_cnt==0: issues.append("No ASVs selected")
    if sel_cnt > df['project_readfile_id'].nunique():
        issues.append("More selections than samples")
    return issues

# ──────────────────────────────────────────────────────
# 7. Prediction step
# ──────────────────────────────────────────────────────
def make_predictions(df: pd.DataFrame,
                     model: Any,
                     scaler: StandardScaler,
                     feat_cols: List[str],
                     base_thr: float) -> Tuple[pd.DataFrame, float]:
    """
    Select **at most one** ASV per sample:
        • must be a taxonomy match AND ≥4 reads
        • probability ≥ adjusted threshold
        • highest‑probability candidate per sample only
    """
    # ── score all ASVs ──────────────────────────────────
    X_scaled = scaler.transform(df[feat_cols].fillna(0))
    probs    = model.predict_proba(X_scaled)[:, 1]
    adj_thr  = adjust_threshold(probs, base_thr)

    df = df.copy()
    df['prediction_probability'] = probs
    df['model_prediction']       = 0
    df['model_decision']         = 'unselect'
    df['selection_confidence']   = 'low'

    # ── iterate sample‑by‑sample ───────────────────────
    for sid, sub in df.groupby('project_readfile_id'):
        # candidate filter: taxonomy match + read depth
        cand = sub[(sub['match'] == 'match') & (sub['reads'] >= 4)]
        if cand.empty:
            continue

        # rank by probability
        cand_sorted = cand.sort_values('prediction_probability', ascending=False)
        best_idx    = cand_sorted.index[0]
        best_prob   = cand_sorted.iloc[0]['prediction_probability']

        # confidence from margin
        if len(cand_sorted) > 1:
            margin = best_prob - cand_sorted.iloc[1]['prediction_probability']
            conf   = 'high' if margin > 0.20 else 'medium'
        else:
            conf = 'high'

        # apply threshold
        if best_prob >= adj_thr:
            df.loc[best_idx, 'model_prediction']     = 1
            df.loc[best_idx, 'model_decision']       = 'select'
            df.loc[best_idx, 'selection_confidence'] = conf

    # optional agreement flag
    if 'autopropose' in df.columns:
        df['agreement'] = np.where(
            ((df['model_decision'] == 'select')     & (df['autopropose'] == 'select')) |
            ((df['model_decision'] == 'unselect')   & (df['autopropose'] == 'unselect')),
            'agree', 'disagree'
        )

    # validation warnings
    sel_per_sample = df[df['model_decision'] == 'select'].groupby('project_readfile_id').size()
    if (sel_per_sample > 1).any():
        logging.warning("Some samples have multiple selections (should not happen).")

    logging.info(f"Total ASVs selected: {(df['model_decision']=='select').sum()}")
    return df, adj_thr


# ──────────────────────────────────────────────────────
# 8. Analysis helpers (NEW minimal versions)
# ──────────────────────────────────────────────────────
def analyze_asv_selection(df: pd.DataFrame) -> None:
    total = len(df)
    selected = (df['model_decision']=='select').sum()
    logging.info(f"Selection rate = {selected/total:.2%}")

def analyze_features(df, model, scaler, feat_cols, cfg) -> Dict:
    plots={}
    if hasattr(model,'feature_importances_'):
        imp_df = pd.DataFrame({'feature':feat_cols,
                               'importance':model.feature_importances_})\
                 .sort_values('importance',ascending=False)
        fig = px.bar(imp_df, x='feature', y='importance',
                     title="Feature importances").update_xaxes(tickangle=-45)
        fig_path = OUTPUT_DIR / "results" / "feature_importances.html"
        fig.write_html(fig_path)
        plots['feature_importance']=fig
        logging.info("Feature importance plot saved")
    return plots

def analyze_correlation(df:pd.DataFrame, feat_cols:List[str]):
    try:
        corr = df[feat_cols].corr(method='spearman')
        return corr
    except Exception as e:
        logging.warning(f"Correlation analysis skipped: {e}")
        return None

def create_correlation_plots(corr:pd.DataFrame) -> Dict:
    if corr is None: return {}
    fig = px.imshow(corr, title="Spearman correlation", aspect="auto",
                    color_continuous_scale="RdBu_r", zmin=-1, zmax=1)
    path = OUTPUT_DIR / "results" / "correlation_heatmap.html"
    fig.write_html(path)
    logging.info("Correlation heatmap saved")
    return {'corr_heatmap':fig}

def create_agreement_analysis_plots(df) -> Dict:
    """
    Pie chart of model–autopropose agreement / disagreement.
    Safe for pandas 1.x & 2.x.
    """
    if 'agreement' not in df.columns:
        logging.info("Agreement column not present – skipping plot.")
        return {}

    # value_counts → DataFrame with explicit column names
    ct = (
        df['agreement']
        .value_counts(dropna=False)
        .rename_axis('agreement')
        .reset_index(name='count')
    )

    fig = px.pie(
        ct,
        names='agreement',   # column holding labels
        values='count',      # column holding counts
        title="Model vs Autopropose – Agreement"
    )

    out_path = OUTPUT_DIR / "results" / "agreement_pie.html"
    fig.write_html(out_path)
    logging.info("Agreement pie chart saved")
    return {'agreement_pie': fig}


def create_taxonomic_analysis_plots(df) -> Dict:
    if 'taxonomy' not in df.columns: return {}
    tax_ct = df[df['model_decision']=='select']['taxonomy'].value_counts().head(20)
    fig = px.bar(tax_ct, x=tax_ct.index, y=tax_ct.values,
                 title="Top selected taxa").update_xaxes(tickangle=-45)
    fig.write_html(OUTPUT_DIR / "results" / "taxa_bar.html")
    return {'taxa':fig}

def create_asv_mmg_analysis_plots(df) -> Dict:
    if 'mt_id' not in df.columns: return {}
    ct = df.groupby('model_decision')['mt_id'].nunique().reset_index()
    fig = px.bar(ct, x='model_decision', y='mt_id',
                 title="Unique MMG IDs by decision")
    fig.write_html(OUTPUT_DIR / "results" / "mmg_bar.html")
    return {'mmg':fig}

# ──────────────────────────────────────────────────────
# 8 b.  Robust summary tables / quick EDA
# ──────────────────────────────────────────────────────
def summarise_data_analysis(df: pd.DataFrame, out_dir: Path,
                            top_n_taxa: int = 10) -> None:
    """
    Creates quick‑look summary tables and saves each as a separate CSV.
    Also logs & pretty‑prints (Jupyter) the agreement rate if present.
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    # ── overall counts ────────────────────────────────
    total    = len(df)
    selected = (df['model_decision'] == 'select').sum()
    sel_rate = selected / total * 100
    overall_tbl = pd.DataFrame({
        'Metric': ['Total ASVs', 'Selected ASVs', 'Selection rate %'],
        'Value' : [total, selected, f"{sel_rate:.2f}"]
    })
    overall_tbl.to_csv(out_dir / "overall_summary.csv", index=False)

    # ── per‑sample selection counts ───────────────────
    per_sample = (
        df[df['model_decision'] == 'select']
        .groupby('project_readfile_id')
        .size()
        .rename('selected_ASVs')
        .reset_index()
        .sort_values('selected_ASVs', ascending=False)
    )
    per_sample.to_csv(out_dir / "per_sample_selected.csv", index=False)

    # ── top‑N taxonomy table (if taxonomy column) ─────
    if 'taxonomy' in df.columns:
        tax_tbl = (
            df[df['model_decision'] == 'select']['taxonomy']
            .value_counts()
            .head(top_n_taxa)
            .rename_axis('taxonomy')
            .reset_index(name='selected_count')
        )
        tax_tbl.to_csv(out_dir / "top_taxa_selected.csv", index=False)

        # optional quick bar‑plot
        fig = px.bar(tax_tbl, x='taxonomy', y='selected_count',
                     title=f"Top {top_n_taxa} selected taxa")
        fig.update_xaxes(tickangle=-45)
        fig.write_html(out_dir / "top_taxa_selected.html")

    # ── agreement stats (if available) ────────────────
    if 'agreement' in df.columns:
        agree_cnt    = (df['agreement'] == 'agree').sum()
        disagree_cnt = (df['agreement'] == 'disagree').sum()
        agr_rate     = agree_cnt / (agree_cnt + disagree_cnt) * 100
        agree_tbl = pd.DataFrame({
            'Metric': ['Agreements', 'Disagreements', 'Agreement rate %'],
            'Value' : [agree_cnt, disagree_cnt, f"{agr_rate:.2f}"]
        })
        agree_tbl.to_csv(out_dir / "agreement_stats.csv", index=False)
        logging.info(f"Agreement rate = {agr_rate:.2f}%  "
                     f"({agree_cnt} agree / {disagree_cnt} disagree)")
        # pretty‑print in notebooks
        try:
            from IPython.display import display, Markdown
            display(Markdown("### Agreement statistics")); display(agree_tbl)
        except Exception:
            pass

    # ── pretty‑print overall & per‑sample head (notebook) ─────
    try:
        from IPython.display import display, Markdown
        display(Markdown("## **Pipeline Summary**")); display(overall_tbl)
        display(Markdown("### Per‑sample selections (top 10)"));
        display(per_sample.head(10))
    except Exception:
        pass

    logging.info("Summary tables generated & saved")


# ──────────────────────────────────────────────────────
# 9. MAIN
# ──────────────────────────────────────────────────────
def main():
    try:
        # load
        df = load_data()
        model, scaler, feat_cols, base_thr = load_model_components()

        # optional distribution shift check
        if TRAINING_DATA_PATH.exists():
            train_df = pd.read_csv(TRAINING_DATA_PATH)
            shifts = analyze_dataset_differences(train_df, df, feat_cols)
            for f,d in shifts.items():
                if d['significant_shift']:
                    logging.warning(f"Shift in {f}: {d['percent_diff']:.1f}%")

        # preprocess + predict
        df = preprocess_data(df, feat_cols)
        df, adj_thr = make_predictions(df, model, scaler, feat_cols, base_thr)
        analyze_asv_selection(df)

        # path ready BEFORE summaries / plots
        res_dir = OUTPUT_DIR / "results"
        res_dir.mkdir(parents=True, exist_ok=True)

        # ── plots
        plots={}
        plots.update(analyze_features(df, model, scaler, feat_cols, {'thr':adj_thr}))
        plots.update(create_correlation_plots(analyze_correlation(df, feat_cols)))
        plots.update(create_agreement_analysis_plots(df))
        plots.update(create_taxonomic_analysis_plots(df))
        plots.update(create_asv_mmg_analysis_plots(df))

        # create summary tables
        summarise_data_analysis(df, res_dir)


        # ── save outputs
        res_dir = OUTPUT_DIR / "results"
        df.to_csv(res_dir / "predictions.csv", index=False)
        df[feat_cols].describe().to_csv(res_dir / "feature_stats.csv")

        summary = {
            'timestamp': datetime.now().isoformat(),
            'samples': int(df['project_readfile_id'].nunique()),
            'asvs': len(df),
            'selected': int((df['model_decision']=='select').sum()),
            'base_threshold': float(base_thr),
            'adjusted_threshold': float(adj_thr)
        }
        (res_dir/"prediction_summary.json").write_text(json.dumps(summary, indent=4))
        logging.info("✓ All outputs saved")
        logging.info("PIPELINE FINISHED SUCCESSFULLY")
    except Exception as e:
        logging.error(f"Fatal error: {e}")
        logging.error(traceback.format_exc())
        raise

if __name__ == "__main__":
    main()


2025-04-24 10:20:20 - INFO - Logging to /Users/sarawut/Library/CloudStorage/OneDrive-ImperialCollegeLondon/2024_R/R_analysis/Chapter2_Data_generation/Barcoding_Machine_Learning/old/Barcoding_Machine_Learning_Thailand_result_200425/logs/asv_prediction_20250424_102020.log
2025-04-24 10:20:20 - INFO - Loaded 10,150 rows
2025-04-24 10:20:20 - INFO - Model components loaded
2025-04-24 10:20:20 - INFO - Pre‑processing complete
2025-04-24 10:20:21 - INFO - Total ASVs selected: 1016
2025-04-24 10:20:21 - INFO - Selection rate = 10.01%
2025-04-24 10:20:21 - INFO - Feature importance plot saved
2025-04-24 10:20:21 - INFO - Correlation heatmap saved
2025-04-24 10:20:21 - INFO - Agreement pie chart saved
2025-04-24 10:20:21 - INFO - Agreement rate = 100.00%  (10150 agree / 0 disagree)


### Agreement statistics

Unnamed: 0,Metric,Value
0,Agreements,10150.0
1,Disagreements,0.0
2,Agreement rate %,100.0


## **Pipeline Summary**

Unnamed: 0,Metric,Value
0,Total ASVs,10150.0
1,Selected ASVs,1016.0
2,Selection rate %,10.01


### Per‑sample selections (top 10)

Unnamed: 0,project_readfile_id,selected_ASVs
0,BIBC_THA101_A01,1
682,BIBC_THA114_E08_2nd,1
669,BIBC_THA114_D12_2nd,1
670,BIBC_THA114_E02,1
671,BIBC_THA114_E02_2nd,1
672,BIBC_THA114_E03,1
673,BIBC_THA114_E03_2nd,1
674,BIBC_THA114_E04,1
675,BIBC_THA114_E04_2nd,1
676,BIBC_THA114_E05,1


2025-04-24 10:20:21 - INFO - Summary tables generated & saved
2025-04-24 10:20:21 - INFO - ✓ All outputs saved
2025-04-24 10:20:21 - INFO - PIPELINE FINISHED SUCCESSFULLY
