# Machine Learning Classification Analysis for Olink Data and feature analyses

This notebook performs a machine learning analysis on Olink protein expression data using Random Forest classification with selected features (top 50) and analizes most relevant features.

## Input:
- olink.xlsx: Protein expression data with columns:
  * SampleID: Unique sample identifier
  * Group: Clinical group classification
  * SubGroup: Clinical subgroup
  * Strain: Sample strain type
  * age at LP: Age at lumbar puncture
  * Sex: Patient sex
  * [Protein Names]: NPX values for each protein

## Anlysis steps

1. **Model Training**
   - Implement 5-fold cross-validation
   - Train models with different feature sets
   - Generate performance metrics
   - Calculate standard metrics (accuracy, precision, recall, F1, ROC-AUC)
   - Generate confusion matrices
   - Analyze feature importance using SHAP and Gini

2. **Analyse important features (top 10)**
   - Plot biomarker levels across diagnostic groups
   - Plot biomarkers for sCJD strain differentiation
   - Assess prognostic value of top 20 fatures with Cox regression analyses
   - Plot KM curves of top 10 prognostic biomarkers

## Outputs
- Gini Feature importance plots
- SHAP analysis visualisations
- Violin plots for top 10 proteins
- Violin plots for the proteins distinuishing M1 vs V2 strains
- Forest plots of univariate and multivariate Cox regression analyses
- KM curves for 10 proteins

In [45]:
# General imports
import os
import warnings
import statistics as stat
import numpy as np
import pandas as pd
from scipy import stats
from itertools import combinations
from matplotlib.gridspec import GridSpec
import requests

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px

# Machine learning libraries
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_predict
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.impute import KNNImputer
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif

from lifelines import KaplanMeierFitter, CoxPHFitter 
from lifelines.statistics import logrank_test

# Metrics
from sklearn.metrics import (confusion_matrix, classification_report, 
                             accuracy_score, precision_score, recall_score, 
                             f1_score, roc_auc_score)

# SHAP for interpretability
import shap

# Ignore warnings
warnings.filterwarnings("ignore")

# Enable inline plotting for Jupyter
%matplotlib inline

In [46]:
data_path = os.path.dirname(os.getcwd()) + '/data'
figure_path = os.path.dirname(os.getcwd()) + '/figures/feature_analysis'

In [47]:
# Import Olink Data
df = pd.read_excel(data_path + '/curated/olink.xlsx')

# Drop unnecessary columns
columns_to_drop = [
    'Codon 129', 'SampleID', 'Group', 'Strain', 'age at LP', 'Sex',
    'onset-LP', 'onset-death', 'LP-death', 'NP_subtype'
]

df = df.drop(columns=columns_to_drop)

# Filter out controls
df = df[df['SubGroup'] != 'CTRL']

### Function for 5 fold cross validation without feature selection

In [None]:
X = df.drop(['SubGroup'], axis=1)
y = df['SubGroup']

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

