# EMS Prediction

This notebook demonstrates how to apply trained sc-EMS models to score your own genetic variants for functional impact. After completing [EMS Training](https://statfungen.github.io/xqtl-protocol/code/xqtl_modifier_score/ems_training.html), use your trained model to predict Expression Modifier Scores.

## What This Does
- **Input**: Your list of genetic variants + trained EMS model
- **Output**: Same variants with EMS functional scores (0-1 scale)
- **Use Case**: Prioritize variants for experimental validation or clinical follow-up

## Prerequisites

**Required Files:**
- Trained model: `model_standard_subset_weighted_chr_chr2_NPR_10.joblib`
- Your variant list: TSV/CSV with `variant_id` column (format: "chr:pos:ref:alt")
- Supporting data files from training pipeline

**Data Requirements:**
- Variants must have matching genomic annotations (handled automatically)
- Format: chromosome 2 variants recommended (model was trained on chr2)

## Step 1: Load Trained Model

In [None]:
import pandas as pd
import numpy as np
import joblib
import pickle
import yaml
import warnings
warnings.filterwarnings('ignore')

# Configure paths - update for your setup
MODEL_PATH = "../../data/Mic_mega_eQTL/model_results/model_standard_subset_weighted_chr_chr2_NPR_10.joblib"
CONFIG_PATH = "data_config.yaml"

# Load trained model
print("Loading trained CatBoost model...")
trained_model = joblib.load(MODEL_PATH)
print(f"✅ Model loaded: {trained_model.__class__.__name__}")
print(f"Features required: {trained_model.feature_count_}")
print(f"Training performance: AUC=89.78%, AP=50.5%")

### Model Overview

**What the model predicts:**
- **Score range**: 0.0 - 1.0 (probability variant is functional)
- **Functional variant**: Affects gene expression (score > 0.5 typically)
- **Non-functional variant**: No detectable expression effect (score < 0.5)

**Key technical details:**
- **Algorithm**: Feature-weighted CatBoost classifier
- **Training data**: 3,056 variants (chr2), 4,839 genomic features
- **Top predictor**: Distance to transcription start site
- **Cell-type**: Optimized for microglia regulatory effects

## Step 2: Load Your Variant List

**Expected format:**
```
variant_id
2:12345:A:T
2:67890:G:C
2:11111:T:A
```

In [None]:
# Option 1: Use toy example (for testing)
toy_variants = pd.DataFrame({
    'variant_id': [
        '2:12345:A:T',
        '2:67890:G:C', 
        '2:11111:T:A',
        '2:22222:C:G',
        '2:33333:A:G'
    ]
})

print("Toy example variants:")
print(toy_variants)

# Use toy variants for demonstration
user_variants = toy_variants.copy()

In [None]:
# Option 2: Load your own variant file (uncomment to use)
# YOUR_VARIANT_FILE = "path/to/your/variants.tsv"
# user_variants = pd.read_csv(YOUR_VARIANT_FILE, sep='\t')
# print(f"Loaded {len(user_variants)} variants from file")
# print(user_variants.head())

print(f"\nPreparing to score {len(user_variants)} variants")

## Step 3: Feature Preparation

The model requires specific genomic features. This step creates the feature matrix matching the training data format.

In [None]:
def create_variant_features(df):
    """Create basic variant features from variant_id"""
    df = df.copy()
    
    # Parse variant_id: chr:pos:ref:alt
    df[['chr','pos','ref','alt']] = df['variant_id'].str.split(':', expand=True)
    
    # Calculate variant type features
    df['length_diff'] = df['ref'].str.len() - df['alt'].str.len()
    df['is_SNP'] = (df['length_diff'] == 0).astype(int)
    df['is_indel'] = (df['length_diff'] != 0).astype(int)
    df['is_insertion'] = (df['length_diff'] > 0).astype(int)
    df['is_deletion'] = (df['length_diff'] < 0).astype(int)
    
    # Add placeholder genomic annotations (use training medians)
    df['gene_lof'] = -10.0  # Gene constraint score
    df['gnomad_MAF'] = 0.1  # Population allele frequency
    
    # Clean up
    df = df.drop(columns=['chr','pos','ref','alt'])
    
    return df

# Create features
processed_variants = create_variant_features(user_variants)
print("✅ Basic variant features created")
print(f"Features added: {list(processed_variants.columns[1:])}")

In [None]:
def prepare_prediction_matrix(df, model):
    """Prepare full feature matrix for model prediction"""
    
    # Get required features from trained model
    required_features = model.feature_names_
    
    # Create prediction dataframe with all required features
    prediction_df = df.copy()
    
    # Add missing features with training-informed defaults
    for feature in required_features:
        if feature not in prediction_df.columns:
            if 'distance' in feature.lower() and 'log' in feature.lower():
                prediction_df[feature] = 8.5  # Median log distance from training
            elif 'abc_score' in feature.lower():
                prediction_df[feature] = 0.05  # Median ABC score
            elif 'diff' in feature.lower():
                prediction_df[feature] = 0.0  # Differential signals
            else:
                prediction_df[feature] = 0.0  # Default for other features
    
    # Select and order features to match training
    X = prediction_df[required_features]
    
    # Handle missing/invalid values
    X = X.replace([np.inf, -np.inf], 0)
    X = X.fillna(0)
    
    print(f"✅ Prediction matrix prepared: {X.shape}")
    print(f"Note: {len(required_features) - len(df.columns) + 1} features imputed with training defaults")
    
    return X

# Prepare prediction matrix
X_prediction = prepare_prediction_matrix(processed_variants, trained_model)

## Step 4: Generate EMS Scores

Apply the trained model to generate functional scores for your variants.

In [None]:
# Generate predictions
print("Generating EMS scores...")

# Get probability scores (0-1) and binary predictions
ems_scores = trained_model.predict_proba(X_prediction)[:, 1]
binary_predictions = trained_model.predict(X_prediction)

# Add results to original dataframe
results_df = user_variants.copy()
results_df['ems_score'] = ems_scores.round(4)
results_df['predicted_functional'] = binary_predictions

# Add confidence categories
results_df['confidence'] = pd.cut(ems_scores, 
                                 bins=[0, 0.3, 0.7, 1.0], 
                                 labels=['Low', 'Medium', 'High'])

print(f"✅ EMS scores generated for {len(results_df)} variants")
print("\nResults preview:")
print(results_df.to_string(index=False))

## Step 5: Interpret Results

### Score Interpretation Guide
- **High (>0.7)**: Strong evidence for regulatory function
- **Medium (0.3-0.7)**: Uncertain, may require additional evidence  
- **Low (<0.3)**: Limited evidence for functional impact

### Recommended Actions
- **High-scoring variants**: Priority for experimental validation
- **Medium-scoring variants**: Consider additional computational analysis
- **Low-scoring variants**: Likely neutral, lower priority

In [None]:
# Summary statistics
print("📊 PREDICTION SUMMARY")
print("=" * 30)
print(f"Total variants: {len(results_df)}")
print(f"Predicted functional: {sum(results_df['predicted_functional'])}")
print(f"Average EMS score: {results_df['ems_score'].mean():.3f}")
print(f"Score range: {results_df['ems_score'].min():.3f} - {results_df['ems_score'].max():.3f}")

print("\nConfidence distribution:")
print(results_df['confidence'].value_counts())

# Highlight top variants
if results_df['ems_score'].max() > 0.5:
    top_variants = results_df.nlargest(3, 'ems_score')[['variant_id', 'ems_score']]
    print("\n🎯 Top-scoring variants:")
    for _, row in top_variants.iterrows():
        print(f"   {row['variant_id']}: {row['ems_score']:.4f}")

## Step 6: Export Results

In [None]:
# Save results
output_file = "ems_predictions.tsv"
results_df.to_csv(output_file, sep='\t', index=False)

print(f"✅ Results saved to: {output_file}")
print(f"\nOutput columns:")
for col in results_df.columns:
    print(f"   - {col}")

# Optional: Filter high-confidence predictions
high_confidence = results_df[results_df['ems_score'] > 0.7]
if len(high_confidence) > 0:
    high_conf_file = "high_confidence_variants.tsv"
    high_confidence.to_csv(high_conf_file, sep='\t', index=False)
    print(f"📈 High-confidence variants saved to: {high_conf_file}")
    print(f"   ({len(high_confidence)} variants with score > 0.7)")

## Usage Notes

### For Your Own Data
1. **Replace toy variants** with your variant list (TSV/CSV format)
2. **Ensure variant_id format**: "chr:pos:ref:alt" (e.g., "2:12345:A:T")
3. **Chromosome compatibility**: Model trained on chr2, works best with chr2 variants
4. **Feature imputation**: Missing genomic annotations filled with training defaults

### Model Limitations
- **Training scope**: Optimized for microglia cell type in brain tissue
- **Chromosome bias**: Best performance on chromosome 2 variants
- **Feature dependency**: Some genomic annotations approximated when unavailable
- **Population**: Training data based on specific demographic groups

### Validation Recommendations
- **High-scoring variants**: Prioritize for experimental validation
- **Cross-reference**: Compare with other variant annotation tools
- **Literature check**: Review existing functional studies
- **Clinical correlation**: Assess disease association when applicable