# Custom Analysis: Extending the HIIT Methylation Pipeline

## Overview

This notebook demonstrates how to customize and extend the HIIT methylation analysis pipeline for your specific research needs. We cover advanced customization options that go beyond the default configurations shown in the main pipeline notebooks.

### Topics Covered

1. **Custom Feature Selection**: Adjust stringency parameters and create custom selection criteria
2. **Custom Classifiers**: Integrate new machine learning models
3. **Extended Enrichment Analysis**: Add custom annotation databases and gene sets
4. **Custom Visualizations**: Create publication-ready figures with custom styling
5. **Pipeline Extensions**: Integrate external tools and analyses

### Prerequisites

This notebook assumes familiarity with the main pipeline notebooks (01-05). Ensure you have:
- Completed preprocessing (02_preprocessing.ipynb)
- Generated initial features (03_feature_selection.ipynb)
- Trained baseline models (04_classification.ipynb)

## 1. Environment Setup

In [None]:
# Standard library imports
import sys
import pickle
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field

# Scientific computing
import numpy as np
import pandas as pd
from scipy import stats

# Machine learning
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier, AdaBoostClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score, accuracy_score

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# Add project root to path
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

# Project-specific imports
from src.features import (
    TenLevelFeatureSelector,
    FeatureSelectionConfig,
    StatisticalFeatureSelector,
    run_ttest,
    run_anova,
    calculate_effect_size,
    adjust_pvalues
)
from src.models import (
    ClassifierConfig,
    BatchAwareClassifier,
    CrossValidationStrategy,
    ModelEvaluator
)
from src.enrichment import (
    EnrichmentAnalyzer,
    EnrichmentConfig,
    EPICAnnotationMapper,
    MSigDBLoader
)
from src.visualization import (
    plot_pca_visualization,
    plot_heatmap,
    plot_volcano,
    plot_roc_curves,
    PublicationFigureGenerator
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print(f"Project root: {project_root}")

In [None]:
# Define paths
DATA_DIR = project_root / 'data' / 'raw'
PROCESSED_DIR = project_root / 'data' / 'processed'
MODELS_DIR = project_root / 'models'
RESULTS_DIR = project_root / 'results'
FIGURES_DIR = project_root / 'data' / 'figures' / 'custom'

# Create output directories
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

# Load preprocessed data
with open(PROCESSED_DIR / 'methyl_data_preprocessed.pkl', 'rb') as f:
    methylation_data = pickle.load(f)

sample_mapping = pd.read_csv(DATA_DIR / 'GSE171140_sample_mapping.csv')

print(f"Loaded data: {methylation_data.shape[0]:,} probes x {methylation_data.shape[1]} samples")

## 2. Custom Feature Selection Configurations

The `FeatureSelectionConfig` class allows fine-grained control over the feature selection process. Here we demonstrate how to create custom configurations for specific research questions.

### 2.1 Creating Custom Stringency Levels

The Ten-Level Framework provides predefined stringency levels (L1-L10), but you can create custom configurations for specific needs.

In [None]:
@dataclass
class CustomFeatureConfig:
    """Custom feature selection configuration with flexible parameters."""
    
    # Statistical thresholds
    p_value_threshold: float = 0.05
    fdr_threshold: float = 0.1
    effect_size_threshold: float = 0.3
    
    # Variance filtering
    min_variance: float = 0.01
    
    # ML-based selection
    use_lasso: bool = True
    use_random_forest: bool = True
    lasso_alpha: float = 0.01
    rf_n_estimators: int = 100
    rf_importance_threshold: float = 0.001
    
    # Consensus requirements
    min_methods_agreement: int = 2
    
    def __str__(self):
        return (
            f"CustomFeatureConfig(p<{self.p_value_threshold}, "
            f"FDR<{self.fdr_threshold}, |d|>{self.effect_size_threshold})"
        )


# Example: Ultra-conservative configuration for publication-quality biomarkers
publication_config = CustomFeatureConfig(
    p_value_threshold=0.001,
    fdr_threshold=0.01,
    effect_size_threshold=0.8,  # Large effect size
    min_variance=0.02,
    min_methods_agreement=3  # Must appear in 3+ selection methods
)

# Example: Exploratory configuration for pathway discovery
exploratory_config = CustomFeatureConfig(
    p_value_threshold=0.05,
    fdr_threshold=0.2,
    effect_size_threshold=0.2,  # Small effect size acceptable
    min_variance=0.005,
    min_methods_agreement=1  # Any method selection is acceptable
)

print("Publication config:", publication_config)
print("Exploratory config:", exploratory_config)

### 2.2 Custom Feature Selection Pipeline

In [None]:
class CustomFeatureSelector:
    """Custom feature selector with configurable criteria."""
    
    def __init__(self, config: CustomFeatureConfig):
        self.config = config
        self.selection_results = {}
    
    def select_features(
        self,
        X: np.ndarray,
        y: np.ndarray,
        feature_names: List[str]
    ) -> Dict[str, Any]:
        """Run custom feature selection pipeline."""
        
        results = {
            'statistical': set(),
            'ml_based': set(),
            'consensus': set()
        }
        
        # Step 1: Statistical selection
        stat_features = self._statistical_selection(X, y, feature_names)
        results['statistical'] = stat_features
        
        # Step 2: ML-based selection
        ml_features = self._ml_selection(X, y, feature_names)
        results['ml_based'] = ml_features
        
        # Step 3: Consensus features
        results['consensus'] = self._compute_consensus(
            [results['statistical'], results['ml_based']]
        )
        
        self.selection_results = results
        return results
    
    def _statistical_selection(
        self,
        X: np.ndarray,
        y: np.ndarray,
        feature_names: List[str]
    ) -> set:
        """Apply statistical criteria."""
        selected = set()
        
        # Perform t-tests for each feature
        p_values = []
        effect_sizes = []
        
        for i in range(X.shape[1]):
            group1 = X[y == 0, i]
            group2 = X[y == 1, i]
            
            _, p_val = stats.ttest_ind(group1, group2)
            p_values.append(p_val)
            
            # Cohen's d
            pooled_std = np.sqrt(
                ((len(group1) - 1) * np.var(group1) + 
                 (len(group2) - 1) * np.var(group2)) /
                (len(group1) + len(group2) - 2)
            )
            if pooled_std > 0:
                d = abs(np.mean(group1) - np.mean(group2)) / pooled_std
            else:
                d = 0
            effect_sizes.append(d)
        
        # Adjust p-values for multiple testing
        from statsmodels.stats.multitest import multipletests
        _, fdr_values, _, _ = multipletests(p_values, method='fdr_bh')
        
        # Select features meeting criteria
        for i, name in enumerate(feature_names):
            if (fdr_values[i] < self.config.fdr_threshold and
                effect_sizes[i] > self.config.effect_size_threshold):
                selected.add(name)
        
        return selected
    
    def _ml_selection(
        self,
        X: np.ndarray,
        y: np.ndarray,
        feature_names: List[str]
    ) -> set:
        """Apply ML-based selection."""
        selected = set()
        
        if self.config.use_random_forest:
            from sklearn.ensemble import RandomForestClassifier
            
            rf = RandomForestClassifier(
                n_estimators=self.config.rf_n_estimators,
                random_state=42
            )
            rf.fit(X, y)
            
            importances = rf.feature_importances_
            for i, name in enumerate(feature_names):
                if importances[i] > self.config.rf_importance_threshold:
                    selected.add(name)
        
        if self.config.use_lasso:
            from sklearn.linear_model import LogisticRegression
            
            lasso = LogisticRegression(
                penalty='l1',
                solver='saga',
                C=1/self.config.lasso_alpha,
                random_state=42,
                max_iter=1000
            )
            lasso.fit(StandardScaler().fit_transform(X), y)
            
            coefs = np.abs(lasso.coef_[0])
            for i, name in enumerate(feature_names):
                if coefs[i] > 0:
                    selected.add(name)
        
        return selected
    
    def _compute_consensus(
        self,
        feature_sets: List[set]
    ) -> set:
        """Compute consensus features based on method agreement."""
        from collections import Counter
        
        # Count occurrences across methods
        all_features = []
        for feat_set in feature_sets:
            all_features.extend(feat_set)
        
        feature_counts = Counter(all_features)
        
        # Keep features meeting minimum agreement
        consensus = set()
        for feature, count in feature_counts.items():
            if count >= self.config.min_methods_agreement:
                consensus.add(feature)
        
        return consensus


print("CustomFeatureSelector class defined.")

In [None]:
# Demonstrate custom feature selection

# Prepare data
sample_ids = methylation_data.columns.tolist()
sample_info = sample_mapping.set_index('sample_id').loc[sample_ids].reset_index()

binary_mask = sample_info['binary_class'].isin(['HIIT', 'Control'])
binary_samples = sample_info[binary_mask]['sample_id'].tolist()
binary_labels = (sample_info[binary_mask]['binary_class'] == 'HIIT').astype(int).values

# Select subset of features for demonstration
feature_subset = methylation_data.index[:5000].tolist()  # First 5000 probes
X_demo = methylation_data.loc[feature_subset, binary_samples].T.values
y_demo = binary_labels

# Run custom selection with exploratory config
custom_selector = CustomFeatureSelector(exploratory_config)
results = custom_selector.select_features(X_demo, y_demo, feature_subset)

print(f"Statistical selection: {len(results['statistical'])} features")
print(f"ML-based selection: {len(results['ml_based'])} features")
print(f"Consensus features: {len(results['consensus'])} features")

### 2.3 Region-Based Feature Selection

Sometimes you may want to select features based on genomic regions (promoters, gene bodies, etc.).

In [None]:
class RegionBasedSelector:
    """Select features based on genomic region annotations."""
    
    def __init__(self, annotation_file: Path):
        self.annotation_file = annotation_file
        self.annotations = None
    
    def load_annotations(self) -> pd.DataFrame:
        """Load EPIC array annotations."""
        if self.annotation_file.exists():
            self.annotations = pd.read_csv(
                self.annotation_file,
                usecols=['Name', 'CHR', 'MAPINFO', 'UCSC_RefGene_Name', 
                         'UCSC_RefGene_Group', 'Relation_to_UCSC_CpG_Island']
            )
        else:
            print(f"Warning: Annotation file not found at {self.annotation_file}")
            self.annotations = pd.DataFrame()
        return self.annotations
    
    def filter_by_region(
        self,
        probe_ids: List[str],
        regions: List[str] = ['TSS200', 'TSS1500', '1stExon']
    ) -> List[str]:
        """Filter probes to specific genomic regions.
        
        Parameters
        ----------
        probe_ids : List[str]
            List of CpG probe IDs
        regions : List[str]
            Regions to keep. Options include:
            - 'TSS200', 'TSS1500': Promoter regions
            - '1stExon', '5UTR': Gene start regions
            - 'Body', '3UTR': Gene body and end regions
        
        Returns
        -------
        List[str]
            Filtered probe IDs
        """
        if self.annotations is None or self.annotations.empty:
            return probe_ids
        
        # Filter annotations to relevant probes
        subset = self.annotations[self.annotations['Name'].isin(probe_ids)]
        
        # Filter by region
        filtered = []
        for _, row in subset.iterrows():
            gene_group = str(row['UCSC_RefGene_Group'])
            if any(region in gene_group for region in regions):
                filtered.append(row['Name'])
        
        return filtered
    
    def filter_by_cpg_island(
        self,
        probe_ids: List[str],
        relation: str = 'Island'  # 'Island', 'Shore', 'Shelf', 'OpenSea'
    ) -> List[str]:
        """Filter probes by CpG island relationship."""
        if self.annotations is None or self.annotations.empty:
            return probe_ids
        
        subset = self.annotations[self.annotations['Name'].isin(probe_ids)]
        mask = subset['Relation_to_UCSC_CpG_Island'].str.contains(
            relation, na=False
        )
        
        return subset.loc[mask, 'Name'].tolist()


# Example usage (requires annotation file)
annotation_path = project_root / 'data' / 'external' / 'EPIC_manifest.csv'
region_selector = RegionBasedSelector(annotation_path)

print("RegionBasedSelector configured.")
print("Use filter_by_region() for promoter/body filtering.")
print("Use filter_by_cpg_island() for CpG island context filtering.")

## 3. Adding Custom Classifiers

The pipeline supports extending with custom classifiers. Here we demonstrate how to add new models that integrate with the existing evaluation framework.

### 3.1 Wrapper for Custom Classifiers

In [None]:
class CustomClassifierWrapper(BaseEstimator, ClassifierMixin):
    """Wrapper to integrate custom classifiers with the pipeline."""
    
    def __init__(
        self,
        base_classifier: BaseEstimator,
        batch_handling: str = 'none',
        scale_features: bool = True
    ):
        """
        Parameters
        ----------
        base_classifier : BaseEstimator
            Any scikit-learn compatible classifier
        batch_handling : str
            How to handle batch effects: 'none', 'covariate', 'stratified'
        scale_features : bool
            Whether to standardize features before fitting
        """
        self.base_classifier = base_classifier
        self.batch_handling = batch_handling
        self.scale_features = scale_features
        self.scaler = StandardScaler() if scale_features else None
        self._batch_encoder = None
    
    def fit(self, X: np.ndarray, y: np.ndarray, batch: np.ndarray = None):
        """Fit the classifier with optional batch handling."""
        X_processed = X.copy()
        
        # Handle batch effects
        if batch is not None and self.batch_handling == 'covariate':
            # Encode batch as numeric and append to features
            unique_batches = np.unique(batch)
            self._batch_encoder = {b: i for i, b in enumerate(unique_batches)}
            batch_encoded = np.array([self._batch_encoder[b] for b in batch]).reshape(-1, 1)
            X_processed = np.hstack([X_processed, batch_encoded])
        
        # Scale features
        if self.scaler is not None:
            X_processed = self.scaler.fit_transform(X_processed)
        
        # Fit classifier
        self.base_classifier.fit(X_processed, y)
        return self
    
    def predict(self, X: np.ndarray, batch: np.ndarray = None) -> np.ndarray:
        """Make predictions."""
        X_processed = self._preprocess(X, batch)
        return self.base_classifier.predict(X_processed)
    
    def predict_proba(self, X: np.ndarray, batch: np.ndarray = None) -> np.ndarray:
        """Predict class probabilities."""
        X_processed = self._preprocess(X, batch)
        return self.base_classifier.predict_proba(X_processed)
    
    def _preprocess(self, X: np.ndarray, batch: np.ndarray = None) -> np.ndarray:
        """Apply preprocessing steps."""
        X_processed = X.copy()
        
        if batch is not None and self.batch_handling == 'covariate':
            batch_encoded = np.array(
                [self._batch_encoder.get(b, 0) for b in batch]
            ).reshape(-1, 1)
            X_processed = np.hstack([X_processed, batch_encoded])
        
        if self.scaler is not None:
            X_processed = self.scaler.transform(X_processed)
        
        return X_processed
    
    def get_feature_importance(self) -> np.ndarray:
        """Get feature importances if available."""
        if hasattr(self.base_classifier, 'feature_importances_'):
            importances = self.base_classifier.feature_importances_
            # Remove batch covariate importance if added
            if self.batch_handling == 'covariate':
                importances = importances[:-1]
            return importances
        elif hasattr(self.base_classifier, 'coef_'):
            coefs = np.abs(self.base_classifier.coef_[0])
            if self.batch_handling == 'covariate':
                coefs = coefs[:-1]
            return coefs
        else:
            return None


print("CustomClassifierWrapper defined.")

### 3.2 Adding New Classifier Types

In [None]:
# Define new classifier configurations

# Gradient Boosting
gb_classifier = CustomClassifierWrapper(
    base_classifier=GradientBoostingClassifier(
        n_estimators=100,
        max_depth=5,
        learning_rate=0.1,
        random_state=42
    ),
    batch_handling='covariate',
    scale_features=False
)

# AdaBoost
ada_classifier = CustomClassifierWrapper(
    base_classifier=AdaBoostClassifier(
        n_estimators=50,
        learning_rate=1.0,
        random_state=42
    ),
    batch_handling='covariate',
    scale_features=True
)

# Neural Network (MLP)
mlp_classifier = CustomClassifierWrapper(
    base_classifier=MLPClassifier(
        hidden_layer_sizes=(100, 50),
        activation='relu',
        max_iter=500,
        random_state=42
    ),
    batch_handling='covariate',
    scale_features=True
)

# Collection of classifiers for comparison
custom_classifiers = {
    'Gradient Boosting': gb_classifier,
    'AdaBoost': ada_classifier,
    'MLP Neural Network': mlp_classifier
}

print("Custom classifiers configured:")
for name in custom_classifiers:
    print(f"  - {name}")

### 3.3 Comparing Custom Classifiers

In [None]:
def compare_classifiers(
    classifiers: Dict[str, BaseEstimator],
    X: np.ndarray,
    y: np.ndarray,
    batch: np.ndarray = None,
    cv: int = 5
) -> pd.DataFrame:
    """Compare multiple classifiers using cross-validation.
    
    Parameters
    ----------
    classifiers : Dict[str, BaseEstimator]
        Dictionary mapping names to classifier objects
    X : np.ndarray
        Feature matrix
    y : np.ndarray
        Labels
    batch : np.ndarray, optional
        Batch labels for batch-aware training
    cv : int
        Number of cross-validation folds
    
    Returns
    -------
    pd.DataFrame
        Comparison results
    """
    results = []
    
    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=42)
    
    for name, clf in classifiers.items():
        accuracies = []
        aucs = []
        
        for train_idx, test_idx in skf.split(X, y):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]
            
            if batch is not None:
                batch_train = batch[train_idx]
                batch_test = batch[test_idx]
                clf.fit(X_train, y_train, batch=batch_train)
                y_pred = clf.predict(X_test, batch=batch_test)
                y_prob = clf.predict_proba(X_test, batch=batch_test)[:, 1]
            else:
                clf.fit(X_train, y_train)
                y_pred = clf.predict(X_test)
                y_prob = clf.predict_proba(X_test)[:, 1]
            
            accuracies.append(accuracy_score(y_test, y_pred))
            aucs.append(roc_auc_score(y_test, y_prob))
        
        results.append({
            'Classifier': name,
            'Accuracy': np.mean(accuracies),
            'Accuracy_std': np.std(accuracies),
            'AUC': np.mean(aucs),
            'AUC_std': np.std(aucs)
        })
    
    return pd.DataFrame(results)


