In [1]:
#!/usr/bin/env python3
"""
Complete ASV Classification Pipeline
Version: 3.0 - Final with Intra-Species Variants & No Mixed

Key Features:
- Pre-filter Technical_Artifacts (reads < 4)
- Intra-Species Variants based on co-occurrence pattern
- Phase 1: Strong evidence only (no Uncertain in training)
- ML trained on strong evidence, predicts all including Uncertain
- Sequence Summary without Mixed (priority-based)
- All Uncertain classified by ML

Author: Scientific Pipeline
Date: 2025
"""

import pandas as pd
import numpy as np
from collections import Counter, defaultdict
import warnings
import sys
from pathlib import Path
warnings.filterwarnings('ignore')

# Progress bar
try:
    from tqdm import tqdm
    TQDM_AVAILABLE = True
except ImportError:
    TQDM_AVAILABLE = False
    print("⚠️  tqdm not available. Install with: pip install tqdm")

# Machine Learning
try:
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    from sklearn.model_selection import train_test_split, cross_val_score
    from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
    from sklearn.utils import class_weight
    import joblib
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    print("❌ ERROR: scikit-learn required")
    print("Install with: pip install scikit-learn")
    sys.exit(1)

# SMOTE for imbalanced data
try:
    from imblearn.over_sampling import SMOTE
    SMOTE_AVAILABLE = True
except ImportError:
    SMOTE_AVAILABLE = False
    print("⚠️  imbalanced-learn not available. Install with: pip install imbalanced-learn")

# ========================
# CONFIGURATION
# ========================

INPUT_FILE = "/Users/sarawut/Desktop/Manuscript_ASV_selection/data_analysis/sequences_analysis/ASV_Complete_Analysis.csv"
OUTPUT_DIR = "/Users/sarawut/Desktop/Manuscript_ASV_selection/data_analysis//classification_analysis"

# Output files
OUTPUT_CLASSIFIED = f"{OUTPUT_DIR}/ASV_Final_Classification.csv"
OUTPUT_SEQUENCE_SUMMARY = f"{OUTPUT_DIR}/ASV_Sequence_Summary.csv"
OUTPUT_STATISTICS = f"{OUTPUT_DIR}/Classification_Statistics.csv"
OUTPUT_FEATURE_IMPORTANCE = f"{OUTPUT_DIR}/Feature_Importance.csv"
OUTPUT_REPORT = f"{OUTPUT_DIR}/Classification_Report.txt"
OUTPUT_MODEL = f"{OUTPUT_DIR}/Classification_Model.pkl"

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# ========================
# THRESHOLDS
# ========================

THRESHOLDS = {
    # Abundance
    'technical_artifacts_reads': 4,
    'very_low_reads': 10,
    'low_reads': 100,
    'medium_reads_min': 100,
    'medium_reads_max': 1000,
    'high_reads': 1000,
    
    'very_low_abundance': 0.01,
    'low_abundance': 0.05,
    'medium_abundance_min': 0.05,
    'medium_abundance_max': 0.50,
    'high_abundance': 0.50,
    
    # Cross-contamination
    'cross_cont_clear': 0.05,
    'cross_cont_likely': 0.10,
    'cross_cont_possible': 0.20,
    
    # Environmental
    'env_order_mismatch': 0.20,
    'env_class_mismatch': 0.30,
    'env_phylo_distance': 0.20,
    'env_phylo_threshold': 0.15,
    
    # Quality
    'quality_low': 60,
    'quality_medium': 70,
    'quality_high': 80,
    
    # Composition (for NUMTs)
    'gc_normal_min': 32,
    'gc_normal_max': 40,
    'motif_score_degraded': 70,
    
    # Phylogenetic
    'phylo_same_species': 0.02,
    'phylo_same_genus': 0.05,
    'phylo_divergent': 0.10,
    'phylo_intra_max': 0.15,
    'phylo_env_min': 0.05, 
    
    # Distribution
    'widespread_threshold': 5,
    
    # Length
    'length_min': 200,
    'length_max': 900,
    
    # Intra-Species Variants
    'intra_species_min_cooccurrence': 2,
}

# ========================
# HELPER FUNCTIONS
# ========================

def safe_float(val, default=0.0):
    """Safely convert to float"""
    try:
        if pd.isna(val):
            return default
        return float(val)
    except:
        return default

def safe_int(val, default=0):
    """Safely convert to int"""
    try:
        if pd.isna(val):
            return default
        return int(val)
    except:
        return default

def safe_str(val, default=''):
    """Safely convert to string"""
    try:
        if pd.isna(val):
            return default
        return str(val)
    except:
        return default

def safe_bool_to_int(val, default=False):
    """Safely convert to boolean then int"""
    try:
        if pd.isna(val):
            return int(default)
        if isinstance(val, bool):
            return int(val)
        if isinstance(val, (int, float)):
            if pd.isna(val):
                return int(default)
            return int(bool(val))
        if isinstance(val, str):
            return int(val.lower() in ['true', '1', 'yes', 't'])
        return int(default)
    except:
        return int(default)

# ========================
# STEP 1: BUILD AUTHENTICATED DATABASE
# ========================

def build_authenticated_database(df):
    """Build database of authenticated ASVs"""
    
    print("\n" + "="*80)
    print("STEP 1: BUILDING AUTHENTICATED DATABASE")
    print("="*80)
    
    print("\nCriteria: autopropose='select' AND match='match'")
    
    # Filter authenticated
    auth_mask = (df['autopropose'] == 'select') & (df['match'] == 'match')
    auth_rows = df[auth_mask].copy()
    
    print(f"\n  Found {len(auth_rows):,} authenticated instances")
    
    if len(auth_rows) == 0:
        print("  ⚠️  WARNING: No authenticated ASVs found!")
        return {}
    
    # Build database
    auth_db = {}
    
    for _, row in auth_rows.iterrows():
        asv_id = safe_str(row.get('asv_id', ''))
        sample_id = safe_str(row.get('project_readfile_id', ''))
        
        if not asv_id or not sample_id:
            continue
        
        if asv_id not in auth_db:
            auth_db[asv_id] = {
                'sample_ids': [],
                'reads_list': [],
                'family': safe_str(row.get('family', '')),
                'subfamily': safe_str(row.get('subfamily', '')),
                'order': safe_str(row.get('blast_tax_order', '')),
                'genus': safe_str(row.get('blast_tax_genus', '')),
            }
        
        auth_db[asv_id]['sample_ids'].append(sample_id)
        auth_db[asv_id]['reads_list'].append(safe_float(row.get('reads', 0)))
    
    # Calculate statistics
    for asv_id in auth_db:
        reads_list = auth_db[asv_id]['reads_list']
        auth_db[asv_id]['total_reads'] = sum(reads_list)
        auth_db[asv_id]['avg_reads'] = np.mean(reads_list) if reads_list else 0
        auth_db[asv_id]['n_samples'] = len(auth_db[asv_id]['sample_ids'])
    
    print(f"\n  ✓ Unique authenticated ASVs: {len(auth_db):,}")
    print(f"  ✓ Total authenticated instances: {len(auth_rows):,}")
    
    # Distribution
    n_samples_dist = Counter([info['n_samples'] for info in auth_db.values()])
    print(f"\n  Authenticated ASVs per sample:")
    for n in sorted(n_samples_dist.keys())[:10]:
        count = n_samples_dist[n]
        print(f"    {n:3d} sample(s): {count:5,} ASVs")
    
    return auth_db

# ========================
# STEP 2: CALCULATE SAMPLE STATISTICS
# ========================

def calculate_sample_statistics(df):
    """Pre-calculate statistics per ASV across all samples"""
    
    print("\n" + "="*80)
    print("STEP 2: CALCULATING SAMPLE STATISTICS")
    print("="*80)
    
    asv_stats = {}
    
    # Group by ASV
    asv_groups = df.groupby('asv_id')
    
    iterator = asv_groups
    if TQDM_AVAILABLE:
        iterator = tqdm(asv_groups, desc="  Processing ASVs")
    
    for asv_id, group in iterator:
        asv_stats[asv_id] = {
            'n_samples': len(group),
            'total_reads_all_samples': group['reads'].sum(),
            'avg_reads_per_sample': group['reads'].mean(),
            'max_reads': group['reads'].max(),
            'min_reads': group['reads'].min(),
            'is_singleton': len(group) == 1,
            'is_widespread': len(group) >= THRESHOLDS['widespread_threshold'],
        }
    
    print(f"\n  ✓ Calculated statistics for {len(asv_stats):,} unique ASVs")
    
    return asv_stats

# ========================
# STEP 3: DETECT INTRA-SPECIES VARIANTS
# ========================

def detect_intra_species_variants(df, auth_db):
    """
    Detect Intra-Species Variants based on co-occurrence pattern
    Only classify if secondary ASV appears ≥2 times with same authenticated ASV
    """
    
    print("\n" + "="*80)
    print("STEP 3: DETECTING INTRA-SPECIES VARIANTS")
    print("="*80)
    
    print("\nMethod: Co-occurrence pattern analysis")
    print(f"Threshold: ≥{THRESHOLDS['intra_species_min_cooccurrence']} co-occurrences with same authenticated ASV")
    
    # Build co-occurrence database
    co_occurrence = defaultdict(int)
    
    # For each sample with authenticated ASV
    for sample_id in df['project_readfile_id'].unique():
        sample_df = df[df['project_readfile_id'] == sample_id]
        
        # Find authenticated ASV in this sample
        auth_asv = sample_df[
            (sample_df['autopropose'] == 'select') &
            (sample_df['match'] == 'match')
        ]
        
        if len(auth_asv) == 0:
            continue
        
        if len(auth_asv) > 1:
            # Multiple authenticated - take highest reads
            auth_asv = auth_asv.nlargest(1, 'reads')
        
        auth_asv_id = auth_asv.iloc[0]['asv_id']
        
        # Find all secondary ASVs in this sample
        secondary_asvs = sample_df[
            (sample_df['asv_id'] != auth_asv_id) &
            (sample_df['reads'] >= THRESHOLDS['technical_artifacts_reads'])
        ]
        
        # Count co-occurrences
        for _, sec_row in secondary_asvs.iterrows():
            sec_asv_id = sec_row['asv_id']
            key = (auth_asv_id, sec_asv_id)
            co_occurrence[key] += 1
    
    print(f"\n  Found {len(co_occurrence):,} unique (authenticated_ASV, secondary_ASV) pairs")
    
    # Build strong evidence set
    intra_species_db = {}
    
    for (auth_asv_id, sec_asv_id), count in co_occurrence.items():
        if count >= THRESHOLDS['intra_species_min_cooccurrence']:
            # Calculate confidence based on count
            if count >= 3:
                conf = 0.80
                level = 'Strong'
            else:  # count == 2
                conf = 0.70
                level = 'Moderate'
            
            # Store (keep highest count if secondary appears with multiple auth ASVs)
            if sec_asv_id not in intra_species_db or \
               count > intra_species_db[sec_asv_id][1]:
                intra_species_db[sec_asv_id] = (auth_asv_id, count, conf, level)
    
    print(f"\n  Strong evidence Intra-Species Variants:")
    print(f"    Secondary ASVs: {len(intra_species_db):,}")
    
    if intra_species_db:
        counts = [v[1] for v in intra_species_db.values()]
        print(f"    Co-occurrence counts:")
        print(f"      Min: {min(counts)}")
        print(f"      Max: {max(counts)}")
        print(f"      Avg: {np.mean(counts):.1f}")
        
        # Distribution
        count_dist = Counter(counts)
        print(f"\n    Distribution:")
        for cnt in sorted(count_dist.keys())[:10]:
            print(f"      {cnt} co-occurrences: {count_dist[cnt]:,} secondary ASVs")
    
    return intra_species_db

