In [None]:
# ==================== Section 4: Comprehensive Accuracy Metrics and Uncertainty Analysis ====================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (confusion_matrix, classification_report,
                           accuracy_score, f1_score, precision_score, recall_score,
                           roc_auc_score, roc_curve, precision_recall_curve,
                           average_precision_score, cohen_kappa_score, matthews_corrcoef)
import warnings
warnings.filterwarnings('ignore')

print("Comprehensive Accuracy Metrics and Uncertainty Analysis")
print("="*60)

# Load real data from the trained models (assuming you have the trainer object from previous section)
# If you don't have the trainer object, you'll need to load your data differently
print("Loading real data from trained models...")

# ==================== Load Real Data ====================
# Assuming you have the trainer object from Section 5
# If not, you'll need to modify this part to load your own data

try:
    # Try to use trainer object from previous section
    from main import trainer  # Adjust import as needed

    # Collect all test results from all trained ecoregions
    all_y_true = []
    all_y_pred_proba = []
    all_y_pred = []

    for ecoregion_id, result in trainer.results.items():
        if 'y_test' in result and 'y_pred_proba' in result:
            all_y_true.extend(result['y_test'])
            all_y_pred_proba.extend(result['y_pred_proba'])
            all_y_pred.extend(result['y_pred'])

    y_true = np.array(all_y_true)
    y_pred_proba = np.array(all_y_pred_proba)
    y_pred = np.array(all_y_pred)

    print(f"Loaded real data: {len(y_true)} samples")
    print(f"Class distribution: Stable={np.sum(y_true==0)}, Disturbance={np.sum(y_true==1)}")

except ImportError:
    print("Warning: Could not import trainer object from previous section.")
    print("Loading data from CSV file instead...")

    # Alternative: Load from CSV if you have saved predictions
    try:
        # Modify this path to your actual data file
        data_path = '/content/drive/MyDrive/predictions.csv'
        df = pd.read_csv(data_path)

        # Assuming columns: 'y_true', 'y_pred_proba', 'y_pred'
        y_true = df['y_true'].values
        y_pred_proba = df['y_pred_proba'].values
        y_pred = df['y_pred'].values

        print(f"Loaded data from CSV: {len(y_true)} samples")
        print(f"Class distribution: Stable={np.sum(y_true==0)}, Disturbance={np.sum(y_true==1)}")

    except Exception as e:
        print(f"Error loading data: {e}")
        print("Using sample data for demonstration...")

        # Fallback to sample data if real data not available
        np.random.seed(42)
        n_samples = 1000

        # Generate realistic sample data
        y_true = np.random.choice([0, 1], n_samples, p=[0.7, 0.3])
        y_pred_proba = np.zeros(n_samples)

        for i in range(n_samples):
            if y_true[i] == 0:
                y_pred_proba[i] = np.random.beta(2, 8)
            else:
                y_pred_proba[i] = np.random.beta(8, 2)

        y_pred = (y_pred_proba >= 0.5).astype(int)
        print(f"Using sample data: {n_samples} samples")

# Apply different thresholds
thresholds = [0.3, 0.5, 0.7]
results = {}

for threshold in thresholds:
    y_pred_thresh = (y_pred_proba >= threshold).astype(int)

    # Calculate comprehensive metrics
    cm = confusion_matrix(y_true, y_pred_thresh)

    # Handle cases where confusion matrix might not be 2x2
    if cm.shape == (2, 2):
        TN, FP, FN, TP = cm.ravel()
    else:
        # If only one class present, handle accordingly
        TN = cm[0, 0] if cm.shape[0] > 0 else 0
        FP = cm[0, 1] if cm.shape[1] > 1 else 0
        FN = cm[1, 0] if cm.shape[0] > 1 else 0
        TP = cm[1, 1] if cm.shape[0] > 1 and cm.shape[1] > 1 else 0

    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred_thresh)
    precision = precision_score(y_true, y_pred_thresh, zero_division=0)
    recall = recall_score(y_true, y_pred_thresh, zero_division=0)
    f1 = f1_score(y_true, y_pred_thresh, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred_thresh)
    mcc = matthews_corrcoef(y_true, y_pred_thresh)

    # Producer's Accuracy = 1 - Omission Error
    pa_0 = TN / (TN + FP) if (TN + FP) > 0 else 0  # Stable class
    pa_1 = TP / (TP + FN) if (TP + FN) > 0 else 0  # Disturbance class

    # User's Accuracy = 1 - Commission Error
    ua_0 = TN / (TN + FN) if (TN + FN) > 0 else 0  # Stable class
    ua_1 = TP / (TP + FP) if (TP + FP) > 0 else 0  # Disturbance class

    # Omission Error
    oe_0 = 1 - pa_0  # Stable class
    oe_1 = 1 - pa_1  # Disturbance class

    # Commission Error
    ce_0 = 1 - ua_0  # Stable class
    ce_1 = 1 - ua_1  # Disturbance class

    # Overall accuracy
    mean_pa = (pa_0 + pa_1) / 2
    mean_ua = (ua_0 + ua_1) / 2

    results[threshold] = {
        'threshold': threshold,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'kappa': kappa,
        'mcc': mcc,
        'pa_0': pa_0,
        'pa_1': pa_1,
        'ua_0': ua_0,
        'ua_1': ua_1,
        'oe_0': oe_0,
        'oe_1': oe_1,
        'ce_0': ce_0,
        'ce_1': ce_1,
        'mean_pa': mean_pa,
        'mean_ua': mean_ua,
        'cm': cm,
        'y_pred': y_pred_thresh
    }