In [50]:
def train_and_evaluate_reduced_model(X, y, cv, k=50, plot_confusion=True):
    # Store metrics
    accuracies, precisions, recalls, f1_scores, roc_aucs = [], [], [], [], []
    all_confusion_matrices = np.zeros((len(y.unique()), len(y.unique())))
    classes = sorted(y.unique())
    
    # Create DataFrames to store feature importances per fold
    fold_features_df = pd.DataFrame()  # Will store all features and their importances per fold
    fold_shap_class_df = []  # Will store SHAP values per class per fold
    
    for fold, (train_index, test_index) in enumerate(cv.split(X, y), 1):
        print(f"Processing fold {fold}")
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        
        # Feature selection
        selector = SelectKBest(f_classif, k=k)
        X_train_selected = selector.fit_transform(X_train, y_train)
        X_test_selected = selector.transform(X_test)
        
        # Get selected feature names for this fold
        feature_mask = selector.get_support()
        fold_features = X.columns[feature_mask].tolist()
        
        # Train model
        model = RandomForestClassifier(random_state=42)
        model.fit(X_train_selected, y_train)
        
        # Predictions
        y_pred = model.predict(X_test_selected)
        y_pred_proba = model.predict_proba(X_test_selected)
        
        # Calculate metrics
        accuracies.append(accuracy_score(y_test, y_pred))
        precisions.append(precision_score(y_test, y_pred, average='weighted'))
        recalls.append(recall_score(y_test, y_pred, average='weighted'))
        f1_scores.append(f1_score(y_test, y_pred, average='weighted'))
        
        # ROC AUC
        y_test_bin = label_binarize(y_test, classes=classes)
        roc_auc = roc_auc_score(y_test_bin, y_pred_proba, average='macro', multi_class='ovr')
        roc_aucs.append(roc_auc)
        
        # Confusion matrix
        cm = confusion_matrix(y_test, y_pred, labels=classes)
        all_confusion_matrices += cm
        
        # Store Gini importances for this fold
        fold_gini = pd.DataFrame({
            'Feature': fold_features,
            'Gini_Importance': model.feature_importances_,
            'Fold': fold
        })
        
        # SHAP analysis
        explainer = shap.TreeExplainer(model)
        shap_values = explainer(X_test_selected)
        
        # Calculate SHAP importance per class for this fold
        shap_class_importance = np.abs(shap_values.values).mean(axis=0)  # Average over samples
        shap_class_df = pd.DataFrame(
            shap_class_importance,
            columns=classes,
            index=fold_features
        )
        shap_class_df['Fold'] = fold
        fold_shap_class_df.append(shap_class_df)
        
        # Calculate overall SHAP importance for this fold
        overall_shap = np.abs(shap_values.values).mean(axis=0).mean(axis=1)  # Average over samples and classes
        fold_gini['SHAP_Importance'] = overall_shap
        
        # Append to main tracking DataFrame
        fold_features_df = pd.concat([fold_features_df, fold_gini])
    
    # Calculate average importance across folds
    # First, get list of all unique features that appeared in any fold
    all_features = fold_features_df['Feature'].unique()
    
    # Calculate average importance for features, counting only folds where they appeared
    final_rankings = []
    for feature in all_features:
        feature_data = fold_features_df[fold_features_df['Feature'] == feature]
        n_folds = len(feature_data)
        avg_gini = feature_data['Gini_Importance'].mean()
        avg_shap = feature_data['SHAP_Importance'].mean()
        
        final_rankings.append({
            'Feature': feature,
            'Gini_Importance': avg_gini,
            'SHAP_Importance': avg_shap,
            'Folds_Appeared': n_folds
        })
    
    # Create final rankings DataFrame
    feature_rankings = pd.DataFrame(final_rankings)
    feature_rankings.sort_values('SHAP_Importance', ascending=False, inplace=True)
    
    # Calculate average SHAP importance per class
    all_shap_class = pd.concat(fold_shap_class_df)
    shap_by_class = all_shap_class.groupby(all_shap_class.index).mean()  # Average across folds
    shap_by_class = shap_by_class.drop('Fold', axis=1)
    
    # Get top 20 features by overall SHAP importance
    top_20_features = feature_rankings['Feature'].head(20).tolist()
    top_20_shap_by_class = shap_by_class.loc[top_20_features]
    
    # Save results
    feature_rankings.to_csv(data_path + '/results/feature_importance_rankings.csv', index=False)
    top_20_shap_by_class.to_csv(data_path + '/results/top_20_shap_importance_by_class.csv')
    
    # Create visualizations
    if plot_confusion:
        # Confusion Matrix
        all_confusion_matrices = all_confusion_matrices / cv.get_n_splits()
        all_confusion_matrices = (all_confusion_matrices / all_confusion_matrices.sum(axis=1, keepdims=True)) * 100
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(all_confusion_matrices, annot=True, fmt='.2f', cmap='Blues',
                   xticklabels=classes, yticklabels=classes,
                   annot_kws={"size": 12})
        plt.title('Confusion matrix (%, RF selected features)', fontsize=14)
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.savefig(figure_path + '/confusion_matrix_reduced_rf.png', dpi=1200, bbox_inches='tight')
        plt.close()
        
        # Gini Importance Plot
        plt.figure(figsize=(12, 8))
        top_20_gini = feature_rankings.head(20)
        plt.barh(range(20), top_20_gini['Gini_Importance'])
        plt.yticks(range(20), top_20_gini['Feature'])
        plt.xlabel('Average Gini Importance')
        plt.title('Top 20 Feature Importance (Gini)')
        plt.tight_layout()
        plt.savefig(figure_path + '/top_20_features_gini_importance.png', dpi=1200, bbox_inches='tight')
        plt.close()
        
        # SHAP Importance Heatmap
        plt.figure(figsize=(8, 6))
        im = plt.imshow(top_20_shap_by_class, cmap='Blues')
        colorbar=plt.colorbar(im, label='Average |SHAP value|', orientation='vertical')
        colorbar.set_label('Average |SHAP value|', fontsize=12)
        plt.xticks(range(len(top_20_shap_by_class.columns)), top_20_shap_by_class.columns, rotation=45)
        plt.yticks(range(len(top_20_shap_by_class.index)), top_20_shap_by_class.index, fontsize=12)
        plt.title('Top 20 Feature Importance Heatmap (SHAP)', fontsize=14)
        plt.tight_layout()
        plt.savefig(figure_path + '/top_20_shap_importance_heatmap.png', dpi=1200, bbox_inches='tight')
        plt.close()
        
        # Feature Importance Stability
        plt.figure(figsize=(14, 10))
        for feature in top_20_features:
            feature_stability = fold_features_df[fold_features_df['Feature'] == feature]
            plt.plot(feature_stability['Fold'], feature_stability['Gini_Importance'], 
                    label=feature, marker='o')
        plt.xlabel('Fold')
        plt.ylabel('Gini Importance')
        plt.title('Top 20 Feature Importance Stability Across Folds')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(figure_path + '/feature_importance_stability.png', dpi=1200, bbox_inches='tight')
        plt.close()
    
    # Print results
    print("\nModel Performance Metrics:")
    print(f"Accuracy: {np.mean(accuracies):.4f} (±{np.std(accuracies):.4f})")
    print(f"Precision: {np.mean(precisions):.4f} (±{np.std(precisions):.4f})")
    print(f"Recall: {np.mean(recalls):.4f} (±{np.std(recalls):.4f})")
    print(f"F1 Score: {np.mean(f1_scores):.4f} (±{np.std(f1_scores):.4f})")
    print(f"ROC-AUC: {np.mean(roc_aucs):.4f} (±{np.std(roc_aucs):.4f})")
    
    print("\nTop 20 Features by SHAP Importance:")
    print("(showing number of folds each feature appeared in)")
    for _, row in feature_rankings.head(20).iterrows():
        print(f"{row['Feature']}: {row['SHAP_Importance']:.4f} (appeared in {row['Folds_Appeared']}/5 folds)")
    
    return {
        'metrics': {
            'Accuracy': (np.mean(accuracies), np.std(accuracies)),
            'Precision': (np.mean(precisions), np.std(precisions)),
            'Recall': (np.mean(recalls), np.std(recalls)),
            'F1 Score': (np.mean(f1_scores), np.std(f1_scores)),
            'ROC-AUC': (np.mean(roc_aucs), np.std(roc_aucs))
        },
        'feature_rankings': feature_rankings,
        'shap_by_class': shap_by_class,
        'confusion_matrix': all_confusion_matrices
    }

