# Data Preprocessing

## Overview

This notebook demonstrates the preprocessing pipeline for DNA methylation data, preparing it for feature selection and classification. Proper preprocessing is critical for obtaining reliable and reproducible results in methylation-based biomarker discovery.

### Learning Objectives

By the end of this notebook, you will be able to:

1. Assess data quality and identify issues
2. Filter low-variance CpG probes
3. Handle missing values appropriately
4. Detect and correct batch effects
5. Create multiple preprocessing versions for robust analysis

### Prerequisites

This notebook assumes you have completed **01_data_acquisition.ipynb** and have:
- Downloaded the GSE171140 series matrix file
- Created the sample mapping file

## 1. Environment Setup

In [None]:
# Standard library imports
import sys
import logging
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Project-specific imports
from src.data.loader import GEODataLoader
from src.data.preprocessing import (
    MethylationPreprocessor,
    normalize_beta_values,
    calculate_missing_rate
)
from src.visualization import (
    plot_pca_visualization,
    plot_heatmap
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Visualization settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('colorblind')

print(f"Project root: {project_root}")

## 2. Load Raw Data

In [None]:
# Define paths
data_dir = project_root / 'data' / 'raw'
processed_dir = project_root / 'data' / 'processed'
figures_dir = project_root / 'data' / 'figures' / 'qc'

# Create output directories
processed_dir.mkdir(parents=True, exist_ok=True)
figures_dir.mkdir(parents=True, exist_ok=True)

# Load methylation data
loader = GEODataLoader('GSE171140', data_dir=data_dir)
methylation_data = loader.load_methylation_matrix()

# Load sample mapping
sample_mapping = pd.read_csv(data_dir / 'GSE171140_sample_mapping.csv')

print(f"Methylation data shape: {methylation_data.shape}")
print(f"Sample mapping shape: {sample_mapping.shape}")

## 3. Initial Data Quality Assessment

Before preprocessing, we need to understand the quality characteristics of our data:

- **Beta value range**: Should be [0, 1]
- **Missing value patterns**: How many and where?
- **Variance distribution**: Identify uninformative probes
- **Sample-to-sample correlations**: Identify outliers

In [None]:
# Basic data quality statistics
print("=" * 60)
print("DATA QUALITY SUMMARY")
print("=" * 60)

# Beta value range check
print(f"\nBeta Value Range:")
print(f"  Minimum: {methylation_data.min().min():.6f}")
print(f"  Maximum: {methylation_data.max().max():.6f}")
print(f"  Mean: {methylation_data.mean().mean():.6f}")

# Missing value assessment
total_missing = methylation_data.isna().sum().sum()
total_values = methylation_data.size
missing_pct = (total_missing / total_values) * 100

print(f"\nMissing Values:")
print(f"  Total missing: {total_missing:,}")
print(f"  Total values: {total_values:,}")
print(f"  Missing percentage: {missing_pct:.4f}%")

In [None]:
# Probe-level statistics
probe_missing_rate = calculate_missing_rate(methylation_data, axis=1)
probe_variance = methylation_data.var(axis=1)

print("\nProbe-level Statistics:")
print(f"  Probes with any missing: {(probe_missing_rate > 0).sum():,}")
print(f"  Probes with >10% missing: {(probe_missing_rate > 0.1).sum():,}")
print(f"  Probes with >20% missing: {(probe_missing_rate > 0.2).sum():,}")

print(f"\nVariance Statistics:")
print(f"  Min variance: {probe_variance.min():.6f}")
print(f"  Max variance: {probe_variance.max():.6f}")
print(f"  Median variance: {probe_variance.median():.6f}")

In [None]:
# Visualize probe variance distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Variance histogram
axes[0].hist(probe_variance, bins=100, edgecolor='none', alpha=0.7)
axes[0].axvline(x=0.02, color='red', linestyle='--', label='Threshold (0.02)')
axes[0].set_xlabel('Probe Variance')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Probe Variance')
axes[0].legend()
axes[0].set_xlim(0, 0.1)

# Missing rate histogram
axes[1].hist(probe_missing_rate[probe_missing_rate > 0], bins=50, edgecolor='none', alpha=0.7)
axes[1].axvline(x=0.2, color='red', linestyle='--', label='Threshold (20%)')
axes[1].set_xlabel('Missing Rate per Probe')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of Missing Rates')
axes[1].legend()

plt.tight_layout()
plt.savefig(figures_dir / 'qc_probe_statistics.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Normalize Beta Values

Ensure all beta values are within the valid range [0, 1]. Values outside this range can occur due to technical artifacts and should be clipped.

In [None]:
# Normalize beta values to valid range
methylation_normalized = normalize_beta_values(
    methylation_data,
    clip_range=(0.0, 1.0)
)

print("After normalization:")
print(f"  Min: {methylation_normalized.min().min():.6f}")
print(f"  Max: {methylation_normalized.max().max():.6f}")

## 5. Initialize the Preprocessor

The `MethylationPreprocessor` class provides a unified interface for all preprocessing operations with configurable thresholds.

In [None]:
# Initialize preprocessor with configurable thresholds
# std_threshold: probes with std < threshold are removed
# missing_threshold: probes with missing rate > threshold are removed

preprocessor = MethylationPreprocessor(
    std_threshold=0.02,      # Remove probes with std < 0.02
    missing_threshold=0.2    # Remove probes with >20% missing
)

print("Preprocessor configuration:")
print(f"  Standard deviation threshold: {preprocessor.std_threshold}")
print(f"  Missing value threshold: {preprocessor.missing_threshold}")

## 6. Filter Low-Variance Probes

CpG sites with low variance across samples provide little discriminative information and can increase noise. We remove probes below a minimum variance threshold.

In [None]:
# Filter low variance probes
print(f"\nOriginal probe count: {methylation_normalized.shape[0]:,}")

filtered_data, probe_stats = preprocessor.filter_low_variance(
    methylation_normalized,
    threshold=0.02,
    return_stats=True
)

n_removed = methylation_normalized.shape[0] - filtered_data.shape[0]
pct_removed = (n_removed / methylation_normalized.shape[0]) * 100

print(f"\nAfter variance filtering:")
print(f"  Remaining probes: {filtered_data.shape[0]:,}")
print(f"  Removed probes: {n_removed:,} ({pct_removed:.1f}%)")

In [None]:
# Visualize probe statistics
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Mean vs Std plot
sample_idx = np.random.choice(len(probe_stats), min(10000, len(probe_stats)), replace=False)
axes[0].scatter(
    probe_stats.iloc[sample_idx]['mean'],
    probe_stats.iloc[sample_idx]['std'],
    alpha=0.3, s=1
)
axes[0].axhline(y=0.02, color='red', linestyle='--', label='Variance threshold')
axes[0].set_xlabel('Mean Beta Value')
axes[0].set_ylabel('Standard Deviation')
axes[0].set_title('Mean-Variance Relationship')
axes[0].legend()

# Retained vs removed comparison
retained_mask = probe_stats['std'] >= 0.02
axes[1].hist(probe_stats.loc[~retained_mask, 'std'], bins=50, alpha=0.5, label='Removed', color='red')
axes[1].hist(probe_stats.loc[retained_mask, 'std'], bins=50, alpha=0.5, label='Retained', color='green')
axes[1].set_xlabel('Standard Deviation')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Variance Distribution by Filtering Status')
axes[1].legend()

plt.tight_layout()
plt.savefig(figures_dir / 'qc_variance_filtering.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Handle Missing Values

Missing values in methylation data can occur due to detection failures or quality filtering. We use a two-step approach:

1. **Remove probes with high missing rates** (>20%)
2. **Impute remaining missing values** using probe median

In [None]:
# Check missing values after variance filtering
missing_before = filtered_data.isna().sum().sum()
print(f"Missing values before imputation: {missing_before:,}")

# Handle missing values
imputed_data = preprocessor.handle_missing_values(
    filtered_data,
    strategy='median',  # Options: 'median', 'mean', 'knn', 'drop'
    drop_threshold=0.2
)

missing_after = imputed_data.isna().sum().sum()
print(f"Missing values after imputation: {missing_after}")
print(f"\nFinal probe count: {imputed_data.shape[0]:,}")

## 8. Batch Effect Detection

Batch effects are systematic technical variations that can confound biological signals. We detect potential batch effects by:

1. Examining PCA plots colored by study group
2. Analyzing sample correlations within and between batches
3. Testing for significant differences between batch groups

In [None]:
# Create batch information from sample mapping
# In this study, 'study_group' can indicate different batches
sample_ids = imputed_data.columns.tolist()
batch_info = sample_mapping.set_index('sample_id').loc[sample_ids, 'study_group']

print("Batch distribution:")
print(batch_info.value_counts())

In [None]:
# PCA visualization to detect batch effects
# Transpose data: samples as rows, probes as columns
data_for_pca = imputed_data.T

# Create sample labels for coloring
sample_labels = batch_info.values

# Perform PCA visualization
fig, ax, pca_result = plot_pca_visualization(
    data_for_pca,
    sample_labels,
    title='PCA of Methylation Data by Study Group',
    figsize=(10, 8)
)

plt.savefig(figures_dir / 'pca_batch_effect.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPCA variance explained:")
print(f"  PC1: {pca_result.explained_variance_ratio_[0]*100:.1f}%")
print(f"  PC2: {pca_result.explained_variance_ratio_[1]*100:.1f}%")

In [None]:
# Color by timepoint to compare with batch structure
timepoint_labels = sample_mapping.set_index('sample_id').loc[sample_ids, 'time_point'].values

fig, ax, _ = plot_pca_visualization(
    data_for_pca,
    timepoint_labels,
    title='PCA of Methylation Data by Timepoint',
    figsize=(10, 8)
)

plt.savefig(figures_dir / 'pca_timepoint.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Batch Effect Correction (Optional)

If significant batch effects are detected, we can apply correction methods. The preprocessor supports:

- **Median centering**: Simple approach that centers each batch to the global median
- **ComBat**: Empirical Bayes method (requires pycombat package)

In [None]:
# Apply batch correction if needed
# Uncomment if batch effects are significant in your analysis

# batch_corrected = preprocessor.apply_batch_correction(
#     imputed_data,
#     batch_info,
#     method='median_centering'  # Options: 'median_centering', 'combat'
# )

# For this example, we'll continue without batch correction
# The decision depends on the magnitude of batch effects relative to biological signal

print("Batch correction: Not applied (optional step)")
print("Evaluate PCA plots to determine if correction is needed.")

## 10. Create Multiple Data Versions

For robust analysis, we create multiple versions of the preprocessed data:

1. **Original**: Filtered and imputed data
2. **Standardized**: Z-score normalized per probe
3. **Batch-corrected**: With median centering (if applicable)

Different versions may perform better for different classification tasks.

In [None]:
# Create multiple data versions
data_versions = preprocessor.create_data_versions(
    imputed_data,
    batch_info=batch_info
)

print("Created data versions:")
for version_name, version_data in data_versions.items():
    print(f"  {version_name}: {version_data.shape}")

In [None]:
# Compare data distributions across versions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for ax, (version_name, version_data) in zip(axes.flat, data_versions.items()):
    # Sample a subset for visualization
    sample_values = version_data.values.flatten()
    sample_values = sample_values[~np.isnan(sample_values)]
    sample_values = np.random.choice(sample_values, min(100000, len(sample_values)), replace=False)
    
    ax.hist(sample_values, bins=100, edgecolor='none', alpha=0.7)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{version_name.replace("_", " ").title()}')

plt.tight_layout()
plt.savefig(figures_dir / 'data_version_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

## 11. Sample Quality Control

Identify potential outlier samples that may need exclusion from downstream analysis.

In [None]:
# Calculate sample-level statistics
sample_stats = pd.DataFrame({
    'sample_id': imputed_data.columns,
    'mean': imputed_data.mean(axis=0),
    'std': imputed_data.std(axis=0),
    'median': imputed_data.median(axis=0)
})

# Add sample information
sample_stats = sample_stats.merge(
    sample_mapping[['sample_id', 'time_point', 'binary_class']],
    on='sample_id'
)

print("Sample statistics summary:")
print(sample_stats[['mean', 'std', 'median']].describe())

In [None]:
# Identify potential outliers using IQR method
Q1 = sample_stats['mean'].quantile(0.25)
Q3 = sample_stats['mean'].quantile(0.75)
IQR = Q3 - Q1

lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR

outliers = sample_stats[
    (sample_stats['mean'] < lower_bound) | 
    (sample_stats['mean'] > upper_bound)
]

print(f"\nPotential outlier samples: {len(outliers)}")
if len(outliers) > 0:
    print(outliers[['sample_id', 'mean', 'time_point']])

In [None]:
# Visualize sample statistics by group
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Boxplot of mean by timepoint
sample_stats.boxplot(column='mean', by='time_point', ax=axes[0])
axes[0].set_xlabel('Timepoint')
axes[0].set_ylabel('Mean Beta Value')
axes[0].set_title('Sample Mean by Timepoint')
plt.suptitle('')

# Mean vs Std scatter
for group in sample_stats['binary_class'].unique():
    mask = sample_stats['binary_class'] == group
    axes[1].scatter(
        sample_stats.loc[mask, 'mean'],
        sample_stats.loc[mask, 'std'],
        label=group, alpha=0.7
    )
axes[1].set_xlabel('Mean Beta Value')
axes[1].set_ylabel('Standard Deviation')
axes[1].set_title('Sample Mean vs Std by Class')
axes[1].legend()

plt.tight_layout()
plt.savefig(figures_dir / 'sample_qc.png', dpi=150, bbox_inches='tight')
plt.show()

## 12. Save Preprocessed Data

Save all preprocessed data versions for use in feature selection and classification.

In [None]:
import pickle
import json

# Save the primary preprocessed data
preprocessed_path = processed_dir / 'methyl_data_preprocessed.pkl'
with open(preprocessed_path, 'wb') as f:
    pickle.dump(imputed_data, f)

print(f"Preprocessed data saved to: {preprocessed_path}")

# Save all versions
versions_path = processed_dir / 'methyl_data_versions.pkl'
with open(versions_path, 'wb') as f:
    pickle.dump(data_versions, f)

print(f"All data versions saved to: {versions_path}")

In [None]:
# Save preprocessing configuration for reproducibility
preprocessing_config = {
    'std_threshold': preprocessor.std_threshold,
    'missing_threshold': preprocessor.missing_threshold,
    'original_probes': methylation_data.shape[0],
    'filtered_probes': imputed_data.shape[0],
    'n_samples': imputed_data.shape[1],
    'data_versions': list(data_versions.keys())
}

config_path = processed_dir / 'preprocessing_config.json'
with open(config_path, 'w') as f:
    json.dump(preprocessing_config, f, indent=2)

print(f"Configuration saved to: {config_path}")

## Summary

In this notebook, we completed the following preprocessing steps:

1. **Quality Assessment**: Evaluated raw data for missing values and variance distribution
2. **Beta Value Normalization**: Clipped values to [0, 1] range
3. **Variance Filtering**: Removed low-variance probes
4. **Missing Value Imputation**: Handled missing data using median imputation
5. **Batch Effect Detection**: Visualized potential batch effects using PCA
6. **Multi-Version Creation**: Created original, standardized, and batch-corrected versions
7. **Sample QC**: Identified potential outlier samples

### Next Steps

Continue to **03_feature_selection.ipynb** to:
- Apply the Ten-Level Feature Selection Framework
- Select binary classification features (HIIT vs Control)
- Select multiclass classification features (4W/8W/12W)
- Analyze time-series features

In [None]:
# Session summary
print("=" * 60)
print("PREPROCESSING COMPLETE")
print("=" * 60)
print(f"\nOriginal probes: {methylation_data.shape[0]:,}")
print(f"Final probes: {imputed_data.shape[0]:,}")
print(f"Reduction: {(1 - imputed_data.shape[0]/methylation_data.shape[0])*100:.1f}%")
print(f"\nSamples: {imputed_data.shape[1]}")
print(f"Data versions created: {len(data_versions)}")
print(f"\nOutput files:")
print(f"  - {preprocessed_path}")
print(f"  - {versions_path}")
print(f"  - {config_path}")