class ComprehensiveAccuracyAnalysis:
    """Comprehensive Accuracy Analysis"""

    def __init__(self, y_true, y_pred_proba):
        self.y_true = y_true
        self.y_pred_proba = y_pred_proba
        self.results = {}

    def analyze_thresholds(self, thresholds=None):
        """Analyze performance under different thresholds"""
        if thresholds is None:
            thresholds = np.arange(0.1, 1.0, 0.05)

        for threshold in thresholds:
            y_pred = (self.y_pred_proba >= threshold).astype(int)
            self.results[threshold] = self.calculate_all_metrics(self.y_true, y_pred, self.y_pred_proba)

        return self.results

    def calculate_all_metrics(self, y_true, y_pred, y_pred_proba):
        """Calculate all metrics"""
        metrics = {}

        # Basic metrics
        metrics['accuracy'] = accuracy_score(y_true, y_pred)
        metrics['precision'] = precision_score(y_true, y_pred, zero_division=0)
        metrics['recall'] = recall_score(y_true, y_pred, zero_division=0)
        metrics['f1'] = f1_score(y_true, y_pred, zero_division=0)
        metrics['kappa'] = cohen_kappa_score(y_true, y_pred)
        metrics['mcc'] = matthews_corrcoef(y_true, y_pred)

        # AUC metrics
        if len(np.unique(y_true)) > 1:
            metrics['auc_roc'] = roc_auc_score(y_true, y_pred_proba)
            metrics['average_precision'] = average_precision_score(y_true, y_pred_proba)
        else:
            metrics['auc_roc'] = 0.5
            metrics['average_precision'] = 0.5

        # Confusion matrix related metrics
        cm = confusion_matrix(y_true, y_pred)
        metrics['cm'] = cm

        if cm.shape == (2, 2):
            TN, FP, FN, TP = cm.ravel()

            # Producer's Accuracy
            metrics['pa_0'] = TN / (TN + FP) if (TN + FP) > 0 else 0
            metrics['pa_1'] = TP / (TP + FN) if (TP + FN) > 0 else 0

            # User's Accuracy
            metrics['ua_0'] = TN / (TN + FN) if (TN + FN) > 0 else 0
            metrics['ua_1'] = TP / (TP + FP) if (TP + FP) > 0 else 0

            # Omission Error
            metrics['oe_0'] = 1 - metrics['pa_0']
            metrics['oe_1'] = 1 - metrics['pa_1']

            # Commission Error
            metrics['ce_0'] = 1 - metrics['ua_0']
            metrics['ce_1'] = 1 - metrics['ua_1']

            # Overall
            metrics['mean_pa'] = (metrics['pa_0'] + metrics['pa_1']) / 2
            metrics['mean_ua'] = (metrics['ua_0'] + metrics['ua_1']) / 2

            # Balanced accuracy
            metrics['balanced_accuracy'] = (metrics['pa_0'] + metrics['pa_1']) / 2

        return metrics

    def plot_comprehensive_analysis(self, optimal_threshold=0.5):
        """Plot comprehensive analysis visualizations"""
        # Use results with optimal threshold
        y_pred_opt = (self.y_pred_proba >= optimal_threshold).astype(int)
        metrics_opt = self.calculate_all_metrics(self.y_true, y_pred_opt, self.y_pred_proba)

        fig = plt.figure(figsize=(18, 12))

        # Create subplot grid
        gs = fig.add_gridspec(3, 3)

        # 1. Confusion Matrix
        ax1 = fig.add_subplot(gs[0, 0])
        cm = metrics_opt['cm']
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
                   xticklabels=['Stable', 'Disturbance'], yticklabels=['Stable', 'Disturbance'])
        ax1.set_title(f'Confusion Matrix (Threshold={optimal_threshold})', fontsize=12, fontweight='bold')
        ax1.set_xlabel('Predicted Label')
        ax1.set_ylabel('True Label')

        # Add accuracy metrics text
        cm_text = f"""
Producer's Accuracy (PA):
Stable: {metrics_opt.get('pa_0', 0):.3f}
Disturbance: {metrics_opt.get('pa_1', 0):.3f}

User's Accuracy (UA):
Stable: {metrics_opt.get('ua_0', 0):.3f}
Disturbance: {metrics_opt.get('ua_1', 0):.3f}

Omission Error (OE):
Stable: {metrics_opt.get('oe_0', 0):.3f}
Disturbance: {metrics_opt.get('oe_1', 0):.3f}

Commission Error (CE):
Stable: {metrics_opt.get('ce_0', 0):.3f}
Disturbance: {metrics_opt.get('ce_1', 0):.3f}
"""
        ax1.text(2.8, 0.5, cm_text, fontsize=9, verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

        # 2. ROC Curve
        ax2 = fig.add_subplot(gs[0, 1])
        fpr, tpr, _ = roc_curve(self.y_true, self.y_pred_proba)
        ax2.plot(fpr, tpr, label=f'ROC Curve (AUC={metrics_opt.get("auc_roc", 0):.3f})', linewidth=2)
        ax2.plot([0, 1], [0, 1], 'k--', label='Random Guess')
        ax2.set_xlabel('False Positive Rate (FPR)')
        ax2.set_ylabel('True Positive Rate (TPR)')
        ax2.set_title('ROC Curve', fontsize=12, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # 3. PR Curve
        ax3 = fig.add_subplot(gs[0, 2])
        precision_curve, recall_curve, _ = precision_recall_curve(self.y_true, self.y_pred_proba)
        ax3.plot(recall_curve, precision_curve,
                label=f'PR Curve (AP={metrics_opt.get("average_precision", 0):.3f})',
                linewidth=2)
        ax3.set_xlabel('Recall')
        ax3.set_ylabel('Precision')
        ax3.set_title('Precision-Recall Curve', fontsize=12, fontweight='bold')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # 4. Threshold Analysis
        ax4 = fig.add_subplot(gs[1, 0])
        thresholds = np.arange(0.1, 1.0, 0.05)
        accuracies = []
        f1_scores = []

        for thresh in thresholds:
            y_pred = (self.y_pred_proba >= thresh).astype(int)
            accuracies.append(accuracy_score(self.y_true, y_pred))
            f1_scores.append(f1_score(self.y_true, y_pred, zero_division=0))

        ax4.plot(thresholds, accuracies, 'b-', label='Accuracy', linewidth=2)
        ax4.plot(thresholds, f1_scores, 'r-', label='F1 Score', linewidth=2)
        ax4.axvline(x=optimal_threshold, color='g', linestyle='--',
                   label=f'Optimal Threshold={optimal_threshold}')
        ax4.set_xlabel('Threshold')
        ax4.set_ylabel('Score')
        ax4.set_title('Threshold Sensitivity Analysis', fontsize=12, fontweight='bold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        # 5. Prediction Probability Distribution
        ax5 = fig.add_subplot(gs[1, 1])
        ax5.hist(self.y_pred_proba[self.y_true==0], bins=30, alpha=0.7,
                label='Stable (True)', color='blue', density=True)
        ax5.hist(self.y_pred_proba[self.y_true==1], bins=30, alpha=0.7,
                label='Disturbance (True)', color='red', density=True)
        ax5.axvline(x=optimal_threshold, color='g', linestyle='--',
                   label=f'Threshold={optimal_threshold}')
        ax5.set_xlabel('Predicted Probability')
        ax5.set_ylabel('Density')
        ax5.set_title('Prediction Probability Distribution', fontsize=12, fontweight='bold')
        ax5.legend()

        # 6. Radar Chart of Accuracy Metrics
        ax6 = fig.add_subplot(gs[1, 2], polar=True)
        metrics_names = ['Accuracy', 'F1', 'Precision', 'Recall', 'Kappa', 'MCC']
        metrics_values = [metrics_opt.get('accuracy', 0), metrics_opt.get('f1', 0),
                         metrics_opt.get('precision', 0), metrics_opt.get('recall', 0),
                         metrics_opt.get('kappa', 0), metrics_opt.get('mcc', 0)]

        angles = np.linspace(0, 2*np.pi, len(metrics_names), endpoint=False).tolist()
        metrics_values = metrics_values + [metrics_values[0]]
        angles = angles + [angles[0]]

        ax6.plot(angles, metrics_values, 'o-', linewidth=2)
        ax6.fill(angles, metrics_values, alpha=0.25)
        ax6.set_thetagrids(np.degrees(angles[:-1]), metrics_names)
        ax6.set_ylim(0, 1)
        ax6.set_title('Accuracy Metrics Radar Chart', fontsize=12, fontweight='bold', y=1.1)

        # 7. Uncertainty Analysis
        ax7 = fig.add_subplot(gs[2, :])
        self.plot_uncertainty_analysis(ax7, optimal_threshold)

        plt.suptitle('Comprehensive Accuracy Metrics and Uncertainty Analysis',
                    fontsize=16, fontweight='bold', y=0.98)
        plt.tight_layout()
        plt.show()

        return fig

    def plot_uncertainty_analysis(self, ax, threshold):
        """Plot uncertainty analysis"""
        # Calculate uncertainty metrics
        uncertainty_zones = {
            'High Confidence Correct': 0,
            'High Confidence Error': 0,
            'Low Confidence': 0
        }

        confidence = np.abs(self.y_pred_proba - 0.5) * 2  # Convert to 0-1 confidence

        for i in range(len(self.y_true)):
            pred_class = 1 if self.y_pred_proba[i] >= threshold else 0
            is_correct = pred_class == self.y_true[i]

            if confidence[i] >= 0.6:  # High confidence
                if is_correct:
                    uncertainty_zones['High Confidence Correct'] += 1
                else:
                    uncertainty_zones['High Confidence Error'] += 1
            else:  # Low confidence
                uncertainty_zones['Low Confidence'] += 1

        total = sum(uncertainty_zones.values())

        # Plot pie chart
        labels = list(uncertainty_zones.keys())
        sizes = list(uncertainty_zones.values())
        colors = ['lightgreen', 'lightcoral', 'lightblue']
        explode = (0.1, 0.1, 0.1)

        wedges, texts, autotexts = ax.pie(sizes, explode=explode, labels=labels,
                                         colors=colors, autopct='%1.1f%%',
                                         shadow=True, startangle=90)

        for autotext in autotexts:
            autotext.set_color('black')
            autotext.set_fontweight('bold')

        ax.set_title('Prediction Uncertainty Analysis', fontsize=12, fontweight='bold')

        # Add uncertainty sources text
        uncertainty_text = f"""
Uncertainty Source Analysis:

1. Training Sample Uncertainty:
   • Some samples from Hansen data have positional and classification errors
   • Uneven sample distribution across ecoregions

2. Image Data Uncertainty:
   • Cloud and snow contamination in annual composite images
   • Sensor differences and temporal inconsistencies

3. Model Uncertainty:
   • Sensitivity to threshold selection
   • Feature differences across ecoregions
   • Time series discontinuity issues

4. Application Uncertainty:
   • Mixed pixel effects
   • Progressive disturbance year attribution
   • Small patch detection limitations

Total Samples: {total}
High Confidence Correct: {uncertainty_zones['High Confidence Correct']} ({uncertainty_zones['High Confidence Correct']/total*100:.1f}%)
High Confidence Error: {uncertainty_zones['High Confidence Error']} ({uncertainty_zones['High Confidence Error']/total*100:.1f}%)
Low Confidence Regions: {uncertainty_zones['Low Confidence']} ({uncertainty_zones['Low Confidence']/total*100:.1f}%)
"""

        # Add text to the right of pie chart
        ax.text(1.8, 0, uncertainty_text, transform=ax.transAxes, fontsize=9,
               verticalalignment='center',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

    def generate_report(self, optimal_threshold=0.5):
        """Generate analysis report"""
        y_pred_opt = (self.y_pred_proba >= optimal_threshold).astype(int)
        metrics_opt = self.calculate_all_metrics(self.y_true, y_pred_opt, self.y_pred_proba)

        report = f"""
Comprehensive Accuracy Analysis Report
{'='*60}

1. Overall Performance
{'='*30}
Overall Accuracy (OA): {metrics_opt.get('accuracy', 0):.4f}
F1 Score: {metrics_opt.get('f1', 0):.4f}
Kappa Coefficient: {metrics_opt.get('kappa', 0):.4f}
Matthews Correlation Coefficient: {metrics_opt.get('mcc', 0):.4f}
AUC-ROC: {metrics_opt.get('auc_roc', 0):.4f}
Average Precision (AP): {metrics_opt.get('average_precision', 0):.4f}

2. Classification Accuracy Breakdown
{'='*30}
2.1 Producer's Accuracy (Complement of Omission Error):
   • Stable class: {metrics_opt.get('pa_0', 0):.4f} (Omission Error: {metrics_opt.get('oe_0', 0):.4f})
   • Disturbance class: {metrics_opt.get('pa_1', 0):.4f} (Omission Error: {metrics_opt.get('oe_1', 0):.4f})
   • Mean Producer's Accuracy: {metrics_opt.get('mean_pa', 0):.4f}

2.2 User's Accuracy (Complement of Commission Error):
   • Stable class: {metrics_opt.get('ua_0', 0):.4f} (Commission Error: {metrics_opt.get('ce_0', 0):.4f})
   • Disturbance class: {metrics_opt.get('ua_1', 0):.4f} (Commission Error: {metrics_opt.get('ce_1', 0):.4f})
   • Mean User's Accuracy: {metrics_opt.get('mean_ua', 0):.4f}

3. Threshold Sensitivity Analysis
{'='*30}
Optimal Threshold: {optimal_threshold}
F1 Score Variation with Threshold: ±{np.std([v.get('f1', 0) for v in results.values()]):.4f}

4. Uncertainty Assessment
{'='*30}
4.1 Data Uncertainty:
   • Training sample positional errors
   • Image data quality issues
   • Ecoregion feature differences

4.2 Model Uncertainty:
   • Threshold selection sensitivity
   • Time series analysis uncertainty
   • Small patch detection limitations

4.3 Application Uncertainty:
   • Progressive disturbance year attribution
   • Mixed pixel effects
   • Spatial scale conversion issues

5. Recommendations
{'='*30}
1. Report both Producer's and User's Accuracy, not just overall accuracy
2. Provide confusion matrices as supplementary material
3. Discuss the impact of threshold selection on results
4. Analyze main sources of uncertainty
5. Consider providing uncertainty maps

{'='*60}
Report Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
Total Samples Analyzed: {len(self.y_true)}
"""

        return report

# Execute comprehensive analysis
print("\nExecuting comprehensive accuracy analysis...")
analyzer = ComprehensiveAccuracyAnalysis(y_true, y_pred_proba)

# Analyze different thresholds
thresholds = [0.3, 0.5, 0.7]
for threshold in thresholds:
    y_pred_thresh = (y_pred_proba >= threshold).astype(int)
    metrics = analyzer.calculate_all_metrics(y_true, y_pred_thresh, y_pred_proba)
    results[threshold] = metrics

# Plot comprehensive analysis visualizations
print("\nGenerating analysis visualizations...")
fig = analyzer.plot_comprehensive_analysis(optimal_threshold=0.5)

# Generate report
print("\nGenerating analysis report...")
report = analyzer.generate_report(optimal_threshold=0.5)
print(report)

# Save results
print("\nSaving analysis results...")
import json
import os

# Create directory if it doesn't exist
os.makedirs('analysis_results', exist_ok=True)

# Save metrics results
metrics_summary = {}
for threshold, metrics in results.items():
    # Only save serializable data
    save_metrics = {k: v for k, v in metrics.items() if k != 'cm'}
    if 'cm' in metrics:
        save_metrics['cm'] = metrics['cm'].tolist()
    metrics_summary[str(threshold)] = save_metrics

with open('analysis_results/comprehensive_metrics.json', 'w') as f:
    json.dump(metrics_summary, f, indent=2)

# Save report
with open('analysis_results/accuracy_analysis_report.txt', 'w') as f:
    f.write(report)

# Save figure
fig.savefig('analysis_results/comprehensive_analysis.png', dpi=300, bbox_inches='tight')

print("✓ Analysis results saved:")
print("  - analysis_results/comprehensive_metrics.json")
print("  - analysis_results/accuracy_analysis_report.txt")
print("  - analysis_results/comprehensive_analysis.png")

print("\n" + "="*60)
print("Manuscript Revision Suggestions")
print("="*60)

print("""
For addressing reviewer comments on accuracy metrics:

1. Report in Results Section:
   • Producer's Accuracy and User's Accuracy
   • Omission Error and Commission Error
   • Balanced Accuracy

2. Specify in Methods Section:
   • Threshold used and justification for threshold selection
   • Calculation methods for accuracy metrics

3. Analyze in Discussion Section:
   • Main sources of error
   • Impact of uncertainty on results
   • Sensitivity to threshold selection

4. Provide in Supplementary Materials:
   • Complete confusion matrices
   • Detailed accuracy metrics for each ecoregion
   • Uncertainty analysis results
   • Threshold sensitivity analysis
""")

print("\n" + "="*60)
print("Comprehensive Accuracy Analysis Completed!")
print("="*60)