In [None]:
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
results = train_and_evaluate_reduced_model(X, y, cv, k=50, plot_confusion=True)

# Feature analysis

In [15]:
df = pd.read_excel(data_path + '/curated/olink.xlsx')
feature_importance_rankings = pd.read_csv(data_path + '/results/feature_importance_rankings.csv')
top_biomarkers = list(feature_importance_rankings['Feature'].head(10))

In [None]:
def get_protein_info(protein_id):
    """
    Fetch protein information from UniProt API
    Returns tuple of (full_name, uniprot_id) if found, otherwise (protein_id, None)
    """
    try:
        # Query UniProt API
        url = f"https://rest.uniprot.org/uniprotkb/search?query={protein_id}&format=json"
        response = requests.get(url)
        data = response.json()
        
        # Check if we got any results
        if data['results']:
            result = data['results'][0]  # Take first match
            protein_name = result['proteinDescription']['recommendedName']['fullName']['value']
            uniprot_id = result['primaryAccession']
            return f"{protein_name} ({protein_id}, {uniprot_id})"
        return protein_id, None
        
    except Exception as e:
        print(f"Error fetching info for {protein_id}: {e}")
        return protein_id, None

protein_names = []
for protein in top_biomarkers:
    full_name = get_protein_info(protein)
    protein_names.append(full_name)

# Print results
for protein, name in zip(top_biomarkers, protein_names):
    print(f"{protein}: {name}")

In [None]:
def perform_mann_whitney(data, column, group1, group2):
    x = data[data['SubGroup'] == group1][column].dropna()
    y = data[data['SubGroup'] == group2][column].dropna()
    
    if len(x) < 2 or len(y) < 2:
        warnings.warn(f"Insufficient data for Mann-Whitney U test between {group1} and {group2} for {column}")
        return np.nan
    
    try:
        statistic, p_value = stats.mannwhitneyu(x, y, alternative='two-sided')
        return p_value
    except ValueError as e:
        warnings.warn(f"Error performing Mann-Whitney U test between {group1} and {group2} for {column}: {str(e)}")
        return np.nan

# Define color scheme
colors = {
    'MV2K': '#2ecc71',  # green
    'VV2': '#e74c3c',   # red
    'MM(V)1': '#3498db',   # blue
    'CTRL': '#333c42'   # grey
}

# Create a dictionary to map column names to protein names
protein_name_dict = dict(zip(top_biomarkers, protein_names))

# Get all unique groups
groups = df['SubGroup'].unique()

# Generate all pairwise combinations of groups
group_pairs = list(combinations(groups, 2))

# Perform statistical tests and store results
stats_results = []
for column in top_biomarkers:
    for pair in group_pairs:
        p_value = perform_mann_whitney(df, column, pair[0], pair[1])
        stats_results.append({
            'Column': column,
            'Group1': pair[0],
            'Group2': pair[1],
            'P-value': p_value
        })

# Convert results to DataFrame
stats_df = pd.DataFrame(stats_results)