print("Classifier comparison function defined.")
print("Use compare_classifiers() to evaluate multiple models.")

In [None]:
# Example: Compare classifiers on the demo data
# (Uses subset of features for faster execution)

# Get batch information
batch_demo = sample_info[binary_mask]['study_group'].values

# Compare classifiers
comparison_results = compare_classifiers(
    custom_classifiers,
    X_demo,
    y_demo,
    batch=batch_demo,
    cv=5
)

print("\nClassifier Comparison Results:")
print(comparison_results.to_string(index=False))

## 4. Extended Enrichment Analysis

Extend the enrichment analysis with custom gene sets and additional annotation databases.

### 4.1 Custom Gene Set Analysis

In [None]:
class CustomGeneSetAnalyzer:
    """Perform enrichment analysis with custom gene sets."""
    
    def __init__(self):
        self.custom_gene_sets = {}
    
    def add_gene_set(
        self,
        name: str,
        genes: List[str],
        category: str = 'custom'
    ):
        """Add a custom gene set for analysis.
        
        Parameters
        ----------
        name : str
            Name of the gene set
        genes : List[str]
            List of gene symbols
        category : str
            Category for grouping gene sets
        """
        if category not in self.custom_gene_sets:
            self.custom_gene_sets[category] = {}
        
        self.custom_gene_sets[category][name] = set(genes)
    
    def load_gmt_file(self, gmt_path: Path, category: str = 'gmt'):
        """Load gene sets from GMT format file.
        
        GMT format: gene_set_name<TAB>description<TAB>gene1<TAB>gene2...
        """
        if not gmt_path.exists():
            print(f"Warning: GMT file not found: {gmt_path}")
            return
        
        if category not in self.custom_gene_sets:
            self.custom_gene_sets[category] = {}
        
        with open(gmt_path, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 3:
                    name = parts[0]
                    genes = parts[2:]  # Skip description
                    self.custom_gene_sets[category][name] = set(genes)
    
    def test_enrichment(
        self,
        query_genes: List[str],
        background_genes: List[str],
        min_overlap: int = 3
    ) -> pd.DataFrame:
        """Test enrichment using Fisher's exact test.
        
        Parameters
        ----------
        query_genes : List[str]
            Genes to test (e.g., from selected features)
        background_genes : List[str]
            Background gene set for comparison
        min_overlap : int
            Minimum overlap required for testing
        
        Returns
        -------
        pd.DataFrame
            Enrichment results with p-values and fold enrichment
        """
        from scipy.stats import fisher_exact
        
        query_set = set(query_genes)
        background_set = set(background_genes)
        
        results = []
        
        for category, gene_sets in self.custom_gene_sets.items():
            for name, pathway_genes in gene_sets.items():
                # Calculate overlap
                overlap = query_set & pathway_genes & background_set
                
                if len(overlap) < min_overlap:
                    continue
                
                # Build contingency table
                a = len(overlap)  # Query AND pathway
                b = len(query_set & background_set) - a  # Query NOT pathway
                c = len(pathway_genes & background_set) - a  # Pathway NOT query
                d = len(background_set) - a - b - c  # Neither
                
                # Fisher's exact test
                odds_ratio, p_value = fisher_exact([[a, b], [c, d]], alternative='greater')
                
                # Fold enrichment
                expected = len(query_set) * len(pathway_genes & background_set) / len(background_set)
                fold_enrichment = len(overlap) / expected if expected > 0 else 0
                
                results.append({
                    'Category': category,
                    'Term': name,
                    'Overlap': len(overlap),
                    'GeneSet_Size': len(pathway_genes & background_set),
                    'Fold_Enrichment': fold_enrichment,
                    'P_Value': p_value,
                    'Genes': ';'.join(sorted(overlap))
                })
        
        if not results:
            return pd.DataFrame()
        
        results_df = pd.DataFrame(results)
        
        # Adjust p-values
        from statsmodels.stats.multitest import multipletests
        _, results_df['FDR'], _, _ = multipletests(
            results_df['P_Value'], method='fdr_bh'
        )
        
        return results_df.sort_values('P_Value')


print("CustomGeneSetAnalyzer defined.")

In [None]:
# Example: Create analyzer with exercise-related gene sets

gene_set_analyzer = CustomGeneSetAnalyzer()

# Add custom exercise/metabolism-related gene sets
# These are example gene sets - replace with actual research-relevant genes

gene_set_analyzer.add_gene_set(
    name='Mitochondrial_Biogenesis',
    genes=['PPARGC1A', 'TFAM', 'NRF1', 'NRF2', 'ESRRA', 'GABPA'],
    category='Exercise_Response'
)

gene_set_analyzer.add_gene_set(
    name='Glucose_Metabolism',
    genes=['SLC2A4', 'HK2', 'PFKM', 'PDK4', 'GYS1', 'PYGM'],
    category='Exercise_Response'
)

gene_set_analyzer.add_gene_set(
    name='Muscle_Adaptation',
    genes=['MYOD1', 'MYF5', 'MYOG', 'MRF4', 'MEF2A', 'MEF2C'],
    category='Exercise_Response'
)

gene_set_analyzer.add_gene_set(
    name='Inflammatory_Response',
    genes=['IL6', 'TNF', 'IL1B', 'CRP', 'CCL2', 'NFKB1'],
    category='Exercise_Response'
)

print(f"Added {sum(len(gs) for gs in gene_set_analyzer.custom_gene_sets.values())} gene sets")
for category, gene_sets in gene_set_analyzer.custom_gene_sets.items():
    print(f"  {category}: {len(gene_sets)} sets")

### 4.2 Multi-Database Enrichment

In [None]:
class MultiDatabaseEnrichment:
    """Run enrichment across multiple databases and combine results."""
    
    def __init__(self):
        self.databases = {}
    
    def add_database(
        self,
        name: str,
        analyzer: Any
    ):
        """Add a database/analyzer for enrichment."""
        self.databases[name] = analyzer
    
    def run_all(
        self,
        query_genes: List[str],
        background_genes: List[str] = None,
        top_n: int = 10
    ) -> Dict[str, pd.DataFrame]:
        """Run enrichment across all databases.
        
        Returns
        -------
        Dict[str, pd.DataFrame]
            Results from each database
        """
        results = {}
        
        for name, analyzer in self.databases.items():
            print(f"Running enrichment: {name}")
            
            try:
                if hasattr(analyzer, 'test_enrichment'):
                    # Custom gene set analyzer
                    result = analyzer.test_enrichment(
                        query_genes,
                        background_genes or query_genes
                    )
                elif hasattr(analyzer, 'run_analysis'):
                    # Built-in enrichment analyzer
                    result = analyzer.run_analysis(query_genes)
                else:
                    result = pd.DataFrame()
                
                if not result.empty:
                    results[name] = result.head(top_n)
                else:
                    results[name] = pd.DataFrame({'Message': ['No significant results']})
            
            except Exception as e:
                print(f"  Error in {name}: {e}")
                results[name] = pd.DataFrame({'Error': [str(e)]})
        
        return results
    
    def summarize_results(self, results: Dict[str, pd.DataFrame]) -> pd.DataFrame:
        """Create summary table across all databases."""
        summaries = []
        
        for db_name, result_df in results.items():
            if 'Term' in result_df.columns and 'P_Value' in result_df.columns:
                top_term = result_df.iloc[0]
                summaries.append({
                    'Database': db_name,
                    'Top_Term': top_term.get('Term', 'N/A'),
                    'P_Value': top_term.get('P_Value', 1.0),
                    'Total_Significant': len(result_df[result_df.get('FDR', result_df.get('P_Value', [1.0])) < 0.05])
                })
        
        return pd.DataFrame(summaries)


# Setup multi-database enrichment
multi_enrichment = MultiDatabaseEnrichment()
multi_enrichment.add_database('Custom_Exercise', gene_set_analyzer)

print("MultiDatabaseEnrichment configured.")
print("Add more databases with add_database() method.")

## 5. Custom Visualizations

Create publication-ready figures with custom styling and layouts.

### 5.1 Publication Figure Style Configuration

In [None]:
@dataclass
class PublicationStyle:
    """Configuration for publication-quality figures."""
    
    # Figure dimensions (in inches)
    single_column_width: float = 3.5  # Nature/Science single column
    double_column_width: float = 7.0  # Full page width
    max_height: float = 9.0
    
    # Font settings
    font_family: str = 'Arial'
    title_size: int = 10
    label_size: int = 9
    tick_size: int = 8
    legend_size: int = 8
    
    # Color palette
    primary_colors: List[str] = field(default_factory=lambda: [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'
    ])
    
    # Line and marker settings
    line_width: float = 1.5
    marker_size: int = 6
    
    # DPI for saving
    save_dpi: int = 300
    
    def apply(self):
        """Apply the style settings to matplotlib."""
        plt.rcParams.update({
            'font.family': self.font_family,
            'font.size': self.label_size,
            'axes.titlesize': self.title_size,
            'axes.labelsize': self.label_size,
            'xtick.labelsize': self.tick_size,
            'ytick.labelsize': self.tick_size,
            'legend.fontsize': self.legend_size,
            'lines.linewidth': self.line_width,
            'lines.markersize': self.marker_size,
            'axes.linewidth': 0.8,
            'axes.spines.top': False,
            'axes.spines.right': False,
            'figure.dpi': 100,
            'savefig.dpi': self.save_dpi,
            'savefig.bbox': 'tight',
            'savefig.pad_inches': 0.1
        })


# Apply publication style
pub_style = PublicationStyle()
pub_style.apply()

print("Publication style applied.")

### 5.2 Custom Visualization Functions

In [None]:
def plot_multi_panel_figure(
    data_dict: Dict[str, Tuple[np.ndarray, np.ndarray]],
    style: PublicationStyle = None,
    title: str = None,
    save_path: Path = None
) -> Tuple[plt.Figure, np.ndarray]:
    """Create a multi-panel comparison figure.
    
    Parameters
    ----------
    data_dict : Dict[str, Tuple[np.ndarray, np.ndarray]]
        Dictionary mapping panel names to (x, y) data tuples
    style : PublicationStyle, optional
        Style configuration
    title : str, optional
        Figure title
    save_path : Path, optional
        Path to save the figure
    
    Returns
    -------
    Tuple[plt.Figure, np.ndarray]
        Figure and axes array
    """
    if style is None:
        style = PublicationStyle()
    
    n_panels = len(data_dict)
    n_cols = min(3, n_panels)
    n_rows = (n_panels + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(
        n_rows, n_cols,
        figsize=(style.single_column_width * n_cols, 3 * n_rows)
    )
    
    if n_panels == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for idx, (name, (x_data, y_data)) in enumerate(data_dict.items()):
        ax = axes[idx]
        ax.scatter(x_data, y_data, alpha=0.6, c=style.primary_colors[idx % len(style.primary_colors)])
        ax.set_title(name)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
    
    # Hide unused axes
    for idx in range(n_panels, len(axes)):
        axes[idx].set_visible(False)
    
    if title:
        fig.suptitle(title, fontsize=style.title_size + 2, y=1.02)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=style.save_dpi, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    return fig, axes


print("Custom visualization functions defined.")

In [None]:
def plot_enrichment_dotplot(
    enrichment_df: pd.DataFrame,
    x_col: str = 'Fold_Enrichment',
    y_col: str = 'Term',
    size_col: str = 'Overlap',
    color_col: str = 'P_Value',
    title: str = 'Enrichment Analysis',
    max_terms: int = 15,
    save_path: Path = None
) -> Tuple[plt.Figure, plt.Axes]:
    """Create a dot plot for enrichment results.
    
    Parameters
    ----------
    enrichment_df : pd.DataFrame
        Enrichment results with required columns
    x_col : str
        Column for x-axis (typically fold enrichment)
    y_col : str
        Column for y-axis (term names)
    size_col : str
        Column for dot size (typically overlap count)
    color_col : str
        Column for dot color (typically p-value)
    title : str
        Plot title
    max_terms : int
        Maximum number of terms to display
    save_path : Path, optional
        Path to save the figure
    
    Returns
    -------
    Tuple[plt.Figure, plt.Axes]
        Figure and axes
    """
    if enrichment_df.empty:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'No enrichment results to display',
                ha='center', va='center', transform=ax.transAxes)
        return fig, ax
    
    # Prepare data
    plot_df = enrichment_df.head(max_terms).copy()
    plot_df = plot_df.iloc[::-1]  # Reverse for top-to-bottom display
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 0.4 * len(plot_df) + 2))
    
    # Create scatter plot
    scatter = ax.scatter(
        plot_df[x_col],
        range(len(plot_df)),
        s=plot_df[size_col] * 20,  # Scale size
        c=-np.log10(plot_df[color_col]),  # -log10 p-value
        cmap='RdYlBu_r',
        alpha=0.8,
        edgecolors='black',
        linewidths=0.5
    )
    
    # Customize axes
    ax.set_yticks(range(len(plot_df)))
    ax.set_yticklabels(plot_df[y_col])
    ax.set_xlabel(x_col.replace('_', ' '))
    ax.set_title(title)
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax, shrink=0.6)
    cbar.set_label('-log10(P-value)')
    
    # Add size legend
    sizes = [5, 10, 20]
    for size in sizes:
        ax.scatter([], [], s=size * 20, c='gray', alpha=0.6,
                   label=f'{size} genes', edgecolors='black', linewidths=0.5)
    ax.legend(title='Overlap', loc='lower right', framealpha=0.9)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    return fig, ax


