# Classification: Batch-Aware Model Training and Evaluation

## Overview

This notebook demonstrates the classification pipeline for DNA methylation-based prediction of HIIT response. We implement batch-aware training strategies and comprehensive evaluation methods to build robust and reproducible classifiers.

### Key Components

1. **Batch-Aware Classifier**: Models that handle batch effects as covariates
2. **Binary Classification**: HIIT intervention vs Control/Baseline
3. **Multiclass Classification**: Training duration (4W/8W/12W)
4. **Multi-Version Comparison**: Evaluate models across different preprocessing versions

### Learning Objectives

By the end of this notebook, you will be able to:

1. Configure and train batch-aware classifiers
2. Perform stratified cross-validation with proper sample handling
3. Evaluate model performance with multiple metrics
4. Compare models across different data preprocessing versions

## 1. Environment Setup

In [None]:
# Standard library imports
import sys
import logging
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd

# Machine learning
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Project-specific imports - Classification
from src.models import (
    ClassifierConfig,
    BatchAwareClassifier,
    HIITClassificationPipeline,
    ModelEvaluator,
    CrossValidationStrategy,
    MultiVersionComparator
)

# Visualization
from src.visualization import (
    plot_roc_curve,
    plot_roc_curves,
    plot_confusion_matrix,
    plot_feature_importance
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Visualization settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

print(f"Project root: {project_root}")

## 2. Load Data and Features

In [None]:
import pickle

# Define paths
processed_dir = project_root / 'data' / 'processed'
features_dir = processed_dir / 'features'
models_dir = project_root / 'models'
figures_dir = project_root / 'data' / 'figures' / 'binary'

# Create output directories
models_dir.mkdir(parents=True, exist_ok=True)
figures_dir.mkdir(parents=True, exist_ok=True)

# Load preprocessed data versions
with open(processed_dir / 'methyl_data_versions.pkl', 'rb') as f:
    data_versions = pickle.load(f)

# Load sample mapping
sample_mapping = pd.read_csv(
    project_root / 'data' / 'raw' / 'GSE171140_sample_mapping.csv'
)

print("Loaded data versions:")
for name, data in data_versions.items():
    print(f"  {name}: {data.shape}")

In [None]:
# Load selected features
binary_features = pd.read_csv(
    features_dir / 'binary_features_L5_moderate.csv'
)['probe_id'].tolist()

multiclass_features = pd.read_csv(
    features_dir / 'multiclass_features_L5_moderate.csv'
)['probe_id'].tolist()

print(f"Binary features: {len(binary_features)}")
print(f"Multiclass features: {len(multiclass_features)}")

## 3. Prepare Classification Data

In [None]:
# Use the standardized version for classification
methylation_data = data_versions['standardized']

# Align samples
sample_ids = methylation_data.columns.tolist()
sample_info = sample_mapping.set_index('sample_id').loc[sample_ids].reset_index()

# Prepare binary classification data
binary_mask = sample_info['binary_class'].isin(['HIIT', 'Control'])
binary_samples = sample_info[binary_mask]['sample_id'].tolist()
binary_labels = (sample_info[binary_mask]['binary_class'] == 'HIIT').astype(int).values

# Extract batch information for batch-aware modeling
batch_info = sample_info[binary_mask]['study_group'].values

print(f"Binary classification samples: {len(binary_samples)}")
print(f"  HIIT: {sum(binary_labels)}, Control: {len(binary_labels) - sum(binary_labels)}")
print(f"  Batches: {np.unique(batch_info)}")

In [None]:
# Create feature matrix for binary classification
# Filter to selected features
available_features = [f for f in binary_features if f in methylation_data.index]
X_binary = methylation_data.loc[available_features, binary_samples].T.values
y_binary = binary_labels

print(f"Feature matrix shape: {X_binary.shape}")
print(f"  Samples: {X_binary.shape[0]}")
print(f"  Features: {X_binary.shape[1]}")

## 4. Configure Classifier

The `ClassifierConfig` defines the classifier type and hyperparameters.

In [None]:
# Configure classifier
config = ClassifierConfig(
    classifier_type='random_forest',  # Options: 'logistic', 'svm', 'random_forest', 'xgboost'
    n_estimators=100,
    max_depth=10,
    random_state=42
)

print("Classifier Configuration:")
print(f"  Type: {config.classifier_type}")
print(f"  Parameters: n_estimators={config.n_estimators}, max_depth={config.max_depth}")

## 5. Batch-Aware Classification

The `BatchAwareClassifier` incorporates batch information to reduce confounding effects.

In [None]:
# Initialize batch-aware classifier
batch_classifier = BatchAwareClassifier(
    config=config,
    batch_handling='covariate'  # Options: 'covariate', 'stratified', 'none'
)

print("Batch-Aware Classifier initialized")
print(f"  Batch handling strategy: {batch_classifier.batch_handling}")

In [None]:
# Fit the classifier with batch information
batch_classifier.fit(X_binary, y_binary, batch=batch_info)

# Get training predictions
train_predictions = batch_classifier.predict(X_binary)
train_probabilities = batch_classifier.predict_proba(X_binary)[:, 1]

print("\nTraining Performance:")
print(f"  Accuracy: {accuracy_score(y_binary, train_predictions):.3f}")
print(f"  AUC-ROC: {roc_auc_score(y_binary, train_probabilities):.3f}")

## 6. Cross-Validation Evaluation

We use stratified cross-validation to obtain unbiased performance estimates.

In [None]:
# Configure cross-validation strategy
cv_strategy = CrossValidationStrategy(
    n_splits=5,
    n_repeats=10,
    random_state=42
)

print("Cross-Validation Configuration:")
print(f"  Folds: {cv_strategy.n_splits}")
print(f"  Repeats: {cv_strategy.n_repeats}")
print(f"  Total iterations: {cv_strategy.n_splits * cv_strategy.n_repeats}")

In [None]:
# Run cross-validation
print("\nRunning cross-validation...")

cv_results = cv_strategy.evaluate(
    batch_classifier,
    X_binary,
    y_binary,
    batch=batch_info
)

print("\nCross-Validation Results:")
print(f"  Accuracy: {cv_results['accuracy_mean']:.3f} +/- {cv_results['accuracy_std']:.3f}")
print(f"  AUC-ROC: {cv_results['auc_mean']:.3f} +/- {cv_results['auc_std']:.3f}")
print(f"  F1 Score: {cv_results['f1_mean']:.3f} +/- {cv_results['f1_std']:.3f}")

In [None]:
# Visualize cross-validation performance distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics = ['accuracy', 'auc', 'f1']
titles = ['Accuracy', 'AUC-ROC', 'F1 Score']

for ax, metric, title in zip(axes, metrics, titles):
    values = cv_results[f'{metric}_scores']
    ax.hist(values, bins=20, edgecolor='black', alpha=0.7)
    ax.axvline(x=np.mean(values), color='red', linestyle='--', 
               label=f'Mean: {np.mean(values):.3f}')
    ax.set_xlabel(title)
    ax.set_ylabel('Frequency')
    ax.set_title(f'{title} Distribution')
    ax.legend()

plt.tight_layout()
plt.savefig(figures_dir / 'cv_performance_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Model Evaluation with ROC Curves

In [None]:
# Initialize model evaluator
evaluator = ModelEvaluator()

# Get detailed evaluation metrics
eval_results = evaluator.evaluate(
    batch_classifier,
    X_binary,
    y_binary
)

print("Detailed Evaluation Metrics:")
for metric, value in eval_results.items():
    if isinstance(value, float):
        print(f"  {metric}: {value:.4f}")

In [None]:
# Plot ROC curve
fig, ax = plot_roc_curve(
    y_binary,
    train_probabilities,
    title='ROC Curve: HIIT vs Control Classification',
    figsize=(8, 8)
)

plt.savefig(figures_dir / 'roc_curve_binary.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot confusion matrix
fig, ax = plot_confusion_matrix(
    y_binary,
    train_predictions,
    class_names=['Control', 'HIIT'],
    title='Confusion Matrix: HIIT vs Control',
    figsize=(8, 6)
)

plt.savefig(figures_dir / 'confusion_matrix_binary.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Feature Importance Analysis

In [None]:
# Get feature importances from the trained model
importances = batch_classifier.get_feature_importance()

# Create importance DataFrame
importance_df = pd.DataFrame({
    'feature': available_features,
    'importance': importances
}).sort_values('importance', ascending=False)

print("Top 10 Most Important Features:")
print(importance_df.head(10).to_string(index=False))

In [None]:
# Plot feature importance
fig, ax = plot_feature_importance(
    importance_df.head(20)['feature'].tolist(),
    importance_df.head(20)['importance'].tolist(),
    title='Top 20 Feature Importances',
    figsize=(10, 8)
)

plt.savefig(figures_dir / 'feature_importance_top20.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Multi-Version Model Comparison

Compare classifier performance across different preprocessing versions to assess robustness.

In [None]:
# Initialize multi-version comparator
comparator = MultiVersionComparator(
    classifier_config=config,
    cv_strategy=cv_strategy
)

print("Multi-Version Comparator initialized")

In [None]:
# Compare across data versions
print("Comparing across data versions...")

version_results = {}
for version_name, version_data in data_versions.items():
    print(f"\nEvaluating: {version_name}")
    
    # Extract features for this version
    available = [f for f in binary_features if f in version_data.index]
    X_version = version_data.loc[available, binary_samples].T.values
    
    # Run cross-validation
    results = comparator.evaluate_version(
        X_version,
        y_binary,
        batch=batch_info
    )
    
    version_results[version_name] = results
    print(f"  AUC: {results['auc_mean']:.3f} +/- {results['auc_std']:.3f}")

In [None]:
# Visualize version comparison
fig, ax = plt.subplots(figsize=(10, 6))

versions = list(version_results.keys())
means = [version_results[v]['auc_mean'] for v in versions]
stds = [version_results[v]['auc_std'] for v in versions]

bars = ax.bar(range(len(versions)), means, yerr=stds, 
              color='steelblue', edgecolor='black', capsize=5)
ax.set_xticks(range(len(versions)))
ax.set_xticklabels([v.replace('_', '\n') for v in versions])
ax.set_ylabel('AUC-ROC')
ax.set_title('Classification Performance Across Data Versions')
ax.set_ylim(0.5, 1.0)

# Add value labels
for bar, mean in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{mean:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig(figures_dir / 'version_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Multiclass Classification: HIIT Duration

Train a classifier to distinguish between 4W, 8W, and 12W training durations.

In [None]:
# Prepare multiclass data
multi_mask = sample_info['multi_class'].notna()
multi_samples = sample_info[multi_mask]['sample_id'].tolist()
multi_labels_str = sample_info[multi_mask]['multi_class'].values
multi_batch = sample_info[multi_mask]['study_group'].values

# Encode labels
label_encoder = {'4W': 0, '8W': 1, '12W': 2}
multi_labels = np.array([label_encoder[l] for l in multi_labels_str])

# Extract feature matrix
available_multi = [f for f in multiclass_features if f in methylation_data.index]
X_multi = methylation_data.loc[available_multi, multi_samples].T.values

print(f"Multiclass classification:")
print(f"  Samples: {X_multi.shape[0]}")
print(f"  Features: {X_multi.shape[1]}")
print(f"  Class distribution: {np.bincount(multi_labels)}")

In [None]:
# Train multiclass classifier
multi_config = ClassifierConfig(
    classifier_type='random_forest',
    n_estimators=100,
    max_depth=10,
    random_state=42
)

multi_classifier = BatchAwareClassifier(
    config=multi_config,
    batch_handling='covariate'
)

# Fit and evaluate
multi_classifier.fit(X_multi, multi_labels, batch=multi_batch)
multi_predictions = multi_classifier.predict(X_multi)

print("\nMulticlass Training Performance:")
print(f"  Accuracy: {accuracy_score(multi_labels, multi_predictions):.3f}")

In [None]:
# Cross-validation for multiclass
multi_cv_results = cv_strategy.evaluate(
    multi_classifier,
    X_multi,
    multi_labels,
    batch=multi_batch
)

print("\nMulticlass Cross-Validation Results:")
print(f"  Accuracy: {multi_cv_results['accuracy_mean']:.3f} +/- {multi_cv_results['accuracy_std']:.3f}")
print(f"  F1 (macro): {multi_cv_results['f1_mean']:.3f} +/- {multi_cv_results['f1_std']:.3f}")

In [None]:
# Multiclass confusion matrix
multiclass_figures_dir = project_root / 'data' / 'figures' / 'multiclass'
multiclass_figures_dir.mkdir(parents=True, exist_ok=True)

fig, ax = plot_confusion_matrix(
    multi_labels,
    multi_predictions,
    class_names=['4W', '8W', '12W'],
    title='Confusion Matrix: HIIT Duration Classification',
    figsize=(8, 6)
)

plt.savefig(multiclass_figures_dir / 'confusion_matrix_multiclass.png', dpi=150, bbox_inches='tight')
plt.show()

## 11. HIIT Classification Pipeline

The `HIITClassificationPipeline` provides a complete end-to-end workflow.

In [None]:
# Initialize the complete pipeline
pipeline = HIITClassificationPipeline(
    classifier_config=config,
    cv_strategy=cv_strategy,
    output_dir=str(models_dir)
)

print("HIIT Classification Pipeline initialized")

In [None]:
# Run the complete pipeline
pipeline_results = pipeline.run(
    X_binary,
    y_binary,
    batch=batch_info,
    feature_names=available_features
)

print("\nPipeline Results:")
print(f"  Best Model AUC: {pipeline_results['best_auc']:.3f}")
print(f"  Model saved to: {pipeline_results['model_path']}")

## 12. Save Models and Results

In [None]:
import json

# Save binary classifier
binary_model_path = models_dir / 'binary_classifier.pkl'
with open(binary_model_path, 'wb') as f:
    pickle.dump(batch_classifier, f)
print(f"Binary classifier saved: {binary_model_path}")

# Save multiclass classifier
multi_model_path = models_dir / 'multiclass_classifier.pkl'
with open(multi_model_path, 'wb') as f:
    pickle.dump(multi_classifier, f)
print(f"Multiclass classifier saved: {multi_model_path}")

In [None]:
# Save evaluation results
results_summary = {
    'binary_classification': {
        'accuracy': cv_results['accuracy_mean'],
        'accuracy_std': cv_results['accuracy_std'],
        'auc': cv_results['auc_mean'],
        'auc_std': cv_results['auc_std'],
        'f1': cv_results['f1_mean'],
        'f1_std': cv_results['f1_std']
    },
    'multiclass_classification': {
        'accuracy': multi_cv_results['accuracy_mean'],
        'accuracy_std': multi_cv_results['accuracy_std'],
        'f1': multi_cv_results['f1_mean'],
        'f1_std': multi_cv_results['f1_std']
    },
    'version_comparison': {
        v: {'auc_mean': r['auc_mean'], 'auc_std': r['auc_std']}
        for v, r in version_results.items()
    }
}

results_path = models_dir / 'classification_results.json'
with open(results_path, 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"Results saved: {results_path}")

## Summary

In this notebook, we completed the classification pipeline:

### Key Accomplishments

1. **Binary Classification**: HIIT vs Control
   - Batch-aware Random Forest classifier
   - Cross-validated performance metrics
   - ROC curve and confusion matrix analysis

2. **Multiclass Classification**: 4W/8W/12W duration
   - Trained duration classifier
   - Evaluated with stratified cross-validation

3. **Multi-Version Comparison**
   - Compared performance across preprocessing versions
   - Identified most robust preprocessing approach

4. **Feature Importance**
   - Ranked features by predictive importance
   - Identified top biomarker candidates

### Next Steps

Continue to **05_enrichment_analysis.ipynb** to:
- Map CpG features to genes
- Perform GO and KEGG pathway enrichment
- Understand biological significance of identified biomarkers

In [None]:
# Session summary
print("=" * 60)
print("CLASSIFICATION COMPLETE")
print("=" * 60)
print(f"\nBinary Classification (HIIT vs Control):")
print(f"  CV Accuracy: {cv_results['accuracy_mean']:.3f}")
print(f"  CV AUC-ROC: {cv_results['auc_mean']:.3f}")
print(f"\nMulticlass Classification (Duration):")
print(f"  CV Accuracy: {multi_cv_results['accuracy_mean']:.3f}")
print(f"\nModels saved to: {models_dir}")