# Function to get significance stars
def get_significance_stars(p_value):
    if pd.isna(p_value):
        return 'N/A'
    elif p_value <= 0.001:
        return '***'
    elif p_value <= 0.01:
        return '**'
    elif p_value <= 0.05:
        return '*'
    else:
        return 'ns'

# Add significance stars to the DataFrame
stats_df['Significance'] = stats_df['P-value'].apply(get_significance_stars)

# Set up the matplotlib figure
plt.figure(figsize=(20, 10))

# Create a boxplot for each specified column
for i, column in enumerate(top_biomarkers, 1):
    plt.subplot(2, 5, i)
    
    # Create violin plot with custom colors
    sns.violinplot(x='SubGroup', y=column, data=df, 
                  palette=colors, alpha=0.3, inner=None)
    
    # Create box plot with custom colors
    sns.boxplot(x='SubGroup', y=column, data=df, 
                palette=colors, width=0.3, showfliers=False)
    
    # Add significance annotations
    y_max = df[column].max()
    y_range = df[column].max() - df[column].min()
    
    for idx, row in stats_df[stats_df['Column'] == column].iterrows():
        x1 = list(groups).index(row['Group1'])
        x2 = list(groups).index(row['Group2'])
        y = y_max + y_range * 0.05 * (1 + abs(x1 - x2))
        plt.plot([x1, x2], [y, y], 'k-', linewidth=1)
        plt.text((x1 + x2) / 2, y, row['Significance'], 
                ha='center', va='bottom', fontsize=8)
    
    title = protein_name_dict[column]
    split_title = title.replace(" (", "\n(")  # Splits at the opening parenthesis
    plt.title(split_title, fontsize=10, fontweight='bold')
    plt.xlabel('')
    plt.ylabel('Relative protein abundance (log2)')
    plt.xticks(rotation=0)

plt.tight_layout()
plt.savefig(figure_path + '/top10_all.png', dpi=1200, bbox_inches='tight')
plt.show()

## Visualize most interesting M1 vs V2 proteins

In [22]:
df = pd.read_excel(data_path + '/curated/olink.xlsx')
feature_importance_rankings = pd.read_csv(data_path + '/results/feature_importance_rankings.csv')

# Select specific proteins
top_biomarkers = ["GPC5", "CCDC80", "WASF1"] 

In [None]:
def get_protein_info(protein_id):
    """
    Fetch protein information from UniProt API
    Returns tuple of (full_name, uniprot_id) if found, otherwise (protein_id, None)
    """
    try:
        # Query UniProt API
        url = f"https://rest.uniprot.org/uniprotkb/search?query={protein_id}&format=json"
        response = requests.get(url)
        data = response.json()
        
        # Check if we got any results
        if data['results']:
            result = data['results'][0]  # Take first match
            protein_name = result['proteinDescription']['recommendedName']['fullName']['value']
            uniprot_id = result['primaryAccession']
            return f"{protein_name} ({protein_id}, {uniprot_id})"
        return protein_id, None
        
    except Exception as e:
        print(f"Error fetching info for {protein_id}: {e}")
        return protein_id, None

protein_names = []
for protein in top_biomarkers:
    full_name = get_protein_info(protein)
    protein_names.append(full_name)

# Print results
for protein, name in zip(top_biomarkers, protein_names):
    print(f"{protein}: {name}")

In [None]:
def perform_mann_whitney(data, column, group1, group2):
    x = data[data['SubGroup'] == group1][column].dropna()
    y = data[data['SubGroup'] == group2][column].dropna()
    
    if len(x) < 2 or len(y) < 2:
        warnings.warn(f"Insufficient data for Mann-Whitney U test between {group1} and {group2} for {column}")
        return np.nan
    
    try:
        statistic, p_value = stats.mannwhitneyu(x, y, alternative='two-sided')
        return p_value
    except ValueError as e:
        warnings.warn(f"Error performing Mann-Whitney U test between {group1} and {group2} for {column}: {str(e)}")
        return np.nan

# Define color scheme
colors = {
    'MV2K': '#2ecc71',  # green
    'VV2': '#e74c3c',   # red
    'MM(V)1': '#3498db',   # blue
    'CTRL': '#333c42'   # grey
}

# Create a dictionary to map column names to protein names
protein_name_dict = dict(zip(top_biomarkers, protein_names))

# Get all unique groups
groups = df['SubGroup'].unique()

# Generate all pairwise combinations of groups
group_pairs = list(combinations(groups, 2))

# Perform statistical tests and store results
stats_results = []
for column in top_biomarkers:
    for pair in group_pairs:
        p_value = perform_mann_whitney(df, column, pair[0], pair[1])
        stats_results.append({
            'Column': column,
            'Group1': pair[0],
            'Group2': pair[1],
            'P-value': p_value
        })

# Convert results to DataFrame
stats_df = pd.DataFrame(stats_results)