print("Enrichment dot plot function defined.")

### 5.3 Composite Figure Creation

In [None]:
def create_composite_figure(
    panels: Dict[str, Tuple[callable, Dict]],
    layout: Tuple[int, int] = None,
    figsize: Tuple[float, float] = None,
    save_path: Path = None
) -> plt.Figure:
    """Create a composite figure with multiple panel types.
    
    Parameters
    ----------
    panels : Dict[str, Tuple[callable, Dict]]
        Dictionary mapping panel labels (A, B, C...) to
        (plotting_function, kwargs) tuples
    layout : Tuple[int, int], optional
        Grid layout (rows, cols)
    figsize : Tuple[float, float], optional
        Figure size in inches
    save_path : Path, optional
        Path to save the figure
    
    Returns
    -------
    plt.Figure
        The composite figure
    """
    n_panels = len(panels)
    
    if layout is None:
        n_cols = min(2, n_panels)
        n_rows = (n_panels + n_cols - 1) // n_cols
        layout = (n_rows, n_cols)
    
    if figsize is None:
        figsize = (7 * layout[1], 5 * layout[0])
    
    fig, axes = plt.subplots(layout[0], layout[1], figsize=figsize)
    
    if n_panels == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for idx, (label, (plot_func, kwargs)) in enumerate(panels.items()):
        ax = axes[idx]
        
        # Add panel label
        ax.text(-0.1, 1.1, label, transform=ax.transAxes,
                fontsize=14, fontweight='bold', va='top')
        
        # Execute plotting function
        try:
            plot_func(ax=ax, **kwargs)
        except Exception as e:
            ax.text(0.5, 0.5, f'Error: {str(e)[:50]}',
                    ha='center', va='center', transform=ax.transAxes)
    
    # Hide unused axes
    for idx in range(n_panels, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Composite figure saved to: {save_path}")
    
    return fig


print("Composite figure function defined.")

## 6. Pipeline Extensions

Integrate external tools and create custom analysis workflows.

### 6.1 External Tool Integration Template

In [None]:
class ExternalToolIntegration:
    """Template for integrating external bioinformatics tools."""
    
    def __init__(self, tool_path: str = None):
        """
        Parameters
        ----------
        tool_path : str, optional
            Path to the external tool executable
        """
        self.tool_path = tool_path
        self.results = None
    
    def prepare_input(
        self,
        data: pd.DataFrame,
        output_path: Path,
        format: str = 'tsv'
    ) -> Path:
        """Prepare input file for external tool.
        
        Parameters
        ----------
        data : pd.DataFrame
            Data to export
        output_path : Path
            Output file path
        format : str
            Output format ('tsv', 'csv', 'bed')
        
        Returns
        -------
        Path
            Path to the prepared file
        """
        output_path = Path(output_path)
        
        if format == 'tsv':
            data.to_csv(output_path, sep='\t', index=False)
        elif format == 'csv':
            data.to_csv(output_path, index=False)
        elif format == 'bed':
            # BED format: chr, start, end, name, score, strand
            if all(col in data.columns for col in ['chr', 'start', 'end']):
                data[['chr', 'start', 'end']].to_csv(
                    output_path, sep='\t', index=False, header=False
                )
        
        return output_path
    
    def run_tool(self, input_path: Path, **kwargs) -> str:
        """Run the external tool.
        
        This is a template - override in subclass for specific tools.
        
        Parameters
        ----------
        input_path : Path
            Path to input file
        **kwargs
            Additional tool-specific arguments
        
        Returns
        -------
        str
            Tool output or path to output file
        """
        import subprocess
        
        if self.tool_path is None:
            raise ValueError("Tool path not specified")
        
        cmd = [self.tool_path, str(input_path)]
        
        # Add additional arguments
        for key, value in kwargs.items():
            cmd.extend([f'--{key}', str(value)])
        
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        if result.returncode != 0:
            raise RuntimeError(f"Tool failed: {result.stderr}")
        
        return result.stdout
    
    def parse_output(
        self,
        output: str,
        format: str = 'tsv'
    ) -> pd.DataFrame:
        """Parse tool output into DataFrame.
        
        Parameters
        ----------
        output : str
            Tool output string or path to output file
        format : str
            Output format
        
        Returns
        -------
        pd.DataFrame
            Parsed results
        """
        output_path = Path(output)
        
        if output_path.exists():
            if format == 'tsv':
                return pd.read_csv(output_path, sep='\t')
            elif format == 'csv':
                return pd.read_csv(output_path)
        else:
            # Parse string output
            from io import StringIO
            return pd.read_csv(StringIO(output), sep='\t')


print("ExternalToolIntegration template defined.")
print("Subclass this for specific tools (HOMER, GREAT, etc.).")

### 6.2 Custom Analysis Pipeline

In [None]:
class CustomAnalysisPipeline:
    """Template for creating custom analysis pipelines."""
    
    def __init__(self, config: Dict[str, Any] = None):
        """
        Parameters
        ----------
        config : Dict[str, Any], optional
            Pipeline configuration
        """
        self.config = config or {}
        self.steps = []
        self.results = {}
    
    def add_step(
        self,
        name: str,
        function: callable,
        depends_on: List[str] = None,
        **kwargs
    ):
        """Add a step to the pipeline.
        
        Parameters
        ----------
        name : str
            Step name (used as key in results)
        function : callable
            Function to execute
        depends_on : List[str], optional
            Names of steps this depends on
        **kwargs
            Additional arguments for the function
        """
        self.steps.append({
            'name': name,
            'function': function,
            'depends_on': depends_on or [],
            'kwargs': kwargs
        })
    
    def run(self, initial_data: Any = None) -> Dict[str, Any]:
        """Execute the pipeline.
        
        Parameters
        ----------
        initial_data : Any, optional
            Initial data to pass to first step
        
        Returns
        -------
        Dict[str, Any]
            Results from all steps
        """
        self.results = {'initial': initial_data}
        
        for step in self.steps:
            print(f"Running step: {step['name']}")
            
            # Gather inputs from dependencies
            inputs = {}
            for dep in step['depends_on']:
                if dep in self.results:
                    inputs[dep] = self.results[dep]
                else:
                    raise ValueError(f"Dependency '{dep}' not found")
            
            # Add initial data if no dependencies
            if not inputs and initial_data is not None:
                inputs['data'] = initial_data
            
            # Execute step
            try:
                result = step['function'](**inputs, **step['kwargs'])
                self.results[step['name']] = result
                print(f"  Completed: {step['name']}")
            except Exception as e:
                print(f"  Error in {step['name']}: {e}")
                self.results[step['name']] = None
        
        return self.results
    
    def get_result(self, step_name: str) -> Any:
        """Get result from a specific step."""
        return self.results.get(step_name)
    
    def summarize(self) -> pd.DataFrame:
        """Create summary of pipeline execution."""
        summary = []
        for step in self.steps:
            result = self.results.get(step['name'])
            summary.append({
                'Step': step['name'],
                'Status': 'Success' if result is not None else 'Failed',
                'Result_Type': type(result).__name__ if result is not None else 'None'
            })
        return pd.DataFrame(summary)


print("CustomAnalysisPipeline template defined.")

In [None]:
# Example: Create a custom pipeline

def step_filter_features(data: pd.DataFrame, variance_threshold: float = 0.01):
    """Example step: filter by variance."""
    variances = data.var(axis=1)
    return data.loc[variances > variance_threshold]

def step_select_top(step_filter_features: pd.DataFrame, n_top: int = 100):
    """Example step: select top variable features."""
    variances = step_filter_features.var(axis=1)
    top_features = variances.nlargest(n_top).index
    return step_filter_features.loc[top_features]

def step_summarize(step_select_top: pd.DataFrame):
    """Example step: create summary statistics."""
    return {
        'n_features': len(step_select_top),
        'mean_variance': step_select_top.var(axis=1).mean()
    }

# Build pipeline
pipeline = CustomAnalysisPipeline()
pipeline.add_step('step_filter_features', step_filter_features, variance_threshold=0.01)
pipeline.add_step('step_select_top', step_select_top, depends_on=['step_filter_features'], n_top=50)
pipeline.add_step('step_summarize', step_summarize, depends_on=['step_select_top'])

print("Example pipeline configured with 3 steps.")
print("Run with: pipeline.run(methylation_data)")

## 7. Putting It All Together

Example of a complete custom analysis workflow combining the techniques above.

In [None]:
def run_complete_custom_analysis(
    methylation_data: pd.DataFrame,
    sample_mapping: pd.DataFrame,
    feature_config: CustomFeatureConfig = None,
    output_dir: Path = None
) -> Dict[str, Any]:
    """Run a complete custom analysis workflow.
    
    This demonstrates how to combine custom configurations,
    classifiers, and visualizations into a cohesive analysis.
    
    Parameters
    ----------
    methylation_data : pd.DataFrame
        Preprocessed methylation data (probes x samples)
    sample_mapping : pd.DataFrame
        Sample metadata
    feature_config : CustomFeatureConfig, optional
        Custom feature selection configuration
    output_dir : Path, optional
        Directory to save outputs
    
    Returns
    -------
    Dict[str, Any]
        Analysis results
    """
    results = {}
    
    # Step 1: Prepare data
    print("Step 1: Preparing data...")
    sample_ids = methylation_data.columns.tolist()
    sample_info = sample_mapping.set_index('sample_id').loc[sample_ids].reset_index()
    
    binary_mask = sample_info['binary_class'].isin(['HIIT', 'Control'])
    binary_samples = sample_info[binary_mask]['sample_id'].tolist()
    binary_labels = (sample_info[binary_mask]['binary_class'] == 'HIIT').astype(int).values
    batch = sample_info[binary_mask]['study_group'].values
    
    results['n_samples'] = len(binary_samples)
    results['n_features_initial'] = methylation_data.shape[0]
    
    # Step 2: Custom feature selection
    print("Step 2: Running custom feature selection...")
    if feature_config is None:
        feature_config = CustomFeatureConfig()
    
    feature_names = methylation_data.index.tolist()
    X = methylation_data[binary_samples].T.values
    y = binary_labels
    
    selector = CustomFeatureSelector(feature_config)
    selection_results = selector.select_features(X, y, feature_names)
    
    results['features_statistical'] = len(selection_results['statistical'])
    results['features_ml'] = len(selection_results['ml_based'])
    results['features_consensus'] = len(selection_results['consensus'])
    
    # Step 3: Train and evaluate classifiers
    print("Step 3: Training classifiers...")
    
    # Use consensus features
    selected_features = list(selection_results['consensus'])
    if len(selected_features) == 0:
        selected_features = list(selection_results['statistical'])[:100]
    
    available_features = [f for f in selected_features if f in methylation_data.index]
    X_selected = methylation_data.loc[available_features, binary_samples].T.values
    
    # Compare classifiers
    classifiers = {
        'Gradient Boosting': CustomClassifierWrapper(
            GradientBoostingClassifier(n_estimators=50, random_state=42),
            batch_handling='covariate'
        ),
        'MLP': CustomClassifierWrapper(
            MLPClassifier(hidden_layer_sizes=(50,), max_iter=300, random_state=42),
            batch_handling='covariate'
        )
    }
    
    comparison = compare_classifiers(classifiers, X_selected, y, batch=batch, cv=5)
    results['classifier_comparison'] = comparison
    
    # Step 4: Save outputs
    if output_dir is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Save feature lists
        pd.DataFrame({'probe_id': list(selection_results['consensus'])}).to_csv(
            output_dir / 'consensus_features.csv', index=False
        )
        
        # Save classifier comparison
        comparison.to_csv(output_dir / 'classifier_comparison.csv', index=False)
        
        print(f"Results saved to: {output_dir}")
    
    print("\nAnalysis complete!")
    return results


print("Complete custom analysis function defined.")
print("Run with: run_complete_custom_analysis(methylation_data, sample_mapping)")

## Summary

This notebook demonstrated advanced customization options for the HIIT methylation analysis pipeline:

### Key Customizations Covered

1. **Feature Selection**
   - Custom stringency configurations
   - Region-based filtering (promoters, gene bodies)
   - Multi-method consensus selection

2. **Classifiers**
   - Wrapper for integrating new models
   - Gradient Boosting, AdaBoost, MLP examples
   - Batch-aware training support

3. **Enrichment Analysis**
   - Custom gene set creation
   - GMT file loading
   - Multi-database analysis

4. **Visualizations**
   - Publication-quality styling
   - Multi-panel figures
   - Enrichment dot plots

5. **Pipeline Extensions**
   - External tool integration template
   - Custom analysis pipelines

### Next Steps

- Adapt these templates to your specific research questions
- Add domain-specific gene sets for enrichment
- Integrate additional external tools as needed
- Create custom visualization styles for your publications

In [None]:
# Session summary
print("=" * 60)
print("CUSTOM ANALYSIS NOTEBOOK READY")
print("=" * 60)
print("\nCustomization options demonstrated:")
print("  - CustomFeatureConfig: Flexible feature selection")
print("  - CustomClassifierWrapper: New model integration")
print("  - CustomGeneSetAnalyzer: Custom enrichment")
print("  - PublicationStyle: Figure formatting")
print("  - CustomAnalysisPipeline: Workflow automation")
print("\nRefer to the main pipeline notebooks (01-05) for")
print("standard analysis workflows.")