# ========================
# STEP 4: HELPER FUNCTIONS FOR CLASSIFICATION
# ========================

def get_authenticated_for_sample(sample_id, auth_db):
    """Get authenticated ASV info for a specific sample"""
    for asv_id, info in auth_db.items():
        if sample_id in info['sample_ids']:
            return info
    return None

def check_family_match(row, auth_info):
    """Check if ASV family matches authenticated"""
    if not auth_info:
        return False
    
    row_family = (
        safe_str(row.get('asv_family_v1', '')) or
        safe_str(row.get('asv_family_v2', '')) or
        safe_str(row.get('asv_family_v3', '')) or
        safe_str(row.get('family', ''))
    )
    
    auth_family = auth_info.get('family', '')
    
    if not row_family or not auth_family:
        return False
    
    return row_family == auth_family

# ========================
# STEP 5: PHASE 1 - RULE-BASED CLASSIFICATION
# ========================

def classify_phase1_rule_based(row, auth_db, asv_stats, intra_species_db):
    """
    Phase 1: Rule-based classification (STRONG EVIDENCE ONLY)
    
    Philosophy:
    - Only classify when CONFIDENT
    - Use Uncertain when not sure
    - Let ML learn from confident cases
    - ML predicts Uncertain based on sequence patterns
    """
    
    # Extract variables
    asv_id = safe_str(row.get('asv_id', ''))
    sample_id = safe_str(row.get('project_readfile_id', ''))
    reads = safe_float(row.get('reads', 0))
    total_reads = safe_float(row.get('total_asv_reads', 1))
    abundance = reads / total_reads if total_reads > 0 else 0
    
    asv_stat = asv_stats.get(asv_id, {
        'n_samples': 1,
        'is_singleton': True,
        'is_widespread': False
    })
    
    quality_score = safe_float(row.get('ANALYSIS_sequence_quality_score', 0))
    seq_length = safe_int(row.get('ANALYSIS_seq_length', 417))
    phylo_dist = safe_float(row.get('Phylogenetic_distance', 0))
    match_status = safe_str(row.get('match', ''))
    
    # ================================================
    # CLASS 0: Authenticated (100% confident)
    # ================================================
    
    if safe_str(row.get('autopropose', '')) == 'select' and match_status == 'match':
        return {
            'class': 'Authenticated',
            'confidence': 1.00,
            'method': 'Rule',
            'reason': 'autopropose=select AND match=match'
        }
    
    # ================================================
    # CLASS 1: Cross-Contamination (High confidence)
    # ================================================
    
    if asv_id in auth_db:
        if sample_id not in auth_db[asv_id]['sample_ids']:
            # In authenticated DB but not in this sample
            # → Cross-contamination from other samples
            
            if abundance < THRESHOLDS['cross_cont_clear']:
                conf = 0.95
                level = 'Very_Clear'
            elif abundance < THRESHOLDS['cross_cont_likely']:
                conf = 0.90
                level = 'Clear'
            elif abundance < THRESHOLDS['cross_cont_possible']:
                conf = 0.85
                level = 'Likely'
            else:
                conf = 0.75
                level = 'Possible_Shared_Species'
            
            n_auth_samples = len(auth_db[asv_id]['sample_ids'])
            
            return {
                'class': 'Cross_Contamination',
                'subtype': level,
                'confidence': conf,
                'method': 'Rule',
                'reason': f'Authenticated in {n_auth_samples} other sample(s), abundance={abundance:.1%}'
            }
    
    # CLASS 2: Intra-Species Variant (Strong Evidence)
    # Co-occurrence ≥2 times with same authenticated ASV
    
    if asv_id in intra_species_db:
        auth_asv_id, count, conf, level = intra_species_db[asv_id]
        
        # ✅ เพิ่มการเช็ค taxonomy!
        if match_status == 'match':  # ← เพิ่มบรรทัดนี้!
            return {
                'class': 'Intra_Species_Variant',
                'subtype': level,
                'confidence': conf,
                'method': 'Rule_CoOccurrence',
                'reason': f'Co-occurs {count} times with authenticated ASV {auth_asv_id}, match=match',
                'co_occurrence_count': count
            }
        else:  # ← เพิ่มส่วนนี้!
            # Co-occurs but match≠'match' → Environmental
            return {
                'class': 'Environmental_Contamination',
                'subtype': 'Co_occurring_Different_Species',
                'confidence': 0.80,
                'method': 'Rule_CoOccurrence',
                'reason': f'Co-occurs {count} times but match≠match (different species)',
                'co_occurrence_count': count
            }
    
    # ================================================
    # CLASS 3: Failed (100% confident)
    # ================================================
    
    sequence = safe_str(row.get('ANALYSIS_corrected_sequence_full', ''))
    if not sequence or sequence == '' or reads == 0:
        return {
            'class': 'Failed',
            'confidence': 1.00,
            'method': 'Rule',
            'reason': 'No sequence data or zero reads'
        }
    
    # ================================================
    # CLASS 4: Technical Artifacts (High confidence)
    # ================================================
    
    # Extreme length
    if seq_length < THRESHOLDS['length_min'] or seq_length > THRESHOLDS['length_max']:
        return {
            'class': 'Technical_Artifacts',
            'subtype': 'Abnormal_Length',
            'confidence': 0.85,
            'method': 'Rule',
            'reason': f'Extreme length: {seq_length}bp'
        }
    
    # Low quality singleton
    if reads < 10 and asv_stat['is_singleton'] and quality_score < THRESHOLDS['quality_medium']:
        return {
            'class': 'Technical_Artifacts',
            'subtype': 'Low_Quality_Singleton',
            'confidence': 0.80,
            'method': 'Rule',
            'reason': f'Singleton with {reads} reads and quality={quality_score:.1f}'
        }
    
    # Very low quality
    if reads < 10 and quality_score < THRESHOLDS['quality_low']:
        return {
            'class': 'Technical_Artifacts',
            'subtype': 'Very_Low_Quality',
            'confidence': 0.75,
            'method': 'Rule',
            'reason': f'Low reads ({reads}) with very low quality ({quality_score:.1f})'
        }
    
    # ================================================
    # CLASS 5: Environmental Contamination
    # ONLY when CONFIDENT it's NOT Coleoptera
    # ================================================
    
    auth_info = get_authenticated_for_sample(sample_id, auth_db)
    
    # 5A. Proposed but taxonomy mismatch
    if safe_str(row.get('autopropose', '')) == 'select' and match_status != 'match':
        return {
            'class': 'Environmental_Contamination',
            'subtype': 'Proposed_No_Match',
            'confidence': 0.90,
            'method': 'Rule',
            'reason': 'Proposed as authenticated but taxonomy mismatch'
        }
    
    # 5B. Non-Insecta (very confident NOT Coleoptera)
    asv_class = safe_str(row.get('blast_tax_class', ''))
    if asv_class and asv_class != 'Insecta' and abundance < THRESHOLDS['env_class_mismatch']:
        return {
            'class': 'Environmental_Contamination',
            'subtype': 'Non_Insect',
            'confidence': 0.90,
            'method': 'Rule',
            'reason': f'Non-insect class: {asv_class}, abundance={abundance:.1%}'
        }
    
    # 5C. Family mismatch (confident it's different family)
    if auth_info:
        row_family = (
            safe_str(row.get('asv_family_v1', '')) or
            safe_str(row.get('asv_family_v2', '')) or
            safe_str(row.get('asv_family_v3', '')) or
            safe_str(row.get('family', ''))
        )
        auth_family = auth_info.get('family', '')
        
        # Only classify as Environmental if BOTH families are known AND different
        if row_family and auth_family and row_family != auth_family:
            # AND abundance is low (not dominant)
            if abundance < THRESHOLDS['env_order_mismatch']:
                return {
                    'class': 'Environmental_Contamination',
                    'subtype': 'Family_Mismatch',
                    'confidence': 0.85,
                    'method': 'Rule',
                    'reason': f'Confident family mismatch: ASV={row_family} vs Auth={auth_family}, abundance={abundance:.1%}'
                }
    
    # 5D. Order mismatch within Coleoptera (if both known)
    if auth_info:
        row_order = safe_str(row.get('blast_tax_order', ''))
        auth_order = auth_info.get('order', '')
        
        # Only if BOTH orders are known AND different AND both are Coleoptera
        if row_order and auth_order and \
           row_order == 'Coleoptera' and auth_order == 'Coleoptera':
            # Different families within Coleoptera
            if row_family and auth_family and row_family != auth_family:
                if abundance < THRESHOLDS['env_order_mismatch']:
                    return {
                        'class': 'Environmental_Contamination',
                        'subtype': 'Different_Coleoptera_Family',
                        'confidence': 0.80,
                        'method': 'Rule',
                        'reason': f'Different Coleoptera family: ASV={row_family} vs Auth={auth_family}'
                    }
    
    # 5E. Very high phylogenetic distance + low abundance
    # (likely different species/genus)
    if phylo_dist > THRESHOLDS['env_phylo_distance'] and \
       abundance < THRESHOLDS['env_phylo_threshold']:
        return {
            'class': 'Environmental_Contamination',
            'subtype': 'High_Divergence',
            'confidence': 0.80,
            'method': 'Rule',
            'reason': f'Very divergent: phylo_dist={phylo_dist:.3f}, abundance={abundance:.1%}'
        }
    
    # ================================================
    # CLASS 6: Uncertain
    # Not confident enough to classify
    # → Let ML predict based on sequence patterns
    # ================================================
    
    hints = []
    
    # Provide hints for debugging
    if reads < THRESHOLDS['very_low_reads']:
        hints.append(f'very_low_reads({reads})')
    elif reads < THRESHOLDS['low_reads']:
        hints.append(f'low_reads({reads})')
    
    if abundance < THRESHOLDS['low_abundance']:
        hints.append(f'low_abundance({abundance:.1%})')
    elif THRESHOLDS['medium_abundance_min'] <= abundance < THRESHOLDS['medium_abundance_max']:
        hints.append(f'medium_abundance({abundance:.1%})')
    
    if quality_score < THRESHOLDS['quality_medium']:
        hints.append(f'low_quality({quality_score:.1f})')
    
    if asv_stat['is_singleton']:
        hints.append('singleton')
    elif asv_stat['is_widespread']:
        hints.append(f'widespread({asv_stat["n_samples"]}_samples)')
    
    if match_status == 'match':
        hints.append('match=match')
    elif match_status == 'mismatch':
        hints.append('match=mismatch')
    
    return {
        'class': 'Uncertain',
        'confidence': 0.50,
        'method': 'Rule',
        'reason': f'No clear pattern: {", ".join(hints) if hints else "insufficient evidence"}',
        'hints': hints,
        'note': 'Will be classified by ML based on sequence patterns'
    }

# ========================
# STEP 6: EXTRACT ML FEATURES
# ========================