# Function to get significance stars
def get_significance_stars(p_value):
    if pd.isna(p_value):
        return 'N/A'
    elif p_value <= 0.001:
        return '***'
    elif p_value <= 0.01:
        return '**'
    elif p_value <= 0.05:
        return '*'
    else:
        return 'ns'

# Add significance stars to the DataFrame
stats_df['Significance'] = stats_df['P-value'].apply(get_significance_stars)

# Set up the matplotlib figure
plt.figure(figsize=(20, 10))

# Create a boxplot for each specified column
for i, column in enumerate(top_biomarkers, 1):
    plt.subplot(2, 5, i)
    
    # Create violin plot with custom colors
    sns.violinplot(x='SubGroup', y=column, data=df, 
                  palette=colors, alpha=0.3, inner=None)
    
    # Create box plot with custom colors
    sns.boxplot(x='SubGroup', y=column, data=df, 
                palette=colors, width=0.3, showfliers=False)
    
    # Add significance annotations
    y_max = df[column].max()
    y_range = df[column].max() - df[column].min()
    
    for idx, row in stats_df[stats_df['Column'] == column].iterrows():
        x1 = list(groups).index(row['Group1'])
        x2 = list(groups).index(row['Group2'])
        y = y_max + y_range * 0.05 * (1 + abs(x1 - x2))
        plt.plot([x1, x2], [y, y], 'k-', linewidth=1)
        plt.text((x1 + x2) / 2, y, row['Significance'], 
                ha='center', va='bottom', fontsize=8)
    
    title = protein_name_dict[column]
    split_title = title.replace(" (", "\n(")  # Splits at the opening parenthesis
    plt.title(split_title, fontsize=12)
    plt.xlabel('')
    plt.ylabel('Relative protein abundance (log2)')
    plt.xticks(rotation=0)

plt.tight_layout()
plt.savefig(figure_path + '/interesting_bmk_all.png', dpi=300, bbox_inches='tight')
plt.show()

# Survival analysis

## Perform Cox regression analysis for top biomarkers

In [26]:
df = pd.read_excel(data_path + '/curated/olink.xlsx')
feature_importance_rankings = pd.read_csv(data_path + '/results/feature_importance_rankings.csv')
protein_list = list(feature_importance_rankings['Feature'].head(20))

In [27]:
# Create a copy of the DataFrame
df_cox = df.copy()
df_cox = df_cox[df_cox['SubGroup'] != 'CTRL']
# Rename columns for clarity
df_cox = df_cox.rename(columns={ 'LP-death': 'Survival_time', 'onset-LP': 'Onset_lp', 'age at LP': 'Age'})


df_cox['Event'] = 1

# Scale age (subtract mean and divide by standard deviation)
df_cox['Age_scaled'] = (df_cox['Age'] - df_cox['Age'].mean()) / df_cox['Age'].std()

# Scale onset to LP column
df_cox['Onset_lp_scaled'] = (df_cox['Onset_lp'] - df_cox['Onset_lp'].mean()) / df_cox['Onset_lp'].std()

# Convert the SubGroup to numeric values
subgroup_mapping = {
    'MV2K': 1,
    'VV2': 2,
    'MM(V)1': 3
}

df_cox['SubGroup_numeric'] = df_cox['SubGroup'].map(subgroup_mapping)

# Convert the codon 129 to numeric values
codon_mapping = {
    'MV': 1,
    'VV': 2,
    'MM': 3
}

df_cox['codon_numeric'] = df_cox['Codon 129'].map(codon_mapping)

# Combine protein list with other columns
columns_to_select = protein_list + ['Survival_time', 'Event', 
                                    'codon_numeric',
                                    'Sex', 'Age_scaled', 'Onset_lp_scaled', 'SubGroup_numeric']

# Select the columns
df_cox = df_cox[columns_to_select]

