# One-versus-Rest Classification for CJD Subtype Analysis

This notebook implements a binary classification approach to identify protein signatures specific to each CJD subtype using Random Forest models.

## Analysis Approach
1. **Binary Classification**: Convert the multiclass problem into one-versus-rest binary classification for each subtype
2. **Class Imbalance**: Address class imbalance using weighted Random Forest
3. **Feature Selection**: Perform feature selection within each cross-validation fold
4. **Stability Analysis**: Identify stable biomarkers that appear across multiple folds
5. **Evaluation**: Assess model performance with multiple metrics
6. **Interpretation**: Analyze feature importance using Gini and SHAP approaches

## Expected Outputs
- Performance metrics for each subtype classification
- Stable feature sets specific to each subtype
- Comparative analysis of subtype-specific protein signatures
- Importance ranking of identified biomarkers

In [1]:
# General imports
import os
import warnings
import statistics as stat
import numpy as np
import pandas as pd
from scipy import stats
from matplotlib_venn import venn3
from typing import Dict
from upsetplot import UpSet

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px

# Machine learning libraries
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_predict
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.impute import KNNImputer
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif

# Metrics
from sklearn.metrics import (confusion_matrix, classification_report, 
                             accuracy_score, precision_score, recall_score, 
                             f1_score, roc_auc_score, balanced_accuracy_score)

# SHAP for interpretability
import shap

# Ignore warnings
warnings.filterwarnings("ignore")

# Enable inline plotting for Jupyter
%matplotlib inline

In [2]:
data_path = os.path.dirname(os.getcwd()) + '/data'
figure_path = os.path.dirname(os.getcwd()) + '/figures'

In [4]:
# Import Olink Data
df = pd.read_excel(data_path + '/curated/olink.xlsx')

# Drop unnecessary columns
columns_to_drop = [
    'Codon 129', 'SampleID', 'Group', 'Strain', 'age at LP', 'Sex',
    'onset-LP', 'onset-death', 'LP-death', 'NP_subtype'
]

df = df.drop(columns=columns_to_drop)

# Filter out controls
df = df[df['SubGroup'] != 'CTRL']

### Function for 5 fold cross validation

In [5]:
X = df.drop(['SubGroup'], axis=1)
y = df['SubGroup']