def extract_ml_features(row, auth_db, asv_stats):
    """Extract comprehensive features for ML"""
    
    asv_id = safe_str(row.get('asv_id', ''))
    sample_id = safe_str(row.get('project_readfile_id', ''))
    reads = safe_float(row.get('reads', 0))
    total_reads = safe_float(row.get('total_asv_reads', 1))
    abundance = reads / total_reads if total_reads > 0 else 0
    
    asv_stat = asv_stats.get(asv_id, {
        'n_samples': 1,
        'is_singleton': True,
        'is_widespread': False,
        'total_reads_all_samples': reads,
        'avg_reads_per_sample': reads,
        'max_reads': reads
    })
    
    auth_info = get_authenticated_for_sample(sample_id, auth_db)
    
    features = {}
    
    # GROUP 1: ABUNDANCE FEATURES
    features.update({
        'reads': float(reads),
        'log_reads': float(np.log10(reads + 1)),
        'abundance': float(abundance),
        'log_abundance': float(np.log10(abundance + 1e-10)),
        'percentage_reads': safe_float(row.get('percentage_reads', 0)),
    })
    
    # GROUP 2: DISTRIBUTION FEATURES
    features.update({
        'n_samples_present': int(asv_stat['n_samples']),
        'is_singleton': int(asv_stat['is_singleton']),
        'is_widespread': int(asv_stat['is_widespread']),
        'total_reads_all_samples': float(asv_stat['total_reads_all_samples']),
        'log_total_reads_all': float(np.log10(asv_stat['total_reads_all_samples'] + 1)),
        'avg_reads_per_sample': float(asv_stat['avg_reads_per_sample']),
        'max_reads_in_sample': float(asv_stat['max_reads']),
    })
    
    # GROUP 3: AUTHENTICATED REFERENCE
    is_auth_elsewhere = int(asv_id in auth_db and sample_id not in auth_db[asv_id]['sample_ids'])
    n_auth_samples = len(auth_db[asv_id]['sample_ids']) if asv_id in auth_db else 0
    
    features.update({
        'is_authenticated_elsewhere': is_auth_elsewhere,
        'n_authenticated_samples': int(n_auth_samples),
        'has_authenticated_reference': int(auth_info is not None),
    })
    
    # ================================================
    # GROUP 4: TAXONOMY FEATURES (UPDATED!)
    # ADD: match status as critical feature
    # ================================================
    
    family_match = int(check_family_match(row, auth_info)) if auth_info else 0
    match_status = safe_str(row.get('match', ''))
    
    features.update({
        'family_match': family_match,
        
        # NEW: Critical taxonomy features!
        'taxonomy_match': int(match_status == 'match'),      # 1 if match, 0 otherwise
        'taxonomy_mismatch': int(match_status == 'no'),      # 1 if no, 0 otherwise
        'taxonomy_unknown': int(match_status not in ['match', 'no']),  # 1 if unknown
    })
    
    # GROUP 5: PHYLOGENETIC FEATURES
    phylo_dist = safe_float(row.get('Phylogenetic_distance', 1.0))
    
    features.update({
        'phylo_distance': float(phylo_dist),
        'log_phylo_distance': float(np.log10(phylo_dist + 1e-10)),
        'phylo_very_close': int(phylo_dist < THRESHOLDS['phylo_same_species']),
        'phylo_close': int(phylo_dist < THRESHOLDS['phylo_same_genus']),
        'phylo_divergent': int(phylo_dist > THRESHOLDS['phylo_divergent']),
    })
    
    # GROUP 6: QUALITY FEATURES
    tga_validated = safe_bool_to_int(row.get('ANALYSIS_tga_trp_validated', False))
    coi_confidence = safe_str(row.get('ANALYSIS_coi_confidence', ''))
    
    features.update({
        'quality_score': safe_float(row.get('ANALYSIS_sequence_quality_score', 0)),
        'motif_score': safe_float(row.get('ANALYSIS_motif_score', 0)),
        'tga_trp_validated': tga_validated,
        'coi_confidence_high': int(coi_confidence == 'High'),
        'coi_confidence_medium': int(coi_confidence == 'Medium'),
        'coi_confidence_low': int(coi_confidence in ['Low', 'Very_Low']),
    })
    
    # GROUP 7: COMPOSITION FEATURES
    gc_content = safe_float(row.get('ANALYSIS_GC_content', 36))
    
    features.update({
        'GC_content': float(gc_content),
        'AT_content': safe_float(row.get('ANALYSIS_AT_content', 64)),
        'GC_skew': safe_float(row.get('ANALYSIS_GC_skew', 0)),
        'AT_skew': safe_float(row.get('ANALYSIS_AT_skew', 0)),
        'shannon_entropy': safe_float(row.get('ANALYSIS_shannon_entropy', 0)),
        'GC_unusual': int(gc_content > THRESHOLDS['gc_normal_max'] or gc_content < THRESHOLDS['gc_normal_min']),
    })
    
    # GROUP 8: LENGTH FEATURES
    seq_length = safe_int(row.get('ANALYSIS_seq_length', 417))
    
    features.update({
        'seq_length': float(seq_length),
        'protein_length': safe_float(row.get('ANALYSIS_protein_length', 139)),
        'length_abnormal': int(seq_length < THRESHOLDS['length_min'] or seq_length > THRESHOLDS['length_max']),
    })
    
    # GROUP 9: PROTEIN PROPERTIES
    features.update({
        'hydrophobic_percent': safe_float(row.get('ANALYSIS_hydrophobic_percent', 50)),
        'leucine_percent': safe_float(row.get('ANALYSIS_leucine_percent', 15)),
        'aromatic_percent': safe_float(row.get('ANALYSIS_aromatic_percent', 10)),
        'polar_percent': safe_float(row.get('ANALYSIS_polar_percent', 25)),
        'gravy_score': safe_float(row.get('ANALYSIS_gravy_score', 0)),
    })
    
    # GROUP 10: MOTIF FEATURES
    features.update({
        'dna_motif_count': safe_int(row.get('ANALYSIS_dna_motif_count', 0)),
        'protein_motif_count': safe_int(row.get('ANALYSIS_protein_motif_count', 0)),
        'dna_motif_coverage': safe_float(row.get('ANALYSIS_dna_motif_coverage', 0)),
        'protein_motif_coverage': safe_float(row.get('ANALYSIS_protein_motif_coverage', 0)),
    })
    
    # GROUP 11: CODON USAGE
    features.update({
        'codon_diversity': safe_float(row.get('ANALYSIS_codon_diversity', 50)),
        'total_codons': safe_int(row.get('ANALYSIS_total_codons', 139)),
        'unique_codons': safe_int(row.get('ANALYSIS_unique_codons', 50)),
    })
    
    # GROUP 12: DERIVED FEATURES
    # NUMT signal score
    numt_score = 0
    if gc_content > THRESHOLDS['gc_normal_max'] or gc_content < THRESHOLDS['gc_normal_min']:
        numt_score += 1
    if safe_float(row.get('ANALYSIS_motif_score', 100)) < THRESHOLDS['motif_score_degraded']:
        numt_score += 1
    if tga_validated == 0:
        numt_score += 1
    if phylo_dist > THRESHOLDS['phylo_divergent']:
        numt_score += 1
    
    features['numt_signal_score'] = numt_score
    
    # Likely patterns
    features['likely_numt'] = int(
        abundance < THRESHOLDS['low_abundance'] and
        asv_stat['n_samples'] >= 2 and
        numt_score >= 2
    )
    
    features['likely_heteroplasmy'] = int(
        THRESHOLDS['medium_abundance_min'] <= abundance < THRESHOLDS['medium_abundance_max'] and
        safe_float(row.get('ANALYSIS_sequence_quality_score', 0)) >= THRESHOLDS['quality_high'] and
        safe_float(row.get('ANALYSIS_motif_score', 0)) >= THRESHOLDS['quality_high'] and
        phylo_dist < THRESHOLDS['phylo_same_genus']
    )
    
    features['likely_artifact'] = int(
        reads < THRESHOLDS['very_low_reads'] and
        (asv_stat['is_singleton'] or safe_float(row.get('ANALYSIS_sequence_quality_score', 0)) < THRESHOLDS['quality_medium'])
    )
    
    features['likely_cross_cont'] = int(is_auth_elsewhere and abundance < THRESHOLDS['cross_cont_possible'])
    
    features['likely_environmental'] = int(
        family_match == 0 and
        abundance < THRESHOLDS['env_phylo_threshold']
    )
    
    return features

# ========================
# STEP 7: TRAIN ML MODEL
# ========================