In [None]:
# Run Cox analysis with selected covariates (optionally excluding SubGroup_numeric)
def run_cox_analysis_with_selected_covariates(data, protein_list, covariates):
    data = data.copy()
    
    # Initialize results dataframe dynamically based on covariates
    if covariates:
        results_columns = ['Protein', 'Hazard Ratio', 'CI_lower', 'CI_upper', 
                           'p_value', 'Covariate', 'Covariate_HR', 'Covariate_p_value']
    else:
        results_columns = ['Protein', 'Hazard Ratio', 'CI_lower', 'CI_upper', 'p_value']

    results = pd.DataFrame(columns=results_columns)
    
    # Base columns for Cox model (excluding the specific covariates if necessary)
    if not covariates:  # Univariate case
        base_cols = ['Survival_time', 'Event']
    else:  # Multivariate case
        base_cols = ['Survival_time', 'Event', 'codon_numeric', 'Sex', 'Age_scaled', 'Onset_lp_scaled']
        if 'SubGroup_numeric' in covariates:
            base_cols.append('SubGroup_numeric')

    # Loop through proteins and fit Cox model
    for protein in protein_list:
        try:
            # Prepare data for current protein
            cols_to_use = base_cols + [protein]
            current_data = data[cols_to_use].dropna()

            # Fit the Cox model
            cph = CoxPHFitter()
            cph.fit(current_data, duration_col='Survival_time', event_col='Event')

            # Verify proportional hazard assumption
            if covariates:
                cph.check_assumptions(current_data, p_value_threshold=0.05, show_plots=False)
            
            # Get protein results
            hr_protein = np.exp(cph.params_[protein])
            p_val_protein = cph.summary.loc[protein, 'p']
            se_protein = cph.summary.loc[protein, 'se(coef)']
            ci_lower = np.exp(cph.params_[protein] - 1.96 * se_protein)
            ci_upper = np.exp(cph.params_[protein] + 1.96 * se_protein)
            
            # Add protein result to the results dataframe
            if covariates:
                new_row = pd.DataFrame({
                    'Protein': [protein],
                    'Hazard Ratio': [hr_protein],
                    'CI_lower': [ci_lower],
                    'CI_upper': [ci_upper],
                    'p_value': [p_val_protein],
                    'Covariate': ['Protein'],
                    'Covariate_HR': [hr_protein],
                    'Covariate_p_value': [p_val_protein]
                })
            else:
                new_row = pd.DataFrame({
                    'Protein': [protein],
                    'Hazard Ratio': [hr_protein],
                    'CI_lower': [ci_lower],
                    'CI_upper': [ci_upper],
                    'p_value': [p_val_protein]
                })
            
            results = pd.concat([results, new_row], ignore_index=True)
            
        except Exception:
            continue

    # Sort by p-value
    results = results.sort_values('p_value')
    
    return results

# 1. Univariate analysis (only protein)
covariates_no_covariates = []  # No other covariates, just the protein itself
results_univariate = run_cox_analysis_with_selected_covariates(df_cox, protein_list, covariates_no_covariates)

# 2. Analysis with all covariates except SubGroup_numeric
covariates_without_subgroup = ['codon_numeric', 'Sex', 'Age_scaled', 'Onset_lp_scaled']
results_without_subgroup = run_cox_analysis_with_selected_covariates(df_cox, protein_list, covariates_without_subgroup)

# 3. Analysis with all covariates
covariates_with_all = ['codon_numeric', 'SubGroup_numeric', 'Sex', 'Age_scaled', 'Onset_lp_scaled']
results_with_all_covariates = run_cox_analysis_with_selected_covariates(df_cox, protein_list, covariates_with_all)


In [None]:
def create_forest_plot(results, ax, title, protein_order):
    # Ensure proteins are sorted in the same order as `protein_order`
    results['Protein'] = pd.Categorical(results['Protein'], categories=protein_order, ordered=True)
    plot_data = results.drop_duplicates(subset=['Protein']).sort_values('Protein')  # Sort by predefined order
    y_positions = np.arange(len(plot_data))

    def format_pvalue(p):
        return 'p < 0.001' if p < 0.001 else f'p = {p:.3f}'

    # Add alternating grey background for better readability
    for i in range(len(plot_data)):
        if i % 2 == 0:
            ax.axhspan(i - 0.5, i + 0.5, color='gray', alpha=0.1)

    # Scatter plot for hazard ratios
    ax.scatter(plot_data['Hazard Ratio'], y_positions, c='#F6A15C', s=50, zorder=2)

    # Confidence interval lines
    for i, (hr, ci_l, ci_u) in enumerate(zip(plot_data['Hazard Ratio'], plot_data['CI_lower'], plot_data['CI_upper'])):
        ax.plot([ci_l, ci_u], [i, i], color='#F6A15C', linewidth=2, zorder=1)

    # Reference line at HR = 1
    ax.axvline(x=1, color='gray', linestyle='--', alpha=0.7, zorder=0)

    # Set Y-ticks and labels for each subplot
    ax.set_yticks(y_positions)
    ax.set_yticklabels(plot_data['Protein'], fontsize=10)  # Ensure labels appear on all subplots

    # Add hazard ratios and p-values as text annotations
    max_ci = plot_data['CI_upper'].max()
    for i, row in enumerate(plot_data.itertuples()):
        hr_text = f'HR = {row._2:.2f} ({row.CI_lower:.2f}-{row.CI_upper:.2f})'
        p_text = format_pvalue(row.p_value)
        ax.text(max_ci * 1.1, i, f'{hr_text}\n{p_text}', va='center', fontsize=9)

    ax.set_xlabel('Hazard Ratio (95% CI)', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.grid(True, axis='x', linestyle='--', alpha=0.3)


# Get consistent protein order from `results_univariate`
protein_order = results_univariate['Protein'].unique()

# Create a single figure with three horizontally aligned subplots
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 7))