In [6]:
def analyze_subtype(X, y, target_subtype, cv, k=50):
    """
    Perform balanced binary classification analysis for a specific subtype vs all others.
    """
    # Create binary labels
    y_binary = (y == target_subtype).astype(int)
    
    # Initialize metrics storage
    metrics = {
        'Accuracy': [], 'Precision': [], 'Recall': [], 
        'F1': [], 'ROC-AUC': [], 'Balanced accuracy': []
    }
    feature_importances = {}  # Store feature importances by feature name
    all_selected_features = []  # Store selected features for each fold
    feature_shap_values = {}  # Store SHAP values by feature
    
    # Calculate class weights
    n_samples = len(y_binary)
    n_positive = sum(y_binary)
    class_weight = {0: n_positive/n_samples, 
                   1: (n_samples-n_positive)/n_samples}
    
    for fold, (train_idx, test_idx) in enumerate(cv.split(X, y_binary), 1):
        print(f"Processing {target_subtype} - Fold {fold}")
        
        # Split data
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y_binary.iloc[train_idx], y_binary.iloc[test_idx]
        
        # Feature selection on training data only
        selector = SelectKBest(f_classif, k=k)
        X_train_selected = selector.fit_transform(X_train, y_train)
        X_test_selected = selector.transform(X_test)
        
        # Store selected features for this fold
        fold_features = X.columns[selector.get_support()].tolist()
        all_selected_features.append(fold_features)
        
        # Train model with class balancing
        model = RandomForestClassifier(
            random_state=42,
            class_weight=class_weight,
            n_estimators=200
        )
        model.fit(X_train_selected, y_train)
        
        # Predictions
        y_pred = model.predict(X_test_selected)
        y_pred_proba = model.predict_proba(X_test_selected)[:, 1]
        
        # Calculate metrics
        metrics['Accuracy'].append(accuracy_score(y_test, y_pred))
        metrics['Balanced accuracy'].append(balanced_accuracy_score(y_test, y_pred))
        metrics['Precision'].append(precision_score(y_test, y_pred))
        metrics['Recall'].append(recall_score(y_test, y_pred))
        metrics['F1'].append(f1_score(y_test, y_pred))
        metrics['ROC-AUC'].append(roc_auc_score(y_test, y_pred_proba))
        
        # SHAP analysis
        explainer = shap.TreeExplainer(model)
        shap_values = explainer(X_test_selected)
        
        # Store feature importance and SHAP values by feature name
        for i, feature in enumerate(fold_features):
            if feature not in feature_importances:
                feature_importances[feature] = []
            if feature not in feature_shap_values:
                feature_shap_values[feature] = []
            
            feature_importances[feature].append(model.feature_importances_[i])
            feature_shap_values[feature].append(np.abs(shap_values.values[:, i]).mean())
    
    # Analyze feature selection stability
    feature_stability = {}
    all_unique_features = set(feature for fold_features in all_selected_features 
                            for feature in fold_features)
    
    for feature in all_unique_features:
        feature_stability[feature] = sum(
            feature in fold_features 
            for fold_features in all_selected_features
        ) / len(all_selected_features)
    
    # Get stable features
    stable_features = [f for f, stability in feature_stability.items() 
                      if stability >= 0.6]
    
    # Calculate mean values for stable features
    stable_importance_df = pd.DataFrame({
        'Feature': stable_features,
        'Stability': [feature_stability[f] for f in stable_features],
        'Gini': [np.mean(feature_importances[f]) for f in stable_features],
        'SHAP': [np.mean(feature_shap_values[f]) for f in stable_features]
    })
    
    # Sort by SHAP importance
    stable_importance_df = stable_importance_df.sort_values('SHAP', ascending=False)
    stable_importance_df.to_csv(data_path + '/results/'+ f'{target_subtype}_feature_importance.csv', index=False)
        
    # Plot top 20 stable features
    plt.figure(figsize=(12, 8))
    n_features = min(20, len(stable_features))
    plt.barh(range(n_features), stable_importance_df['SHAP'][:n_features])
    plt.yticks(range(n_features), stable_importance_df['Feature'][:n_features])
    plt.xlabel('Mean |SHAP value|')
    plt.title(f'Top Stable Features for {target_subtype}\n' + '(Selected in ≥60% of folds)')
    plt.tight_layout()
    plt.savefig(figure_path + '/ml_subgroup' + f'/{target_subtype}_top_stable_features.png', dpi=1200, bbox_inches='tight')
    plt.close()
        
    # Print results
    print(f"\nResults for {target_subtype}:")
    for metric, values in metrics.items():
        print(f"{metric}: {np.mean(values):.4f} (±{np.std(values):.4f})")
    print(f"\nNumber of stable features: {len(stable_features)}")
    
    return {
        'metrics': {k: (np.mean(v), np.std(v)) for k, v in metrics.items()},
        'importance': stable_importance_df,
        'feature_stability': feature_stability,
        'selected_features': stable_features
    }

In [None]:
# Create cross-validation object
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Get unique subtypes
subtypes = y.unique()

# Run for each subtype
results = {}
for subtype in subtypes:
    print(f"\nAnalyzing {subtype}...")
    results[subtype] = analyze_subtype(X, y, target_subtype=subtype, cv=cv, k=50)

# Create comparison DataFrame
comparison_df = pd.DataFrame({
    subtype: {
        'Accuracy': results[subtype]['metrics']['Accuracy'][0],
        'Balanced Accuracy': results[subtype]['metrics']['Balanced accuracy'][0],
        'ROC-AUC': results[subtype]['metrics']['ROC-AUC'][0]
    } for subtype in subtypes
}).T