def train_ml_model(training_df, features_df):
    """
    Train Random Forest classifier
    NOTE: Excludes Uncertain from training
    """
    
    print("\n" + "="*80)
    print("STEP 7: TRAINING ML MODEL")
    print("="*80)
    
    # Prepare data
    X = features_df.copy()
    y = training_df['phase1_class'].values
    
    print(f"\nTraining data: {len(X):,} samples (excluding Uncertain)")
    print(f"Features: {len(X.columns)}")
    
    # Check for NaN
    nan_count = X.isnull().sum().sum()
    if nan_count > 0:
        print(f"\n⚠️  Found {nan_count} NaN values, filling with 0...")
        X = X.fillna(0)
    
    # Encode labels
    le = LabelEncoder()
    y_encoded = le.fit_transform(y)
    
    print(f"\nClasses ({len(le.classes_)}):")
    class_counts = Counter(y)
    for cls in le.classes_:
        count = class_counts[cls]
        print(f"  {cls:35s}: {count:6,} ({count/len(y)*100:5.2f}%)")
    
    # Remove rare classes (< 2 samples)
    min_samples_required = 2
    rare_classes = [cls for cls, count in class_counts.items() if count < min_samples_required]
    
    if rare_classes:
        print(f"\n⚠️  Removing rare classes (< {min_samples_required} samples):")
        for cls in rare_classes:
            print(f"    - {cls}: {class_counts[cls]} sample(s)")
        
        mask = ~pd.Series(y).isin(rare_classes)
        X = X[mask.values].reset_index(drop=True)
        y = y[mask.values]
        
        le = LabelEncoder()
        y_encoded = le.fit_transform(y)
        
        print(f"\n  After filtering:")
        print(f"    Samples: {len(X):,}")
        print(f"    Classes: {len(le.classes_)}")
        
        class_counts = Counter(y)
        for cls in le.classes_:
            count = class_counts[cls]
            print(f"    {cls:35s}: {count:6,} ({count/len(y)*100:5.2f}%)")
    
    if len(le.classes_) < 2:
        print("\n❌ ERROR: Need at least 2 classes for ML training")
        return None, None, None, None
    
    # Apply SMOTE if needed
    if SMOTE_AVAILABLE:
        min_class_count = min(class_counts.values())
        if min_class_count < 100:
            print(f"\n  Applying SMOTE (min class: {min_class_count})...")
            try:
                k_neighbors = min(5, min_class_count - 1)
                if k_neighbors > 0:
                    smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
                    X_balanced, y_balanced = smote.fit_resample(X, y_encoded)
                    print(f"  After SMOTE: {len(X_balanced):,} samples")
                else:
                    X_balanced, y_balanced = X, y_encoded
            except Exception as e:
                print(f"  ⚠️  SMOTE failed: {e}")
                X_balanced, y_balanced = X, y_encoded
        else:
            X_balanced, y_balanced = X, y_encoded
    else:
        X_balanced, y_balanced = X, y_encoded
    
    # Split
    min_class_in_balanced = min(Counter(y_balanced).values())
    use_stratify = min_class_in_balanced >= 2
    
    if use_stratify:
        print(f"\nSplitting with stratification...")
        X_train, X_test, y_train, y_test = train_test_split(
            X_balanced, y_balanced, test_size=0.2, random_state=42, stratify=y_balanced
        )
    else:
        print(f"\n⚠️  Classes too small for stratification, splitting randomly...")
        X_train, X_test, y_train, y_test = train_test_split(
            X_balanced, y_balanced, test_size=0.2, random_state=42
        )
    
    print(f"\nSplit:")
    print(f"  Train: {len(X_train):,}")
    print(f"  Test:  {len(X_test):,}")
    
    # Scale
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Class weights
    classes = np.unique(y_balanced)
    weights = class_weight.compute_class_weight('balanced', classes=classes, y=y_balanced)
    class_weight_dict = dict(zip(classes, weights))
    
    print(f"\nClass weights:")
    for cls_idx, weight in class_weight_dict.items():
        cls_name = le.inverse_transform([cls_idx])[0]
        print(f"  {cls_name:35s}: {weight:.3f}")
    
    # Train Random Forest
    print(f"\nTraining Random Forest...")
    
    clf = RandomForestClassifier(
        n_estimators=200,
        max_depth=20,
        min_samples_split=10,
        min_samples_leaf=5,
        max_features='sqrt',
        class_weight=class_weight_dict,
        random_state=42,
        n_jobs=-1,
        verbose=0
    )
    
    clf.fit(X_train_scaled, y_train)
    print(f"  ✓ Training complete")
    
    # Evaluate
    y_pred = clf.predict(X_test_scaled)
    test_acc = accuracy_score(y_test, y_pred)
    print(f"\n✓ Test Accuracy: {test_acc:.3f}")
    
    # Cross-validation
    if len(X_train) >= 50:
        print(f"\nCross-validation (5-fold)...")
        try:
            cv_scores = cross_val_score(clf, X_train_scaled, y_train, cv=min(5, len(X_train)//10), n_jobs=-1)
            print(f"  ✓ CV Accuracy: {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
        except Exception as e:
            print(f"  ⚠️  Cross-validation failed: {e}")
    
    # Classification report
    print(f"\nClassification Report:")
    print("-"*80)
    target_names = le.inverse_transform(np.unique(y_test))
    report = classification_report(y_test, y_pred, target_names=target_names, zero_division=0)
    print(report)
    
    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    cm_df = pd.DataFrame(
        cm,
        index=le.inverse_transform(np.unique(y_test)),
        columns=le.inverse_transform(np.unique(y_test))
    )
    
    print(f"\nConfusion Matrix:")
    print(cm_df)
    
    # Feature importance
    feature_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': clf.feature_importances_
    }).sort_values('importance', ascending=False)
    
    print(f"\nTop 20 Important Features:")
    for idx, row in feature_importance.head(20).iterrows():
        print(f"  {row['feature']:35s}: {row['importance']:.5f}")
    
    return clf, scaler, le, feature_importance

# ========================
# STEP 8: ML PREDICTION & FINAL DECISION
# ========================

def validate_with_ml(df, clf, scaler, le, features_df):
    """Use ML to predict all samples (including Uncertain)"""
    
    print("\n" + "="*80)
    print("STEP 8: ML PREDICTION & FINAL CLASSIFICATION")
    print("="*80)
    
    # Scale features
    X_scaled = scaler.transform(features_df)
    
    # Predict
    print("\nPredicting all biological samples...")
    ml_predictions = clf.predict(X_scaled)
    ml_probabilities = clf.predict_proba(X_scaled)
    
    # Decode
    ml_classes = le.inverse_transform(ml_predictions)
    ml_confidences = ml_probabilities.max(axis=1)
    
    # Add to dataframe
    df['ml_prediction'] = ml_classes
    df['ml_confidence'] = ml_confidences
    
    # Add probability for each class
    for i, cls_name in enumerate(le.classes_):
        df[f'ml_prob_{cls_name}'] = ml_probabilities[:, i]
    
    print(f"  ✓ Predictions complete")
    
    # Agreement analysis
    print(f"\nAGREEMENT ANALYSIS:")
    
    agreement = (df['phase1_class'] == df['ml_prediction'])
    agreement_rate = agreement.sum() / len(df) * 100
    print(f"  Overall agreement: {agreement_rate:.1f}%")
    
    print(f"\n  Agreement by Phase 1 confidence:")
    for threshold, label in [(0.90, 'High (≥0.90)'), (0.70, 'Medium (0.70-0.89)'), (0, 'Low (<0.70)')]:
        if threshold == 0.90:
            mask = df['phase1_confidence'] >= 0.90
        elif threshold == 0.70:
            mask = (df['phase1_confidence'] >= 0.70) & (df['phase1_confidence'] < 0.90)
        else:
            mask = df['phase1_confidence'] < 0.70
        
        if mask.sum() > 0:
            agree = (df[mask]['phase1_class'] == df[mask]['ml_prediction']).sum()
            agree_rate = agree / mask.sum() * 100
            print(f"    {label:20s}: {agree_rate:5.1f}% ({mask.sum():,} cases)")
    
    # Disagreements
    print(f"\nDISAGREEMENTS (Phase 1 ≠ ML):")
    disagree = df[~agreement].copy()
    
    if len(disagree) > 0:
        print(f"  Total: {len(disagree):,} ({len(disagree)/len(df)*100:.1f}%)")
        
        disagree_patterns = disagree.groupby(['phase1_class', 'ml_prediction']).size().sort_values(ascending=False)
        
        print(f"\n  Top 10 disagreement patterns:")
        for (rule_cls, ml_cls), count in disagree_patterns.head(10).items():
            pct = count / len(disagree) * 100
            print(f"    Phase1: {rule_cls:25s} → ML: {ml_cls:25s} : {count:5,} ({pct:4.1f}%)")
    
    # Final classification decision
    print(f"\nFINAL CLASSIFICATION DECISION:")
    
    final_classes = []
    final_confidences = []
    decision_methods = []
    
    for idx, row in df.iterrows():
        rule_class = row['phase1_class']
        rule_conf = row['phase1_confidence']
        ml_class = row['ml_prediction']
        ml_conf = row['ml_confidence']
        
        # Decision logic
        if rule_class == ml_class:
            # Agreement
            if rule_conf >= ml_conf:
                final_class = rule_class
                final_conf = min(1.0, rule_conf * 1.05)
                method = 'Rule_Confirmed_by_ML'
            else:
                final_class = ml_class
                final_conf = min(1.0, ml_conf * 1.05)
                method = 'ML_Confirmed_by_Rule'
        
        elif rule_conf >= 0.90:
            # High confidence rule
            if ml_conf >= 0.80:
                # Both confident but disagree
                final_class = rule_class
                final_conf = 0.70
                method = 'Rule_High_ML_Disagrees_REVIEW'
            else:
                # Rule confident, ML uncertain
                final_class = rule_class
                final_conf = rule_conf
                method = 'Rule_High_ML_Uncertain'
        
        elif ml_conf >= 0.80:
            # ML very confident
            final_class = ml_class
            final_conf = ml_conf
            method = 'ML_High_Rule_Lower'
        
        elif rule_conf >= 0.70 and ml_conf >= 0.70:
            # Both moderately confident but disagree
            if ml_conf > rule_conf:
                final_class = ml_class
                final_conf = ml_conf * 0.95
                method = 'ML_Weighted'
            else:
                final_class = rule_class
                final_conf = rule_conf * 0.95
                method = 'Rule_Weighted'
        
        else:
            # Both uncertain → use ML
            final_class = ml_class
            final_conf = ml_conf
            method = 'ML_Both_Uncertain'
        
        final_classes.append(final_class)
        final_confidences.append(final_conf)
        decision_methods.append(method)
    
    df['final_classification'] = final_classes
    df['final_confidence'] = final_confidences
    df['decision_method'] = decision_methods
    
    # Flag for review
    df['needs_review'] = [
        'REVIEW' in method or conf < 0.60
        for method, conf in zip(decision_methods, final_confidences)
    ]
    
    # Summary
    print(f"\n  Decision method distribution:")
    method_counts = pd.Series(decision_methods).value_counts()
    for method, count in method_counts.items():
        print(f"    {method:35s}: {count:6,} ({count/len(df)*100:5.2f}%)")
    
    return df

# ========================
# STEP 9: POST-VALIDATION FIX
# ========================

def post_validation_fix(df, auth_db):
    """Fix validation issues after classification"""
    
    print("\n" + "="*80)
    print("STEP 9: POST-VALIDATION FIXES")
    print("="*80)
    
    n_fixed = 0
    
    # ================================================
    # FIX 1: Cross-Contamination must be in auth_db
    # ================================================
    
    bad_cross = (df['final_classification'] == 'Cross_Contamination') & \
                (~df['asv_id'].isin(auth_db.keys()))
    
    if bad_cross.sum() > 0:
        df.loc[bad_cross, 'final_classification'] = 'Environmental_Contamination'
        df.loc[bad_cross, 'final_confidence'] *= 0.9
        df.loc[bad_cross, 'decision_method'] += '_AUTO_CORRECTED'
        n_fixed += bad_cross.sum()
        print(f"  ✓ Fixed {bad_cross.sum():,} Cross-Contamination → Environmental")
    
    # ================================================
    # FIX 2: Flag suspicious Authenticated (ML-predicted)
    # ================================================
    
    suspicious_auth = (df['final_classification'] == 'Authenticated') & \
                    ((df['autopropose'] != 'select') | 
                    (df['match'] != 'match')) & \
                    (df['final_confidence'] < 0.95)
    
    if suspicious_auth.sum() > 0:
        df.loc[suspicious_auth, 'needs_review'] = True
        print(f"  ✓ Flagged {suspicious_auth.sum():,} suspicious Authenticated for review")
    
    # ================================================
    # FIX 3: Enforce 1 Authenticated per sample
    # CRITICAL: Multiple Authenticated in same sample
    # ================================================
    
    print(f"\n  Checking 1-Authenticated-per-sample rule...")
    
    violations = []
    samples_to_fix = []
    
    for sample_id in df['project_readfile_id'].unique():
        sample_df = df[df['project_readfile_id'] == sample_id]
        
        # Count Authenticated in this sample
        auth_mask = sample_df['final_classification'] == 'Authenticated'
        n_auth = auth_mask.sum()
        
        if n_auth > 1:
            violations.append(sample_id)
            
            # Get all Authenticated ASVs
            auth_asvs = sample_df[auth_mask].copy()
            
            # Priority: Keep Phase 1 Authenticated (autopropose='select')
            phase1_auth = auth_asvs[
                (auth_asvs['autopropose'] == 'select') &
                (auth_asvs['match'] == 'match')
            ]
            
            if len(phase1_auth) > 0:
                # Keep Phase 1 Authenticated (highest priority)
                keep_idx = phase1_auth['reads'].idxmax()  # Keep highest reads
                
                # Reclassify others → Intra_Species_Variant
                for idx, row in auth_asvs.iterrows():
                    if idx != keep_idx:
                        samples_to_fix.append({
                            'index': idx,
                            'sample_id': sample_id,
                            'asv_id': row['asv_id'],
                            'reason': 'Multiple Authenticated - keeping Phase 1 with highest reads'
                        })
            else:
                # No Phase 1 Authenticated → Keep highest reads
                keep_idx = auth_asvs['reads'].idxmax()
                
                for idx, row in auth_asvs.iterrows():
                    if idx != keep_idx:
                        samples_to_fix.append({
                            'index': idx,
                            'sample_id': sample_id,
                            'asv_id': row['asv_id'],
                            'reason': 'Multiple Authenticated - keeping highest reads'
                        })
    
    if len(violations) > 0:
        print(f"  ⚠️  Found {len(violations):,} samples with multiple Authenticated")
        print(f"  → Need to fix {len(samples_to_fix):,} ASVs")
        
        # Apply fixes
        for fix in samples_to_fix:
            idx = fix['index']
            
            df.loc[idx, 'final_classification'] = 'Intra_Species_Variant'
            df.loc[idx, 'final_confidence'] *= 0.85  # Lower confidence
            df.loc[idx, 'decision_method'] += '_AUTO_CORRECTED_MultiAuth'
            df.loc[idx, 'needs_review'] = True
            
            n_fixed += 1
        
        print(f"  ✓ Fixed {len(samples_to_fix):,} cases:")
        print(f"    - Reclassified as Intra_Species_Variant")
        print(f"    - Kept one Authenticated per sample (Phase 1 priority)")
        print(f"    - Flagged for review")
    else:
        print(f"  ✓ All samples have ≤1 Authenticated")
    
    # ================================================
    # FINAL VALIDATION: Re-check rule
    # ================================================
    
    print(f"\n  Final validation:")
    
    violation_count = 0
    for sample_id in df['project_readfile_id'].unique():
        sample_df = df[df['project_readfile_id'] == sample_id]
        n_auth = (sample_df['final_classification'] == 'Authenticated').sum()
        
        if n_auth > 1:
            violation_count += 1
    
    if violation_count == 0:
        print(f"  ✓ All samples have ≤1 Authenticated")
    else:
        print(f"  ❌ ERROR: Still have {violation_count} violations!")
    
    # ================================================
    # FIX 4: Taxonomy Consistency ONLY (No Phylo Loop)
    # Taxonomy = Ground Truth, Phylo = Flag only
    # ================================================
    
    print(f"\n  Checking taxonomy consistency...")
    
    if 'match' not in df.columns:
        print(f"  ⚠️  WARNING: 'match' column not found")
    else:
        is_intra = df['final_classification'] == 'Intra_Species_Variant'
        is_env = df['final_classification'] == 'Environmental_Contamination'
        match = df['match']
        
        # ================================================
        # SIMPLE FIX: Taxonomy only (ONE PASS)
        # ================================================
        
        # Fix 1: Intra-Species MUST have match='match'
        intra_bad = is_intra & (match != 'match')
        n_intra_bad = intra_bad.sum()
        
        if n_intra_bad > 0:
            print(f"  ⚠️  Reclassifying {n_intra_bad:,} Intra-Species → Environmental")
            print(f"      Reason: match≠'match' (different species)")
            
            df.loc[intra_bad, 'final_classification'] = 'Environmental_Contamination'
            df.loc[intra_bad, 'final_confidence'] *= 0.80
            df.loc[intra_bad, 'decision_method'] += '_TAX_CORRECTED'
            df.loc[intra_bad, 'needs_review'] = True
            n_fixed += n_intra_bad
            
            print(f"  ✓ Fixed {n_intra_bad:,}")
        else:
            print(f"  ✓ All Intra-Species have match='match'")
        
        # Fix 2: Environmental with match='match' → flag only (don't reclassify)
        env_match = is_env & (match == 'match')
        n_env_match = env_match.sum()
        
        if n_env_match > 0:
            print(f"  ⚠️  Found {n_env_match:,} Environmental with match='match'")
            print(f"      → Keeping as Environmental but flagging for review")
            print(f"      (These are same species but distant populations/contamination)")
            
            df.loc[env_match, 'needs_review'] = True
            df.loc[env_match, 'final_confidence'] *= 0.85
            
            print(f"  ✓ Flagged {n_env_match:,} for review")
        
        # ================================================
        # VERIFICATION
        # ================================================
        
        print(f"\n  Taxonomy verification:")
        
        final_intra = df['final_classification'] == 'Intra_Species_Variant'
        final_env = df['final_classification'] == 'Environmental_Contamination'
        
        intra_ok = (final_intra & (match == 'match')).sum()
        intra_total = final_intra.sum()
        
        env_no = (final_env & (match != 'match')).sum()
        env_yes = (final_env & (match == 'match')).sum()
        env_total = final_env.sum()
        
        if intra_total > 0:
            if intra_ok == intra_total:
                print(f"  ✅ Intra-Species ({intra_total:,}): 100% match='match'")
            else:
                print(f"  ❌ Intra-Species: {intra_total - intra_ok:,} still with match≠'match'")
        
        if env_total > 0:
            print(f"  ℹ️  Environmental ({env_total:,}):")
            print(f"      - match≠'match': {env_no:,} ({env_no/env_total*100:.1f}%)")
            print(f"      - match='match': {env_yes:,} ({env_yes/env_total*100:.1f}%) ← flagged")
        
        # ================================================
        # PHYLO ANALYSIS (informational only)
        # ================================================
        
        if 'Phylogenetic_distance' in df.columns:
            print(f"\n  Phylogenetic distance analysis (informational):")
            
            phylo = df['Phylogenetic_distance']
            
            if intra_total > 0:
                intra_close = (final_intra & (phylo < 0.05)).sum()
                intra_mod = (final_intra & (phylo >= 0.05) & (phylo < 0.15)).sum()
                intra_dist = (final_intra & (phylo >= 0.15)).sum()
                
                print(f"  Intra-Species ({intra_total:,}):")
                print(f"    <0.05: {intra_close:,} ({intra_close/intra_total*100:.1f}%)")
                print(f"    0.05-0.15: {intra_mod:,} ({intra_mod/intra_total*100:.1f}%)")
                print(f"    ≥0.15: {intra_dist:,} ({intra_dist/intra_total*100:.1f}%)")
                
                if intra_dist > 0:
                    print(f"    Note: {intra_dist:,} distant but same species (kept as Intra)")
            
            if env_total > 0:
                env_close = (final_env & (phylo < 0.05)).sum()
                env_mod = (final_env & (phylo >= 0.05) & (phylo < 0.15)).sum()
                env_dist = (final_env & (phylo >= 0.15)).sum()
                
                print(f"  Environmental ({env_total:,}):")
                print(f"    <0.05: {env_close:,} ({env_close/env_total*100:.1f}%)")
                print(f"    0.05-0.15: {env_mod:,} ({env_mod/env_total*100:.1f}%)")
                print(f"    ≥0.15: {env_dist:,} ({env_dist/env_total*100:.1f}%)")
                
                if env_close > 0:
                    print(f"    Note: {env_close:,} close but different species (kept as Env)")
    
    # ================================================
    # SUMMARY
    # ================================================
    
    if n_fixed == 0 and len(violations) == 0 and suspicious_auth.sum() == 0:
        print(f"\n  ✓ No fixes needed")
    else:
        print(f"\n  ✓ Total fixes applied: {n_fixed:,}")
    
    return df  # ✅ CRITICAL!

# ========================
# STEP 10: CREATE SEQUENCE SUMMARY (NO MIXED)
# ========================

def create_sequence_summary_no_mixed(df):
    """
    Create sequence summary without Mixed category
    Priority: Authenticated > Intra_Species > Environmental > Technical
    """
    
    print("\n" + "="*80)
    print("STEP 10: CREATING SEQUENCE-LEVEL SUMMARY (No Mixed)")
    print("="*80)
    
    summary_data = []
    
    asv_groups = df.groupby('asv_id')
    
    iterator = asv_groups
    if TQDM_AVAILABLE:
        iterator = tqdm(asv_groups, desc="  Processing ASVs")
    
    for asv_id, group in iterator:
        # Count by classification
        class_counts = group['final_classification'].value_counts()
        total_occurrences = len(group)
        
        # Priority decision (NO MIXED)
        if 'Authenticated' in class_counts.index:
            overall_class = 'Authenticated'
            n_auth = class_counts['Authenticated']
            note = f'Authenticated in {n_auth}/{total_occurrences} samples'
            
            # Check for cross-contamination
            if 'Cross_Contamination' in class_counts.index:
                n_cross = class_counts['Cross_Contamination']
                note += f', Cross-contamination in {n_cross} samples'
        
        elif 'Cross_Contamination' in class_counts.index:
            # Cross-Cont only → was authenticated elsewhere
            overall_class = 'Authenticated'
            n_cross = class_counts['Cross_Contamination']
            note = f'Cross-contamination only ({n_cross} samples) - authenticated elsewhere'
        
        elif 'Intra_Species_Variant' in class_counts.index:
            overall_class = 'Intra_Species_Variant'
            n_intra = class_counts['Intra_Species_Variant']
            note = f'Intra-Species Variant in {n_intra}/{total_occurrences} samples'
        
        elif 'Environmental_Contamination' in class_counts.index:
            overall_class = 'Environmental_Contamination'
            n_env = class_counts['Environmental_Contamination']
            note = f'Environmental in {n_env}/{total_occurrences} samples'
        
        elif 'Technical_Artifacts' in class_counts.index:
            overall_class = 'Technical_Artifacts'
            note = f'Technical artifacts only'
        
        elif 'Failed' in class_counts.index:
            overall_class = 'Failed'
            note = f'Failed'
        
        else:
            # Should not happen
            overall_class = 'Unknown'
            note = 'ERROR: Not classified'
        
        # Statistics
        total_reads = group['reads'].sum() if 'reads' in group.columns else 0
        avg_reads = group['reads'].mean() if 'reads' in group.columns else 0
        max_reads = group['reads'].max() if 'reads' in group.columns else 0
        avg_confidence = group['final_confidence'].mean()
        
        # Authenticated samples count
        authenticated_samples = (group['final_classification'] == 'Authenticated').sum()
        cross_cont_samples = (group['final_classification'] == 'Cross_Contamination').sum()
        intra_species_samples = (group['final_classification'] == 'Intra_Species_Variant').sum()
        
        # Get best quality sample
        if 'ANALYSIS_sequence_quality_score' in group.columns:
            quality_scores = group['ANALYSIS_sequence_quality_score'].fillna(0)
            best_idx = quality_scores.idxmax()
        else:
            best_idx = group.index[0]
        
        best_row = group.loc[best_idx]
        
        summary_data.append({
            'asv_id': asv_id,
            
            # Overall classification (NO MIXED)
            'overall_classification': overall_class,
            'classification_note': note,
            
            # Distribution
            'total_samples': total_occurrences,
            'authenticated_samples': authenticated_samples,
            'cross_contamination_samples': cross_cont_samples,
            'intra_species_variant_samples': intra_species_samples,
            
            # Per-sample breakdown
            'classification_distribution': str(dict(class_counts)),
            
            # Abundance
            'total_reads_all_samples': total_reads,
            'avg_reads_per_sample': avg_reads,
            'max_reads_in_sample': max_reads,
            'avg_confidence': avg_confidence,
            
            # Taxonomy (from best sample)
            'family': safe_str(best_row.get('family', '')),
            'subfamily': safe_str(best_row.get('subfamily', '')),
            'blast_tax_family': safe_str(best_row.get('blast_tax_family', '')),
            'blast_tax_order': safe_str(best_row.get('blast_tax_order', '')),
            'blast_tax_class': safe_str(best_row.get('blast_tax_class', '')),
            
            # Quality (from best sample)
            'best_quality_score': safe_float(best_row.get('ANALYSIS_sequence_quality_score', 0)),
            'best_motif_score': safe_float(best_row.get('ANALYSIS_motif_score', 0)),
            'best_coi_confidence': safe_str(best_row.get('ANALYSIS_coi_confidence', '')),
            
            # Sequence info
            'sequence_length': safe_int(best_row.get('ANALYSIS_seq_length', 0)),
            'protein_length': safe_int(best_row.get('ANALYSIS_protein_length', 0)),
            'GC_content': safe_float(best_row.get('ANALYSIS_GC_content', 0)),
            'phylo_distance': safe_float(best_row.get('Phylogenetic_distance', 0)),
            
            # Sequence
            'sequence': safe_str(best_row.get('ANALYSIS_corrected_sequence_full', '')),
            'protein': safe_str(best_row.get('ANALYSIS_corrected_protein_full', ''))
        })
    
    summary_df = pd.DataFrame(summary_data)
    
    print(f"\n  ✓ Created summary for {len(summary_df):,} unique ASVs")
    
    # Check: No Mixed
    if 'Mixed' in summary_df['overall_classification'].values:
        print(f"  ⚠️  WARNING: Found Mixed in summary!")
    else:
        print(f"  ✓ Confirmed: No Mixed category")
    
    return summary_df

# ========================
# STEP 11: VALIDATION & STATISTICS
# ========================

def validate_and_create_statistics(df, summary_df, auth_db):
    """Validate results and create statistics"""
    
    print("\n" + "="*80)
    print("STEP 11: VALIDATION & STATISTICS")
    print("="*80)
    
    issues = []
    
    # VALIDATION CHECKS
    print("\n1. VALIDATION CHECKS:")
    
    # Check 1: Authenticated validation
    auth = df[df['final_classification'] == 'Authenticated']
    
    if 'autopropose' in df.columns and 'match' in df.columns:
        valid_auth = auth[
            (auth['autopropose'] == 'select') &
            (auth['match'] == 'match')
        ]
        
        if len(valid_auth) == len(auth):
            print(f"  ✓ All {len(auth):,} Authenticated have autopropose='select' AND match='match'")
        else:
            invalid_count = len(auth) - len(valid_auth)
            issues.append(f"❌ {invalid_count} Authenticated without proper criteria")
    
    # Check 2: Cross-contamination validation
    cross_cont = df[df['final_classification'] == 'Cross_Contamination']
    
    if 'asv_id' in df.columns:
        not_in_db = cross_cont[~cross_cont['asv_id'].isin(auth_db.keys())]
        
        if len(not_in_db) == 0:
            print(f"  ✓ All {len(cross_cont):,} Cross-Contamination in authenticated database")
        else:
            issues.append(f"❌ {len(not_in_db)} Cross-Contamination not in authenticated database")
    
    # Check 3: Technical Artifacts validation
    tech_art = df[df['final_classification'] == 'Technical_Artifacts']
    
    if 'reads' in df.columns and len(tech_art) > 0:
        tech_low_reads = (tech_art['reads'] < THRESHOLDS['technical_artifacts_reads']).sum()
        if tech_low_reads == len(tech_art):
            print(f"  ✓ All {len(tech_art):,} Technical_Artifacts have reads < {THRESHOLDS['technical_artifacts_reads']}")
        else:
            other_tech = len(tech_art) - tech_low_reads
            print(f"  ✓ Technical_Artifacts: {tech_low_reads:,} with reads <4, {other_tech:,} other types")
    
    # Check 4: No Uncertain in final
    uncertain_count = (df['final_classification'] == 'Uncertain').sum()
    if uncertain_count == 0:
        print(f"  ✓ No Uncertain in final classification (all classified by ML)")
    else:
        issues.append(f"⚠️  Still have {uncertain_count} Uncertain cases")
    
    # Check 5: Phylogenetic + Taxonomy consistency (UPDATED!)
    if 'match' in df.columns and 'Phylogenetic_distance' in df.columns:
        # Check 5A: Intra-Species validation
        intra = df[df['final_classification'] == 'Intra_Species_Variant']
        
        if len(intra) > 0:
            # All should have match='match'
            intra_tax_bad = (intra['match'] != 'match').sum()
            
            # Check phylo distance
            intra_phylo_bad = (intra['Phylogenetic_distance'] >= THRESHOLDS['phylo_intra_max']).sum()
            
            if intra_tax_bad == 0 and intra_phylo_bad == 0:
                print(f"  ✓ All {len(intra):,} Intra-Species valid:")
                print(f"      - match='match': 100%")
                print(f"      - phylo < {THRESHOLDS['phylo_intra_max']}: 100%")
            else:
                if intra_tax_bad > 0:
                    issues.append(f"❌ {intra_tax_bad} Intra-Species with match≠'match'")
                if intra_phylo_bad > 0:
                    issues.append(f"❌ {intra_phylo_bad} Intra-Species with phylo ≥ {THRESHOLDS['phylo_intra_max']}")
            
            # Show phylo distribution
            print(f"      Phylo distance distribution:")
            very_close = (intra['Phylogenetic_distance'] < 0.05).sum()
            moderate = ((intra['Phylogenetic_distance'] >= 0.05) & 
                       (intra['Phylogenetic_distance'] < THRESHOLDS['phylo_intra_max'])).sum()
            print(f"        <0.05: {very_close:,} ({very_close/len(intra)*100:.1f}%)")
            print(f"        0.05-{THRESHOLDS['phylo_intra_max']}: {moderate:,} ({moderate/len(intra)*100:.1f}%)")
        
        # Check 5B: Environmental validation
        env = df[df['final_classification'] == 'Environmental_Contamination']
        
        if len(env) > 0:
            # All should have match≠'match'
            env_tax_bad = (env['match'] == 'match').sum()
            
            # Check phylo distance
            env_phylo_bad = (env['Phylogenetic_distance'] < THRESHOLDS['phylo_env_min']).sum()
            
            if env_tax_bad == 0 and env_phylo_bad == 0:
                print(f"  ✓ All {len(env):,} Environmental valid:")
                print(f"      - match≠'match': 100%")
                print(f"      - phylo ≥ {THRESHOLDS['phylo_env_min']}: 100%")
            else:
                if env_tax_bad > 0:
                    issues.append(f"❌ {env_tax_bad} Environmental with match='match'")
                if env_phylo_bad > 0:
                    issues.append(f"❌ {env_phylo_bad} Environmental with phylo < {THRESHOLDS['phylo_env_min']}")
            
            # Show phylo distribution
            print(f"      Phylo distance distribution:")
            close = (env['Phylogenetic_distance'] < THRESHOLDS['phylo_env_min']).sum()
            moderate_env = ((env['Phylogenetic_distance'] >= THRESHOLDS['phylo_env_min']) & 
                           (env['Phylogenetic_distance'] < 0.15)).sum()
            distant = (env['Phylogenetic_distance'] >= 0.15).sum()
            print(f"        <{THRESHOLDS['phylo_env_min']}: {close:,} ({close/len(env)*100:.1f}%)")
            print(f"        {THRESHOLDS['phylo_env_min']}-0.15: {moderate_env:,} ({moderate_env/len(env)*100:.1f}%)")
            print(f"        ≥0.15: {distant:,} ({distant/len(env)*100:.1f}%)")
    
    # Print issues
    if issues:
        print(f"\n  Issues found:")
        for issue in issues:
            print(f"  {issue}")
    
    # STATISTICS
    print(f"\n2. CLASSIFICATION STATISTICS:")
    
    print(f"\n  Per-Sample Level ({len(df):,} total):")
    
    class_dist = df.groupby('final_classification').agg({
        'asv_id': 'count',
        'final_confidence': 'mean',
    }).rename(columns={
        'asv_id': 'count',
        'final_confidence': 'avg_confidence',
    })
    
    if 'reads' in df.columns:
        reads_sum = df.groupby('final_classification')['reads'].sum()
        class_dist['total_reads'] = reads_sum
        class_dist['reads_percentage'] = (class_dist['total_reads'] / df['reads'].sum()) * 100
    
    class_dist['percentage'] = (class_dist['count'] / len(df)) * 100
    class_dist = class_dist.sort_values('count', ascending=False)
    
    print("\n" + class_dist.to_string())
    
    # Sequence level
    print(f"\n  Sequence Level ({len(summary_df):,} unique ASVs):")
    
    seq_dist = summary_df['overall_classification'].value_counts()
    for cls, count in seq_dist.items():
        pct = count / len(summary_df) * 100
        print(f"    {cls:40s}: {count:6,} ({pct:5.2f}%)")
    
    # Confidence distribution
    print(f"\n3. CONFIDENCE DISTRIBUTION:")
    
    high_conf = (df['final_confidence'] >= 0.80).sum()
    med_conf = ((df['final_confidence'] >= 0.60) & (df['final_confidence'] < 0.80)).sum()
    low_conf = (df['final_confidence'] < 0.60).sum()
    
    print(f"  High (≥0.80): {high_conf:,} ({high_conf/len(df)*100:.1f}%)")
    print(f"  Medium (0.60-0.80): {med_conf:,} ({med_conf/len(df)*100:.1f}%)")
    print(f"  Low (<0.60): {low_conf:,} ({low_conf/len(df)*100:.1f}%)")
    
    # Review needed
    if 'needs_review' in df.columns:
        review_needed = df['needs_review'].sum()
        print(f"\n  Flagged for review: {review_needed:,} ({review_needed/len(df)*100:.1f}%)")
    
    # Create statistics dataframe
    stats_data = []
    
    for cls in class_dist.index:
        class_df = df[df['final_classification'] == cls]
        
        row_data = {
            'Classification': cls,
            'Count': len(class_df),
            'Percentage': len(class_df) / len(df) * 100,
            'Avg_Confidence': class_df['final_confidence'].mean(),
        }
        
        if 'reads' in df.columns:
            row_data['Total_Reads'] = class_df['reads'].sum()
            row_data['Reads_Percentage'] = class_df['reads'].sum() / df['reads'].sum() * 100
            row_data['Avg_Reads_Per_Sample'] = class_df['reads'].mean()
        
        if 'needs_review' in df.columns:
            row_data['Needs_Review'] = class_df['needs_review'].sum()
        
        stats_data.append(row_data)
    
    stats_df = pd.DataFrame(stats_data)
    
    return stats_df, issues

# ========================
# STEP 12: EXPORT RESULTS
# ========================

def export_results(df, summary_df, stats_df, feature_importance, auth_db, issues):
    """Export all results to files"""
    
    print("\n" + "="*80)
    print("STEP 12: EXPORTING RESULTS")
    print("="*80)
    
    # 1. Main classified file
    print(f"\n1. Main Classification File:")
    df.to_csv(OUTPUT_CLASSIFIED, index=False)
    print(f"   ✓ {OUTPUT_CLASSIFIED}")
    print(f"     {len(df):,} rows × {len(df.columns)} columns")
    
    # 2. Sequence summary
    print(f"\n2. Sequence Summary:")
    summary_df.to_csv(OUTPUT_SEQUENCE_SUMMARY, index=False)
    print(f"   ✓ {OUTPUT_SEQUENCE_SUMMARY}")
    print(f"     {len(summary_df):,} unique ASVs")
    
    # 3. Statistics
    print(f"\n3. Classification Statistics:")
    stats_df.to_csv(OUTPUT_STATISTICS, index=False)
    print(f"   ✓ {OUTPUT_STATISTICS}")
    
    # 4. Feature importance
    if feature_importance is not None:
        print(f"\n4. Feature Importance:")
        feature_importance.to_csv(OUTPUT_FEATURE_IMPORTANCE, index=False)
        print(f"   ✓ {OUTPUT_FEATURE_IMPORTANCE}")
    
    # 5. Detailed report
    print(f"\n5. Detailed Report:")
    create_detailed_report(df, summary_df, stats_df, auth_db, issues)
    print(f"   ✓ {OUTPUT_REPORT}")

# ========================
# CREATE DETAILED REPORT
# ========================

def create_detailed_report(df, summary_df, stats_df, auth_db, issues):
    """Create comprehensive text report"""
    
    with open(OUTPUT_REPORT, 'w') as f:
        f.write("="*80 + "\n")
        f.write("ASV CLASSIFICATION PIPELINE - COMPREHENSIVE REPORT\n")
        f.write("Version 3.0 - Intra-Species Variants & No Mixed\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Generated: {pd.Timestamp.now()}\n\n")
        
        # Dataset summary
        f.write("DATASET SUMMARY\n")
        f.write("-"*80 + "\n")
        f.write(f"Total ASV records: {len(df):,}\n")
        
        if 'project_readfile_id' in df.columns:
            f.write(f"Unique samples: {df['project_readfile_id'].nunique():,}\n")
        
        if 'asv_id' in df.columns:
            f.write(f"Unique ASV IDs: {df['asv_id'].nunique():,}\n")
        
        if 'reads' in df.columns:
            f.write(f"Total reads: {df['reads'].sum():,}\n")
        
        f.write(f"Authenticated ASVs in database: {len(auth_db):,}\n\n")
        
        # Methodology
        f.write("METHODOLOGY\n")
        f.write("-"*80 + "\n")
        f.write("Three-Phase Classification Strategy:\n\n")
        
        f.write("Pre-filter: Technical Artifacts (reads < 4)\n")
        f.write("  - Applied BEFORE Phase 1\n")
        f.write("  - Removes low-read technical noise\n\n")
        
        f.write("Phase 1: Rule-Based Classification\n")
        f.write("  - Authenticated: autopropose='select' AND match='match'\n")
        f.write("  - Cross-Contamination: In authenticated_db but not primary\n")
        f.write("  - Intra-Species Variants: Co-occurrence ≥2 times with same auth ASV\n")
        f.write("  - Environmental: Taxonomy mismatch or high divergence\n")
        f.write("  - Technical Artifacts: Abnormal length, low quality\n")
        f.write("  - Uncertain: No clear pattern (will be classified by ML)\n\n")
        
        f.write("Phase 2: ML Classification\n")
        f.write("  - Trained on strong evidence only (excludes Uncertain)\n")
        f.write("  - Predicts ALL samples including Uncertain\n")
        f.write("  - Random Forest with biological features\n")
        f.write("  - NO Uncertain in final output\n\n")
        
        f.write("Phase 3: Sequence Summary\n")
        f.write("  - Priority-based (no Mixed category)\n")
        f.write("  - Priority: Authenticated > Intra-Species > Environmental > Technical\n")
        f.write("  - Cross-Contamination counted as Authenticated (was authenticated elsewhere)\n\n")
        
        # Important notes
        f.write("IMPORTANT NOTES\n")
        f.write("-"*80 + "\n")
        f.write("  • Technical_Artifacts (reads < 4) pre-filtered before ML\n")
        f.write("  • Intra-Species Variants detected by co-occurrence pattern\n")
        f.write("  • Cannot distinguish Heteroplasmy vs NUMTs without genomic data\n")
        f.write("  • All sequences passed filtertranslate (no internal stops)\n")
        f.write("  • Same ASV can be different class in different samples\n")
        f.write("  • Cross-Contamination only for authenticated ASVs\n")
        f.write("  • NO Mixed category in sequence summary\n")
        f.write("  • NO Uncertain in final classification\n\n")
        
        # Validation
        f.write("VALIDATION CHECKS\n")
        f.write("-"*80 + "\n")
        if issues:
            f.write("Issues found:\n")
            for issue in issues:
                f.write(f"  {issue}\n")
        else:
            f.write("  ✓ All validation checks passed\n")
        f.write("\n")
        
        # Classification statistics
        f.write("CLASSIFICATION STATISTICS\n")
        f.write("-"*80 + "\n\n")
        
        f.write("Per-Sample Level:\n")
        f.write(stats_df.to_string(index=False))
        f.write("\n\n")
        
        f.write("Sequence Level:\n")
        if 'overall_classification' in summary_df.columns:
            seq_dist = summary_df['overall_classification'].value_counts()
            for cls, count in seq_dist.items():
                pct = count / len(summary_df) * 100
                f.write(f"  {cls:40s}: {count:6,} ({pct:5.2f}%)\n")
        f.write("\n")
        
        # Detailed class analysis
        f.write("DETAILED CLASS ANALYSIS\n")
        f.write("-"*80 + "\n\n")
        
        for cls in stats_df['Classification']:
            class_df = df[df['final_classification'] == cls]
            
            f.write(f"\n{cls}\n")
            f.write("-" * 60 + "\n")
            f.write(f"Count: {len(class_df):,}\n")
            f.write(f"Percentage: {len(class_df)/len(df)*100:.2f}%\n")
            f.write(f"Avg Confidence: {class_df['final_confidence'].mean():.3f}\n")
            
            if 'reads' in class_df.columns:
                f.write(f"Total Reads: {class_df['reads'].sum():,}\n")
                f.write(f"Avg Reads: {class_df['reads'].mean():.1f}\n")
            
            if 'needs_review' in class_df.columns:
                f.write(f"Needs Review: {class_df['needs_review'].sum():,}\n")
        
        # Thresholds
        f.write("\n" + "="*80 + "\n")
        f.write("THRESHOLDS USED\n")
        f.write("="*80 + "\n\n")
        
        for key, value in THRESHOLDS.items():
            f.write(f"  {key:35s}: {value}\n")
        
        # Recommendations
        f.write("\n" + "="*80 + "\n")
        f.write("RECOMMENDATIONS\n")
        f.write("="*80 + "\n\n")
        
        f.write("1. Use Authenticated sequences for:\n")
        f.write("   • Phylogenetic analysis\n")
        f.write("   • Diversity studies\n")
        f.write("   • Species identification\n\n")
        
        f.write("2. Intra-Species Variants:\n")
        f.write("   • Could be Heteroplasmy OR NUMTs\n")
        f.write("   • Require genomic validation to distinguish\n")
        f.write("   • Exclude from phylogenetic analyses\n\n")
        
        f.write("3. Cross-Contamination:\n")
        f.write("   • Review patterns to identify sources\n")
        f.write("   • Consider lab protocol improvements\n\n")
        
        f.write("4. Environmental Contamination:\n")
        f.write("   • Verify taxonomy assignments\n")
        f.write("   • May include prey, parasites, or environmental DNA\n\n")
        
        f.write("5. Technical Artifacts:\n")
        f.write("   • Exclude from all analyses\n")
        f.write("   • Represent sequencing noise\n\n")
        
        # Footer
        f.write("="*80 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*80 + "\n")

# ========================
# MAIN PIPELINE
# ========================

def main():
    """Main classification pipeline"""
    
    print("\n" + "="*80)
    print("ASV CLASSIFICATION PIPELINE")
    print("Version 3.0 - Intra-Species Variants & No Mixed")
    print("="*80)
    
    print(f"\nStarted: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Load data
    print(f"\nLoading data from: {INPUT_FILE}")
    
    try:
        df = pd.read_csv(INPUT_FILE, low_memory=False)
        print(f"  ✓ Loaded {len(df):,} rows")
        print(f"  ✓ Columns: {len(df.columns)}")
    except Exception as e:
        print(f"\n❌ ERROR loading data: {e}")
        sys.exit(1)
    
    # Build authenticated database
    auth_db = build_authenticated_database(df)
    
    if len(auth_db) == 0:
        print("\n⚠️  WARNING: No authenticated ASVs found!")
    
    # Calculate sample statistics
    asv_stats = calculate_sample_statistics(df)
    
    # PRE-FILTER: TECHNICAL ARTIFACTS
    print("\n" + "="*80)
    print("PRE-FILTER: TECHNICAL ARTIFACTS (reads < 4)")
    print("="*80)
    
    tech_mask = df['reads'] < THRESHOLDS['technical_artifacts_reads']
    n_tech = tech_mask.sum()
    
    print(f"\nFiltering reads < {THRESHOLDS['technical_artifacts_reads']}:")
    print(f"  Technical_Artifacts: {n_tech:,} ({n_tech/len(df)*100:.1f}%)")
    print(f"  Biological samples:  {(~tech_mask).sum():,} ({(~tech_mask).sum()/len(df)*100:.1f}%)")
    
    # Mark Technical_Artifacts
    df.loc[tech_mask, 'final_classification'] = 'Technical_Artifacts'
    df.loc[tech_mask, 'final_confidence'] = 0.95
    df.loc[tech_mask, 'phase1_class'] = 'Technical_Artifacts'
    df.loc[tech_mask, 'phase1_confidence'] = 0.95
    df.loc[tech_mask, 'phase1_method'] = 'Pre-filter'
    df.loc[tech_mask, 'phase1_reason'] = f'reads < {THRESHOLDS["technical_artifacts_reads"]} (minimum threshold)'
    df.loc[tech_mask, 'decision_method'] = 'Pre-filter'
    df.loc[tech_mask, 'needs_review'] = False
    
    # Get clean data
    clean_df = df[~tech_mask].copy()
    clean_indices = clean_df.index
    
    print(f"\n  ✓ Pre-filter complete")
    print(f"  ✓ Continuing with {len(clean_df):,} biological samples")
    
    # Detect Intra-Species Variants
    intra_species_db = detect_intra_species_variants(clean_df, auth_db)
    
    # PHASE 1: RULE-BASED CLASSIFICATION
    print("\n" + "="*80)
    print("PHASE 1: RULE-BASED CLASSIFICATION (BIOLOGICAL)")
    print("="*80)
    
    phase1_results = []
    
    iterator = clean_df.iterrows()
    if TQDM_AVAILABLE:
        iterator = tqdm(clean_df.iterrows(), total=len(clean_df), desc="  Classifying")
    
    for idx, row in iterator:
        result = classify_phase1_rule_based(row, auth_db, asv_stats, intra_species_db)
        phase1_results.append(result)
    
    # Add to clean dataframe
    clean_df['phase1_class'] = [r['class'] for r in phase1_results]
    clean_df['phase1_confidence'] = [r['confidence'] for r in phase1_results]
    clean_df['phase1_method'] = [r['method'] for r in phase1_results]
    clean_df['phase1_reason'] = [r['reason'] for r in phase1_results]
    clean_df['phase1_subtype'] = [r.get('subtype', '') for r in phase1_results]
    
    # Update main dataframe
    for col in ['phase1_class', 'phase1_confidence', 'phase1_method', 'phase1_reason', 'phase1_subtype']:
        df.loc[clean_indices, col] = clean_df[col]
    
    # Phase 1 summary
    print(f"\nPhase 1 Summary (Biological classes):")
    
    class_dist = clean_df.groupby('phase1_class').agg({
        'asv_id': 'count',
        'phase1_confidence': 'mean'
    }).rename(columns={'asv_id': 'count', 'phase1_confidence': 'avg_confidence'})
    
    class_dist['percentage'] = (class_dist['count'] / len(clean_df)) * 100
    class_dist = class_dist.sort_values('count', ascending=False)
    
    print("\n" + class_dist.to_string())
    
    # Confidence levels
    high_conf = (clean_df['phase1_confidence'] >= 0.90).sum()
    med_conf = ((clean_df['phase1_confidence'] >= 0.70) & (clean_df['phase1_confidence'] < 0.90)).sum()
    low_conf = (clean_df['phase1_confidence'] < 0.70).sum()
    
    print(f"\nConfidence Distribution (Biological):")
    print(f"  High (≥0.90):       {high_conf:6,} ({high_conf/len(clean_df)*100:5.2f}%)")
    print(f"  Medium (0.70-0.89): {med_conf:6,} ({med_conf/len(clean_df)*100:5.2f}%)")
    print(f"  Low (<0.70):        {low_conf:6,} ({low_conf/len(clean_df)*100:5.2f}%)")
    
    # PHASE 2: ML TRAINING & PREDICTION
    print("\n" + "="*80)
    print("PHASE 2: ML TRAINING & PREDICTION (BIOLOGICAL)")
    print("="*80)
    
    # Extract features
    print("\nExtracting ML features (biological samples)...")
    
    bio_features = []
    
    iterator = clean_df.iterrows()
    if TQDM_AVAILABLE:
        iterator = tqdm(clean_df.iterrows(), total=len(clean_df), desc="  Extracting features")
    
    for idx, row in iterator:
        features = extract_ml_features(row, auth_db, asv_stats)
        bio_features.append(features)
    
    features_df = pd.DataFrame(bio_features)
    features_df.index = clean_df.index
    
    print(f"\n  ✓ Extracted {len(features_df.columns)} features")
    
    # Training data: confidence ≥0.70 AND class ≠ Uncertain
    training_mask = (
        (clean_df['phase1_confidence'] >= 0.70) &
        (clean_df['phase1_class'] != 'Uncertain')
    )
    
    training_df = clean_df[training_mask].copy()
    training_features = features_df[training_mask]
    
    print(f"\nTraining samples: {len(training_df):,} (confidence ≥0.70, excluding Uncertain)")
    print(f"  Excluded Uncertain: {(clean_df['phase1_class'] == 'Uncertain').sum():,}")
    
    if len(training_df) < 100:
        print("\n⚠️  WARNING: Very few training samples!")
    
    # Train model
    clf, scaler, le, feature_importance = train_ml_model(training_df, training_features)
    
    if clf is None:
        print("\n❌ ML training failed. Using Phase 1 results only.")
        df.loc[clean_indices, 'final_classification'] = clean_df['phase1_class']
        df.loc[clean_indices, 'final_confidence'] = clean_df['phase1_confidence']
        df.loc[clean_indices, 'decision_method'] = 'Rule_Only'
        df.loc[clean_indices, 'needs_review'] = clean_df['phase1_confidence'] < 0.70
    else:
        # Predict all biological samples (including Uncertain)
        clean_df = validate_with_ml(clean_df, clf, scaler, le, features_df)
        
        # Update main dataframe
        for col in ['ml_prediction', 'ml_confidence', 'final_classification', 
                    'final_confidence', 'decision_method', 'needs_review']:
            if col in clean_df.columns:
                df.loc[clean_indices, col] = clean_df[col]
        
        # Add ML probabilities
        for col in clean_df.columns:
            if col.startswith('ml_prob_'):
                df.loc[clean_indices, col] = clean_df[col]
        
        # Save model
        try:
            joblib.dump({
                'clf': clf,
                'scaler': scaler,
                'le': le,
                'feature_names': features_df.columns.tolist(),
                'thresholds': THRESHOLDS
            }, OUTPUT_MODEL)
            print(f"\n✓ Model saved: {OUTPUT_MODEL}")
        except Exception as e:
            print(f"\n⚠️  Could not save model: {e}")
    
    # Post-validation fix
    df = post_validation_fix(df, auth_db)
    
    # Create sequence summary (NO MIXED)
    summary_df = create_sequence_summary_no_mixed(df)
    
    # Validation and statistics
    stats_df, issues = validate_and_create_statistics(df, summary_df, auth_db)
    
    # Export results
    export_results(df, summary_df, stats_df, feature_importance, auth_db, issues)
    
    # FINAL SUMMARY
    print("\n" + "="*80)
    print("FINAL SUMMARY")
    print("="*80)
    
    print(f"\nDataset:")
    print(f"  Total records: {len(df):,}")
    
    if 'project_readfile_id' in df.columns:
        print(f"  Unique samples: {df['project_readfile_id'].nunique():,}")
    
    if 'asv_id' in df.columns:
        print(f"  Unique ASVs: {df['asv_id'].nunique():,}")
    
    if 'reads' in df.columns:
        print(f"  Total reads: {df['reads'].sum():,.0f}")
    
    print(f"\nPre-filter Results:")
    print(f"  Technical_Artifacts: {n_tech:,} ({n_tech/len(df)*100:.1f}%)")
    print(f"  Biological samples:  {len(clean_df):,} ({len(clean_df)/len(df)*100:.1f}%)")
    
    print(f"\nFinal Classification Results:")
    if 'final_classification' in df.columns:
        final_dist = df['final_classification'].value_counts()
        for cls, count in final_dist.items():
            pct = count / len(df) * 100
            avg_conf = df[df['final_classification'] == cls]['final_confidence'].mean()
            print(f"  {cls:40s}: {count:6,} ({pct:5.2f}%) [conf={avg_conf:.3f}]")
    
    # Check for Uncertain
    uncertain_final = (df['final_classification'] == 'Uncertain').sum()
    if uncertain_final > 0:
        print(f"\n  ⚠️  WARNING: Still have {uncertain_final:,} Uncertain in final!")
    else:
        print(f"\n  ✓ No Uncertain in final classification")
    
    if 'needs_review' in df.columns:
        review_count = df['needs_review'].sum()
        print(f"\nQuality Control:")
        print(f"  Needs review: {review_count:,} ({review_count/len(df)*100:.1f}%)")
    
    if 'final_confidence' in df.columns:
        high_conf_final = (df['final_confidence'] >= 0.80).sum()
        print(f"  High confidence (≥0.80): {high_conf_final:,}")
    
    print(f"\nOutput Files:")
    print(f"  ✓ {OUTPUT_CLASSIFIED}")
    print(f"  ✓ {OUTPUT_SEQUENCE_SUMMARY}")
    print(f"  ✓ {OUTPUT_STATISTICS}")
    print(f"  ✓ {OUTPUT_REPORT}")
    if feature_importance is not None:
        print(f"  ✓ {OUTPUT_FEATURE_IMPORTANCE}")
    if clf is not None:
        print(f"  ✓ {OUTPUT_MODEL}")
    
    print(f"\n{'='*80}")
    print("✓ PIPELINE COMPLETE!")
    print("="*80)
    
    print(f"\nFinished: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    print(f"\nNext steps:")
    print(f"  1. Review flagged sequences in: {OUTPUT_CLASSIFIED}")
    print(f"  2. Check detailed report: {OUTPUT_REPORT}")
    print(f"  3. Use Authenticated sequences for downstream analysis")
    print(f"  4. Intra-Species Variants require genomic validation")
    print(f"  5. Consider reads < 4 as technical noise (excluded)")
    
    print("\n" + "="*80 + "\n")

# ========================
# RUN PIPELINE
# ========================

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠️  Pipeline interrupted by user")
        print("Exiting gracefully...")
        sys.exit(1)
    except Exception as e:
        print("\n\n" + "="*80)
        print("❌ CRITICAL ERROR OCCURRED")
        print("="*80)
        print(f"\nError Type: {type(e).__name__}")
        print(f"Error Message: {str(e)}")
        print("\nFull Traceback:")
        print("-"*80)
        import traceback
        traceback.print_exc()
        print("-"*80)
        print("\nTroubleshooting Tips:")
        print("  1. Check if input file exists and is readable")
        print("  2. Verify all required columns are present")
        print("  3. Ensure sufficient disk space for output")
        print("  4. Check write permissions for output directory")
        print("  5. Review error message above for specific issue")
        print("\n" + "="*80 + "\n")
        sys.exit(1)


ASV CLASSIFICATION PIPELINE
Version 3.0 - Intra-Species Variants & No Mixed

Started: 2025-10-24 12:49:55

Loading data from: /Users/sarawut/Desktop/Manuscript_ASV_selection/data_analysis/sequences_analysis/ASV_Complete_Analysis.csv
  ✓ Loaded 175,955 rows
  ✓ Columns: 149

STEP 1: BUILDING AUTHENTICATED DATABASE

Criteria: autopropose='select' AND match='match'

  Found 15,901 authenticated instances

  ✓ Unique authenticated ASVs: 12,922
  ✓ Total authenticated instances: 15,901

  Authenticated ASVs per sample:
      1 sample(s): 11,133 ASVs
      2 sample(s): 1,244 ASVs
      3 sample(s):   314 ASVs
      4 sample(s):   104 ASVs
      5 sample(s):    55 ASVs
      6 sample(s):    25 ASVs
      7 sample(s):    14 ASVs
      8 sample(s):     7 ASVs
      9 sample(s):     5 ASVs
     10 sample(s):     5 ASVs

STEP 2: CALCULATING SAMPLE STATISTICS


  Processing ASVs: 100%|██████████| 64544/64544 [00:02<00:00, 21687.30it/s]



  ✓ Calculated statistics for 64,544 unique ASVs

PRE-FILTER: TECHNICAL ARTIFACTS (reads < 4)

Filtering reads < 4:
  Technical_Artifacts: 84,551 (48.1%)
  Biological samples:  91,404 (51.9%)

  ✓ Pre-filter complete
  ✓ Continuing with 91,404 biological samples

STEP 3: DETECTING INTRA-SPECIES VARIANTS

Method: Co-occurrence pattern analysis
Threshold: ≥2 co-occurrences with same authenticated ASV

  Found 61,577 unique (authenticated_ASV, secondary_ASV) pairs

  Strong evidence Intra-Species Variants:
    Secondary ASVs: 1,790
    Co-occurrence counts:
      Min: 2
      Max: 15
      Avg: 2.5

    Distribution:
      2 co-occurrences: 1,378 secondary ASVs
      3 co-occurrences: 256 secondary ASVs
      4 co-occurrences: 57 secondary ASVs
      5 co-occurrences: 39 secondary ASVs
      6 co-occurrences: 18 secondary ASVs
      7 co-occurrences: 13 secondary ASVs
      8 co-occurrences: 11 secondary ASVs
      9 co-occurrences: 6 secondary ASVs
      10 co-occurrences: 4 secondary A

  Classifying: 100%|██████████| 91404/91404 [00:16<00:00, 5708.10it/s] 



Phase 1 Summary (Biological classes):

                             count  avg_confidence  percentage
phase1_class                                                  
Uncertain                    43807        0.500000   47.926787
Authenticated                15901        1.000000   17.396394
Environmental_Contamination  13767        0.823524   15.061704
Cross_Contamination          13163        0.940682   14.400901
Intra_Species_Variant         3359        0.735368    3.674894
Technical_Artifacts           1407        0.797655    1.539320

Confidence Distribution (Biological):
  High (≥0.90):       29,596 (32.38%)
  Medium (0.70-0.89): 18,001 (19.69%)
  Low (<0.70):        43,807 (47.93%)

PHASE 2: ML TRAINING & PREDICTION (BIOLOGICAL)

Extracting ML features (biological samples)...


  Extracting features: 100%|██████████| 91404/91404 [00:27<00:00, 3334.94it/s]



  ✓ Extracted 57 features

Training samples: 47,597 (confidence ≥0.70, excluding Uncertain)
  Excluded Uncertain: 43,807

STEP 7: TRAINING ML MODEL

Training data: 47,597 samples (excluding Uncertain)
Features: 57

Classes (5):
  Authenticated                      : 15,901 (33.41%)
  Cross_Contamination                : 13,163 (27.66%)
  Environmental_Contamination        : 13,767 (28.92%)
  Intra_Species_Variant              :  3,359 ( 7.06%)
  Technical_Artifacts                :  1,407 ( 2.96%)

Splitting with stratification...

Split:
  Train: 38,077
  Test:  9,520

Class weights:
  Authenticated                      : 0.599
  Cross_Contamination                : 0.723
  Environmental_Contamination        : 0.691
  Intra_Species_Variant              : 2.834
  Technical_Artifacts                : 6.766

Training Random Forest...
  ✓ Training complete

✓ Test Accuracy: 0.995

Cross-validation (5-fold)...
  ✓ CV Accuracy: 0.996 ± 0.001

Classification Report:
------------------------

  Processing ASVs: 100%|██████████| 64544/64544 [00:21<00:00, 2987.96it/s]



  ✓ Created summary for 64,544 unique ASVs
  ✓ Confirmed: No Mixed category

STEP 11: VALIDATION & STATISTICS

1. VALIDATION CHECKS:
  ✓ All 15,901 Authenticated have autopropose='select' AND match='match'
  ✓ All 13,163 Cross-Contamination in authenticated database
  ✓ Technical_Artifacts: 84,551 with reads <4, 17,513 other types
  ✓ No Uncertain in final classification (all classified by ML)
      Phylo distance distribution:
        <0.05: 9,573 (48.3%)
        0.05-0.15: 7,129 (35.9%)
      Phylo distance distribution:
        <0.05: 3,704 (14.8%)
        0.05-0.15: 1,108 (4.4%)
        ≥0.15: 20,184 (80.7%)

  Issues found:
  ❌ 3129 Intra-Species with phylo ≥ 0.15
  ❌ 515 Environmental with match='match'
  ❌ 3704 Environmental with phylo < 0.05

2. CLASSIFICATION STATISTICS:

  Per-Sample Level (175,955 total):

                              count  avg_confidence  total_reads  reads_percentage  percentage
final_classification                                                       