# Generate the three forest plots with consistent protein order
create_forest_plot(results_univariate, axes[0], 'Univariate Cox regression', protein_order)
create_forest_plot(results_without_subgroup, axes[1], 'Multivariate Cox regression', protein_order)
create_forest_plot(results_with_all_covariates, axes[2], 'Multivariate Cox regression including sCJD subtype', protein_order)

plt.subplots_adjust(wspace=0.3)  # Adjust spacing between subplots

# Save and show the figure
plt.tight_layout()
plt.savefig(figure_path + '/cox_forest_plots_side_by_side.png', bbox_inches='tight', dpi=1200)
plt.show()


# Kaplan-Meyer curves

In [None]:
df = pd.read_excel(data_path + '/curated/olink.xlsx')
feature_importance_rankings = pd.read_csv(data_path + '/results/feature_importance_rankings.csv')
# Define your specific list of proteins
top_biomarkers = ['CCDC80', 'MAPT', 'PRDX3', 'FKBP4', 'APEX1', 'FOSB', 'PAG1', 'PPP3R1', 'THOP1', 'WASF1']  


In [None]:
def get_protein_info(protein_id):
    """
    Fetch protein information from UniProt API
    Returns tuple of (full_name, uniprot_id) if found, otherwise (protein_id, None)
    """
    try:
        # Query UniProt API
        url = f"https://rest.uniprot.org/uniprotkb/search?query={protein_id}&format=json"
        response = requests.get(url)
        data = response.json()
        
        # Check if we got any results
        if data['results']:
            result = data['results'][0]  # Take first match
            protein_name = result['proteinDescription']['recommendedName']['fullName']['value']
            uniprot_id = result['primaryAccession']
            return f"{protein_name} ({protein_id}, {uniprot_id})"
        return protein_id, None
        
    except Exception as e:
        print(f"Error fetching info for {protein_id}: {e}")
        return protein_id, None

protein_names = []
for protein in top_biomarkers:
    full_name = get_protein_info(protein)
    protein_names.append(full_name)

# Print results
for protein, name in zip(top_biomarkers, protein_names):
    print(f"{protein}: {name}")

In [42]:
df = df[df['SubGroup'] != 'CTRL']
df = df.rename(columns={
    'Age at LP': 'age',
    'LP-death': 'survival_time'
})

# Add event indicator for survival analysis
df['event'] = 1
df = df[df['survival_time'] > 0]
df = df[df['survival_time'] < 20]

In [None]:
def create_multiplot_survival_analysis(df, protein_list, protein_names):
    fig = plt.figure(figsize=(20, 10))
    gs = GridSpec(2, 5, figure=fig)
    
    colors = {
        'MV2K': {'Upper': '#1a9850', 'Lower': '#91cf60'},     # Dark and light green
        'VV2': {'Upper': '#d73027', 'Lower': '#fc8d59'},      # Dark and light red
        'MM(V)1': {'Upper': '#2166ac', 'Lower': '#67a9cf'}       # Dark and light blue
    }
    
    # Process each protein
    for idx, (protein_id, protein_name) in enumerate(zip(protein_list, protein_names)):
        row = idx // 5
        col = idx % 5
        ax = fig.add_subplot(gs[row, col])
        
        stats_results = {}
        
        # Process each subtype separately
        for subtype in sorted(df['SubGroup'].unique()):
            subtype_data = df[df['SubGroup'] == subtype].copy()
            
            # Calculate tertiles within this subtype
            tertiles = subtype_data[protein_id].quantile([0.33, 0.67])
            subtype_data['expression_group'] = pd.qcut(subtype_data[protein_id], 
                                                     q=3, 
                                                     labels=['Lower', 'Middle', 'Upper'])
            
            # Get data for upper and lower tertiles only
            plot_data = subtype_data[subtype_data['expression_group'].isin(['Upper', 'Lower'])]
            
            # Calculate group sizes
            upper_n = len(plot_data[plot_data['expression_group'] == 'Upper'])
            lower_n = len(plot_data[plot_data['expression_group'] == 'Lower'])
            
            # Perform logrank test
            if upper_n > 0 and lower_n > 0:
                results = logrank_test(
                    plot_data[plot_data['expression_group'] == 'Upper']['survival_time'],
                    plot_data[plot_data['expression_group'] == 'Lower']['survival_time'],
                    plot_data[plot_data['expression_group'] == 'Upper']['event'],
                    plot_data[plot_data['expression_group'] == 'Lower']['event']
                )
                
                stats_results[subtype] = {
                    'p_value': results.p_value,
                    'n': f"{upper_n}/{lower_n}"
                }
                
                # Plot KM curves for this subtype
                kmf = KaplanMeierFitter()
                for group in ['Upper', 'Lower']:
                    mask = (plot_data['expression_group'] == group)
                    if mask.sum() > 0:
                        kmf.fit(plot_data.loc[mask, 'survival_time'],
                               plot_data.loc[mask, 'event'],
                               label=f'{subtype}-{group}')
                        kmf.plot(ax=ax, ci_show=False, 
                                color=colors[subtype][group],
                                linestyle='-' if group == 'Upper' else '--')
        
        # Add stats in bottom right corner
        stats_text = []
        for subtype in sorted(stats_results.keys()):
            p_val = stats_results[subtype]['p_value']
            n = stats_results[subtype]['n']
            stars = '***' if p_val <= 0.001 else '**' if p_val <= 0.01 else '*' if p_val <= 0.05 else 'ns'
            stats_text.append(f'{subtype} (n={n}): {stars} p={p_val:.3f}')
        
        # Position stats text in bottom right
        ax.text(0.95, 0.25, '\n'.join(stats_text),
                transform=ax.transAxes,
                horizontalalignment='right',
                verticalalignment='bottom',
                fontsize=9,
                bbox=dict(facecolor='white', alpha=0.8))
        
        # Customize plot
        #ax.set_title(protein_name.split('(')[0].strip(), fontsize=10)
        split_title = protein_name.replace(" (", "\n(")
        ax.set_title(split_title, fontsize=10)
        if row == 1 or row == 0:  # Only bottom row
            ax.set_xlabel('Time (months)')
        if col == 0:  # Only leftmost column
            ax.set_ylabel('Survival probability')
        else:
            ax.set_ylabel('')
    
    plt.tight_layout()
    plt.savefig(f'{figure_path}/all_proteins_survival.png', dpi=1200, bbox_inches='tight')
    plt.show()