print("\nComparison across subtypes:")
print(comparison_df)

In [None]:
important_proteins = {}
for subtype in results:
    # Get the stable proteins and their SHAP values
    df = results[subtype]['importance']
    important_proteins[subtype] = df['Feature'].tolist()

# Find unique and shared proteins
all_proteins = set()
for proteins in important_proteins.values():
    all_proteins.update(proteins)

# Create sets for easy comparison
protein_sets = {subtype: set(proteins) for subtype, proteins in important_proteins.items()}

# Find unique proteins for each subtype
unique_proteins = {
    subtype: protein_sets[subtype] - set.union(*(
        protein_sets[other] for other in protein_sets if other != subtype
    ))
    for subtype in protein_sets
}

# Find shared proteins
shared_proteins = set.intersection(*protein_sets.values())

# Create summary DataFrame
summary = pd.DataFrame({
    'Subtype': list(important_proteins.keys()),
    'Total_Stable_Proteins': [len(proteins) for proteins in important_proteins.values()],
    'Unique_Proteins': [len(unique_proteins[st]) for st in important_proteins],
    'Shared_With_Other_Subtypes': [
        len(protein_sets[st] - unique_proteins[st]) 
        for st in important_proteins
    ]
})

print("\nProtein Signature Summary:")
print(summary)

# For enrichment analysis, you'll need:
background_proteins = X.columns.tolist()  # All proteins in your panel

# Then for each subtype:
for subtype in important_proteins:
    print(f"\nTop 10 proteins for {subtype}:")
    top_proteins = results[subtype]['importance'].head(10)
    print(top_proteins[['Feature', 'SHAP', 'Stability']])

In [None]:
metrics_data = []
for subtype, res in results.items():
    for metric, (mean, std) in res['metrics'].items():
       metrics_data.append({
          'Subtype': subtype, 'Metric': metric, 'Mean': mean, 'Std': std})
    
df = pd.DataFrame(metrics_data)

fig = plt.figure(figsize=(9, 4)) 
metrics = df['Metric'].unique()
subtypes = df['Subtype'].unique()
x = np.arange(len(metrics))
width = 0.25  # Width of bars
    
# Plot bars for each subtype
for i, subtype in enumerate(subtypes):
    subtype_data = df[df['Subtype'] == subtype]
    plt.bar(x + i*width, subtype_data['Mean'], width, label=subtype, alpha=0.8)
    plt.errorbar(x + i*width, subtype_data['Mean'], yerr=subtype_data['Std'], fmt='none', 
                 capsize=5, color='black', alpha=0.5)
    
plt.xlabel('Metrics')
plt.ylabel('Score')
plt.title('Classification Performance Metrics by Subtype')
plt.xticks(x + width, metrics, rotation=0)
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.1)

plt.savefig(figure_path + '/ml_subgroup/performance_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plot_data = []
for subtype, res in results.items():
    importance_df = res['importance']
    top_features = importance_df.nlargest(20, 'SHAP')
    
    for _, row in top_features.iterrows():
        plot_data.append({'Subtype': subtype, 'Feature': row['Feature'], 'SHAP': row['SHAP'],
                          'Stability': row['Stability']})
    
df = pd.DataFrame(plot_data)
plt.figure(figsize=(4, 10))
g = sns.scatterplot(data=df, x='Subtype', y='Feature',
                    size='SHAP', hue='Stability',
                    sizes=(50, 400), alpha=0.6)
plt.title('Top Features by Subtype')
plt.xticks(rotation=0)
plt.legend(bbox_to_anchor=(0.08, 0.3), loc='upper left', borderaxespad=0)

plt.savefig(figure_path + '/ml_subgroup/top_features.png', dpi=300, bbox_inches='tight')
plt.show()