# Run the analysis
create_multiplot_survival_analysis(
    df=df,
    protein_list=top_biomarkers,
    protein_names=protein_names
)

In [None]:
protein_names = ['CCDC80', 'MAPT', 'PRDX3', 'FKBP4', 'APEX1', 'FOSB', 'PAG1', 'PPP3R1', 'THOP1', 'WASF1'] 

def create_multiplot_survival_analysis(df, protein_list, protein_names):
    # Set up the figure with GridSpec
    fig = plt.figure(figsize=(22, 10))
    gs = GridSpec(2, 5, figure=fig)
    
    colors = {
        'Upper': '#1a9850',  # Dark green
        'Middle': '#fee08b', # Yellow
        'Lower': '#d73027'   # Dark red
    }
    
    # Process each protein
    for idx, (protein_id, protein_name) in enumerate(zip(protein_list, protein_names)):
        row = idx // 5
        col = idx % 5
        ax = fig.add_subplot(gs[row, col])
        
        # Calculate tertiles for the entire dataset
        df['expression_group'] = pd.qcut(df[protein_id], 
                                       q=3, 
                                       labels=['Lower', 'Middle', 'Upper'])
        
        # Get group sizes
        group_sizes = df['expression_group'].value_counts()
        
        # Perform logrank test
        upper_data = df[df['expression_group'] == 'Upper']
        lower_data = df[df['expression_group'] == 'Lower']
        
        results = logrank_test(
            upper_data['survival_time'],
            lower_data['survival_time'],
            upper_data['event'],
            lower_data['event']
        )
        
        # Plot KM curves
        kmf = KaplanMeierFitter()
        for group in ['Upper', 'Middle', 'Lower']:
            mask = (df['expression_group'] == group)
            if mask.sum() > 0:
                kmf.fit(df.loc[mask, 'survival_time'],
                       df.loc[mask, 'event'],
                       label=f'{group} (n={group_sizes[group]})')
                kmf.plot(ax=ax, 
                        ci_show=False,
                        color=colors[group],
                        linestyle='-')
        
        # Add p-value
        p_val = results.p_value
        stars = '***' if p_val <= 0.001 else '**' if p_val <= 0.01 else '*' if p_val <= 0.05 else 'ns'
        stats_text = f'p={p_val:.3e} {stars}'
        
        ax.text(0.95, 0.95, stats_text,
                transform=ax.transAxes,
                horizontalalignment='right',
                verticalalignment='top',
                fontsize=13,
                bbox=dict(facecolor='white', alpha=0.8))
        
        ax.set_title(protein_name.split('(')[0].strip(), fontsize=16, fontweight='bold')
        if row == 1 or row == 0:  # Only bottom row
            ax.set_xlabel('Time (months)', fontsize=12)
        if col == 0:  # Only leftmost column
            ax.set_ylabel('Survival probability', fontsize=12)
        else:
            ax.set_ylabel('')
        ax.legend(fontsize=13)  # Increase legend font size here

    plt.tight_layout()
    plt.savefig(f'{figure_path}/km_survival_tertiles.png', dpi=1200, bbox_inches='tight')
    plt.show()

# Run the analysis
create_multiplot_survival_analysis(
    df=df,
    protein_list=top_biomarkers,
    protein_names=protein_names
)