# Getting Started

## Import dependencies

In [None]:
%%capture

%pip install lifelines
%pip install lightgbm
%pip install numba==0.58.1
%pip install openpyxl
%pip install rpy2
%pip install scikit-learn
%pip install shap==0.44.0
%pip install statsmodels
%pip install xgboost

%load_ext rpy2.ipython


In [None]:
%%capture
%%R

require(data.table)
install.packages("caret")
library("pROC")
library("caret")


In [None]:
import dxdata
import dxpy
import glob
import io
import lifelines
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pyspark
import re
import scipy.stats as stats
import seaborn as sns
import shap
import statsmodels.api as sm
import statsmodels.formula.api as smf
import subprocess
import warnings

from collections import defaultdict
from datetime import datetime
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import logrank_test
from lightgbm import LGBMRegressor
from numpy import interp
from pathlib import Path
from PIL import Image
from scipy.stats import iqr, ttest_ind
from sklearn.feature_selection import SelectFromModel
from sklearn.metrics import accuracy_score, auc, balanced_accuracy_score, confusion_matrix, roc_auc_score, roc_curve 
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
from statsmodels.stats.multitest import fdrcorrection
from statsmodels.stats.proportion import proportion_confint
from tqdm import tqdm
from xgboost import XGBClassifier


In [None]:
conf = pyspark.SparkConf().set("spark.kryoserializer.buffer.max", "128m")
sc = pyspark.SparkContext(conf=conf)
spark = pyspark.sql.SparkSession(sc)


## Download files

In [None]:
! dx download "old_data/healthy_control_11_19_23.csv" --overwrite
! dx download "Bulk/Genotype\ Results/Genotype\ calls/ukb_rel.dat" --overwrite
! dx download /data/icd10_mapping.txt --overwrite


## Initialize variables

In [None]:
cutoff_date = datetime(1999, 1, 1)

! dx find projects --name "gut_brain" > projectid.txt
projectid = open("projectid.txt", "r")
projectid = projectid.read()
projectid = projectid.split(" : ")[0]

icd10_mapping = pd.read_csv("icd10_mapping.txt", sep="\t")
icd10_description_mapping = icd10_mapping.copy()
icd10_description_mapping.columns = icd10_description_mapping.iloc[0]
icd10_description_mapping = icd10_description_mapping.drop(0)
icd10_description_mapping.reset_index(drop=True, inplace=True)


## Define helper functions

In [None]:
def dx_upload(cmd):
    with open(os.devnull, 'w') as FNULL:
        subprocess.run(
            cmd,
            shell=True,
            stdout=FNULL, 
            stderr=FNULL,
        )
        

# Fetch cohorts

In [None]:
dispensed_dataset = dxpy.find_one_data_object(typename='Dataset', name='app*.dataset', folder='/', name_mode='glob')
dispensed_dataset_id = dispensed_dataset['id']
dataset = dxdata.load_dataset(id=dispensed_dataset_id)
participant = dataset['participant']
project_id = dxpy.find_one_project()["id"]


In [None]:
fields_covar = ["eid"] + ["p21022","p34","p22189","p31"]
fields_pcs = ["eid"] + ["p22009_a1","p22009_a2","p22009_a3","p22009_a4","p22009_a5"]
fields_icd10 = ["eid"] + [
    'p131036', 'p130836', 'p131022', 'p130000', 'p130002', 'p130004', 'p130006', 'p130008', 'p130010', 'p130012', 'p130014', 
    'p130016', 'p130018', 'p131552', 'p131554', 'p131556', 'p131558', 'p131560', 'p131562', 'p131564', 'p131566', 'p131568', 
    'p131570', 'p131572', 'p131574', 'p131576', 'p131578', 'p131580', 'p131582', 'p131584', 'p131586', 'p131588', 'p131590', 
    'p131592', 'p131594', 'p131596', 'p131598', 'p131600', 'p131602', 'p131604', 'p131606', 'p131608', 'p131610', 'p131612', 
    'p131614', 'p131616', 'p131618', 'p131620', 'p131622', 'p131624', 'p131626', 'p131628', 'p131630', 'p131632', 'p131634', 
    'p131636', 'p131638', 'p131640', 'p131642', 'p131644', 'p131646', 'p131648', 'p131650', 'p131652', 'p131654', 'p131656', 
    'p131658', 'p131660', 'p131662', 'p131664', 'p131666', 'p131668', 'p131670', 'p131672', 'p131674', 'p131676', 'p131678', 
    'p131680', 'p131682', 'p131684', 'p131686', 'p131688', 'p131690', 'p131692', 'p131694', 'p130692', 'p130694', 'p130696', 
    'p130698', 'p130700', 'p130702', 'p130704', 'p130706', 'p130708', 'p130710', 'p130712', 'p130714', 'p130716', 'p130718', 
    'p130720', 'p130722', 'p130724', 'p130726', 'p130728', 'p130730', 'p130732', 'p130734', 'p130736', 'p130738', 'p130740', 
    'p130742', 'p130744', 'p130746', 'p130748', 'p130750', 'p130752', 'p130756', 'p130758', 'p130760', 'p130762', 'p130764', 
    'p130766', 'p130768', 'p130770', 'p130772', 'p130774', 'p130776', 'p130778', 'p130780', 'p130782', 'p130784', 'p130786', 
    'p130788', 'p130790', 'p130792', 'p130794', 'p130796', 'p130798', 'p130800', 'p130802', 'p130804', 'p130806', 'p130808', 
    'p130810', 'p130812', 'p130814', 'p130816', 'p130818', 'p130820', 'p130822', 'p130824', 'p130826', 'p130828', 'p130830', 
    'p130832',
]
fields_death = ["eid"] + ["p40000_i0", "p40007_i0", "p40007_i1"]
fields_european = ["eid"] + ["p22006"]


In [None]:
ukb_pcs_mapping = {
    "p22009_a1" : "PC1",
    "p22009_a2" : "PC2",
    "p22009_a3" : "PC3",
    "p22009_a4" : "PC4",
    "p22009_a5" : "PC5",
}

ukb_covar_mapping = {
    "p34" : "Year_of_birth", 
    "p31" : "sex", 
    "p21022" : "Age_at_recruitment", 
    "p22189" : "townsend",
}

ukb_icd10_mapping = icd10_mapping.rename(columns=lambda col: "p" + col).to_dict(orient="records")[0]
ukb_icd10_mapping["p131036"] = "AD"
ukb_icd10_mapping["p130836"] = "AD_F00"
ukb_icd10_mapping["p131022"] = "PD"
icd_codes = [ukb_icd10_mapping[ukb_id] for ukb_id in fields_icd10[4:]]

ukb_death_mapping =  {
    "p40000_i0" : "death_date",
    "p40007_i0" : "death_age_1",
    "p40007_i1" : "death_age_2",
}

ukb_european_mapping = {"p22006" : "Caucasian"}


In [None]:
df_pcs = participant.retrieve_fields(names=fields_pcs, engine=dxdata.connect())
df_pcs = df_pcs.toPandas()
df_pcs.rename(columns=ukb_pcs_mapping, inplace=True)


In [None]:
df_covar = participant.retrieve_fields(names=fields_covar, engine=dxdata.connect())
df_covar = df_covar.toPandas()
df_covar.rename(columns=ukb_covar_mapping, inplace=True)

df_covar['start'] = df_covar['Year_of_birth'] + df_covar['Age_at_recruitment']


In [None]:
df_icd10 = participant.retrieve_fields(names=fields_icd10, engine=dxdata.connect())
df_icd10 = df_icd10.toPandas()
df_icd10.rename(columns=ukb_icd10_mapping, inplace=True)

df_icd10.loc[df_icd10['AD_F00'].notna() & df_icd10['AD'].isna(), 'AD'] = df_icd10['AD_F00']
df_icd10.drop('AD_F00', axis=1, inplace=True)


In [None]:
df_death = participant.retrieve_fields(names=fields_death, engine=dxdata.connect())
df_death = df_death.toPandas()
df_death.rename(columns=ukb_death_mapping, inplace=True)


In [None]:
df_controls = pd.read_csv('healthy_control_11_19_23.csv')[["eid", "p131036", "p131022"]]
df_controls.rename(columns={
    "p131036":"AD",
    "p131022":"PD",
}, inplace=True)
df_controls["eid"] = df_controls["eid"].astype(str)


In [None]:
df_european = participant.retrieve_fields(names=fields_european, engine=dxdata.connect())
df_european = df_european.toPandas()
df_european.rename(columns=ukb_european_mapping, inplace=True)

eur_ids = df_european[df_european['Caucasian'] == 1]['eid'].tolist()

df_icd10 = df_icd10[df_icd10['eid'].isin(eur_ids)]
df_controls = df_controls[df_controls['eid'].isin(eur_ids)]


In [None]:
df_relatedness = pd.read_csv('ukb_rel.dat', sep = r'\s+')
rel_remove = df_relatedness[df_relatedness['Kinship'] > 0.0884]
eids_to_remove = rel_remove["ID1"].tolist()

df_icd10 = df_icd10[~df_icd10['eid'].isin(eids_to_remove)]
df_controls = df_controls[~df_controls['eid'].isin(eids_to_remove)]


In [None]:
def process_lag_data(df_icd10, df_controls, df_covar, df_death, ndd):
    df = pd.concat([df_icd10.loc[df_icd10[ndd].notna(), ['eid',ndd]], df_controls[['eid',ndd]]], ignore_index=True)
    df = df.merge(df_covar, on='eid')
    df = df.merge(df_death, on='eid')
    
    df['complete_stop_date'] = pd.to_datetime(df['death_date'])
    df['complete_stop_date'].fillna(pd.Timestamp('2023-01-01'), inplace=True)
    df['stop'] = df['complete_stop_date'].dt.year
    df.drop(columns=["death_date"], inplace=True)

    df[ndd] = pd.to_datetime(df[ndd])
    df[f"{ndd}_year"] = df[ndd].dt.year
    df['duration'] = np.where(df[ndd].isnull(), df['stop'] - df['start'], df[f'{ndd}_year'] - df['start'])
    df['event'] = (df[ndd].notna()).astype(int)

    df.loc[df[ndd] < df['complete_stop_date'], 'complete_stop_date'] = df[ndd]
    df = df[df["stop"] > df["start"]]
    df = df[(df['start'] <= df[f'{ndd}_year']) | df[f'{ndd}_year'].isna()]
    
    return df


In [None]:
df_ad = process_lag_data(df_icd10, df_controls, df_covar, df_death, "AD")
df_pd = process_lag_data(df_icd10, df_controls, df_covar, df_death, "PD")


# Fetch proteomics data

In [None]:
if not Path("/mnt/project/results/olink_data.txt").exists():
    ! dx extract_dataset {project_id}:{dispensed_dataset_id} -ddd --delimiter ,

    data_dict_csv = glob.glob("*.data_dictionary.csv")[0]
    data_dict_df = pd.read_csv(data_dict_csv, low_memory=False)
    field_names = list(data_dict_df.loc[data_dict_df["entity"] == "olink_instance_0", "name"].values)
    field_names_str = [f"olink_instance_0.{f}" for f in field_names]
    field_names_query = ",".join(field_names_str)

    ! dx extract_dataset {project_id}:{dispensed_dataset_id} --fields {field_names_query} --delimiter , --output extracted_data.sql --sql
    with open("extracted_data.sql", "r") as file:
        retrieve_sql = ""
        for line in file:
            retrieve_sql += line.strip()

    df_olink = spark.sql(retrieve_sql.strip(";")).toPandas()
    rename_dict = {col: col.replace('olink_instance_0.', '') for col in df_olink.columns.tolist()}
    df_olink.rename(columns=rename_dict, inplace=True)

    df_olink.to_csv('olink_data.txt', index=False, sep="\t")
    ! dx upload 'olink_data.txt' --path /results/olink_data.txt

else:
    ! dx download /results/olink_data.txt --overwrite 
    df_olink = pd.read_csv("olink_data.txt", sep="\t")
    df_olink["eid"] = df_olink["eid"].astype(str)


# APOE genotyping

In [None]:
if not Path(f"/mnt/project/results/APOE_genotyping/plink/").is_dir():
    cmd = f"dx run swiss-army-knife "
    cmd += f"-iin='/data/plink/chr19_pgen.pgen' "
    cmd += f"-iin='/data/plink/chr19_pgen.pvar' "
    cmd += f"-iin='/data/plink/chr19_pgen.psam' "
    cmd += f"-iin='/data/APOE_variants.txt' "
    cmd += f"-icmd='plink2 --pfile chr19_pgen --extract APOE_variants.txt --make-bed --export compound-genotypes --out APOE_snps' "
    cmd += f"--instance-type mem2_ssd1_v2_x16 "
    cmd += f"--destination '{projectid}:/results/APOE_genotyping/plink/'"

    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        )

    if result.returncode != 0:
        print(f"Error running command:")
        print(result.stderr.decode("utf-8"))


In [None]:
if not Path(f"/mnt/project/results/APOE_genotyping/apoe_genotyping.tsv").exists():
    ! dx download /data/APOE_genotypes_PLINK_ped.py --overwrite
    ! dx download /results/APOE_genotyping/plink/APOE_snps.ped --overwrite
    ! python APOE_genotypes_PLINK_ped.py -i APOE_snps.ped -o APOE_final

    df_apoe = pd.read_csv('APOE_final.APOE_GENOTYPES.csv')

    # Count the number of APOE4 copies
    def count_e4_copies(apoe_genotype):
        if apoe_genotype == 'e4/e4':
            return 2
        elif 'e4' in apoe_genotype:
            return 1
        else:
            return 0

    # Create a new column in the DataFrame based on the APOE_GENOTYPE
    df_apoe['e4_copies'] = df_apoe['APOE_GENOTYPE'].apply(count_e4_copies)
    df_apoe.rename(columns = {"FID":"eid"}, inplace=True)
    df_apoe["eid"] = df_apoe["eid"].astype(str)

    df_apoe.to_csv('apoe_genotyping.tsv', index=False, sep="\t")
    ! dx upload apoe_genotyping.tsv --path /results/APOE_genotyping/apoe_genotyping.tsv

else:
    ! dx download /results/APOE_genotyping/apoe_genotyping.tsv
    df_apoe = pd.read_csv('apoe_genotyping.tsv', sep="\t")


In [None]:
df_apoe = df_apoe[["eid", "e4_copies"]]


# LRRK2 + GBA1 genotyping

### LRRK2

***
rs76904798 -> 12:40220632:C:T  
rs34637584 -> 12:40340400:G:A
***

In [None]:
if not Path(f"/mnt/project/results/LRRK2_GBA1/plink/").is_dir():
    cmd = f"dx run swiss-army-knife "
    cmd += f"-iin='/data/plink/chr12_pgen.pgen' "
    cmd += f"-iin='/data/plink/chr12_pgen.pvar' "
    cmd += f"-iin='/data/plink/chr12_pgen.psam' "
    cmd += f"-iin='/data/LRRK2_variants.txt' "
    cmd += f"-icmd='plink2 --pfile chr12_pgen --extract LRRK2_variants.txt --recode A --out LRRK2_counts' "
    cmd += f"--instance-type mem2_ssd1_v2_x16 "
    cmd += f"--destination '{projectid}:/results/LRRK2_GBA1/plink/'"

    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        )

    if result.returncode != 0:
        print(f"Error running command:")
        print(result.stderr.decode("utf-8"))


In [None]:
! dx download /results/LRRK2_GBA1/plink/LRRK2_counts.raw --overwrite
df_lrrk2 = pd.read_csv("LRRK2_counts.raw", sep="\t")
df_lrrk2['12:40220632:C:T_C'] = df_lrrk2['12:40220632:C:T_C'].apply(lambda x: 0 if x >= 1.5 else (1 if pd.notna(x) else 0))
df_lrrk2['12:40340400:G:A_G'] = df_lrrk2['12:40340400:G:A_G'].apply(lambda x: 0 if x == 2 else (1 if pd.notna(x) else np.nan))
df_lrrk2 = df_lrrk2[["IID", "12:40220632:C:T_C", "12:40340400:G:A_G"]]
df_lrrk2.rename(columns={"IID":"eid"}, inplace=True)


In [None]:
if not Path(f"/mnt/project/results/LRRK2_GBA1/LRRK2_carriers.tsv").exists():
    df_lrrk2.to_csv('LRRK2_carriers.tsv', sep='\t', index=False)
    ! dx upload LRRK2_carriers.tsv --path /results/LRRK2_GBA1/LRRK2_carriers.tsv


## GBA1

***
rs35749011 -> 1:155162560:G:A  
rs76763715 -> 1:155235843:T:C
***

In [None]:
if not Path(f"/mnt/project/results/LRRK2_GBA1/plink/").is_dir():
    cmd = f"dx run swiss-army-knife "
    cmd += f"-iin='/data/plink/chr1_pgen.pgen' "
    cmd += f"-iin='/data/plink/chr1_pgen.pvar' "
    cmd += f"-iin='/data/plink/chr1_pgen.psam' "
    cmd += f"-iin='/data/GBA1_variants.txt' "
    cmd += f"-icmd='plink2 --pfile chr1_pgen --extract GBA1_variants.txt --recode A --out GBA1_counts' "
    cmd += f"--instance-type mem2_ssd1_v2_x16 "
    cmd += f"--destination '{projectid}:/results/LRRK2_GBA1/plink/'"

    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        )

    if result.returncode != 0:
        print(f"Error running command:")
        print(result.stderr.decode("utf-8"))


In [None]:
! dx download /results/LRRK2_GBA1/plink/GBA1_counts.raw --overwrite
df_gba1 = pd.read_csv("GBA1_counts.raw", sep="\t")
df_gba1['1:155162560:G:A_G'] = df_gba1['1:155162560:G:A_G'].apply(lambda x: 0 if x == 2 else (1 if pd.notna(x) else np.nan))
df_gba1['1:155235843:T:C_T'] = df_gba1['1:155235843:T:C_T'].apply(lambda x: 0 if x == 2 else (1 if pd.notna(x) else np.nan))
df_gba1 = df_gba1[["IID", "1:155162560:G:A_G", "1:155235843:T:C_T"]]
df_gba1.rename(columns={"IID":"eid"}, inplace=True)


In [None]:
if not Path(f"/mnt/project/results/LRRK2_GBA1/GBA1_carriers.tsv").exists():
    df_gba1.to_csv("GBA1_carriers.tsv", sep='\t', index=False)
    ! dx upload GBA1_carriers.tsv --path /results/LRRK2_GBA1/GBA1_carriers.tsv


# PRS

In [None]:
for ndd in ["ad", "ad_no_apoe", "pd"]:
    if not Path(f"/mnt/project/results/prs/{ndd}/plink/").is_dir():
        for chrnum in range(1, 23):
            cmd = f"dx run swiss-army-knife "
            cmd += f"-iin='/data/plink/chr{chrnum}_pgen.pgen' "
            cmd += f"-iin='/data/plink/chr{chrnum}_pgen.pvar' "
            cmd += f"-iin='/data/plink/chr{chrnum}_pgen.psam' "
            cmd += f"-iin='/data/loci_{ndd}.txt' "
            cmd += f"-icmd='plink2 --pfile chr{chrnum}_pgen --extract loci_{ndd}.txt --make-pgen --out chr{chrnum}' "
            cmd += f"--instance-type mem2_ssd1_v2_x16 "
            cmd += f"--destination '{projectid}:/results/prs/{ndd}/plink/'"

            result = subprocess.run(
                cmd, 
                shell=True, 
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )

            if result.returncode != 0:
                print(f"Error running command:")
                print(result.stderr.decode("utf-8"))


In [None]:
%%capture

for ndd in ["ad", "ad_no_apoe", "pd"]:
    if len(glob.glob(f"/mnt/project/results/prs/{ndd}/plink/*_chr.txt")) == 0:
        ! mkdir {ndd}
        for chrnum in range(1, 23):
            ! dx download /results/prs/{ndd}/plink/chr{chrnum}.pvar --overwrite
            ! mv chr{chrnum}.pvar {ndd}/
            
        with open(f"{ndd}_chr.txt","w") as f:
            for chrnum in range(1, 23):
                if Path(f"{ndd}/chr{chrnum}.pvar").exists():
                    f.write(f"chr{chrnum}\n")
        ! dx upload {ndd}_chr.txt --path /results/prs/{ndd}/plink/{ndd}_chr.txt


In [None]:
for ndd in ["ad", "ad_no_apoe", "pd"]:
    if len(glob.glob(f"/mnt/project/results/prs/{ndd}/plink/*merged*")) == 0:
        with open(f'{ndd}_chr.txt', 'r') as file:
            chromosomes = [chromosome.replace("\n","") for chromosome in file.readlines()]

        cmd = f"dx run swiss-army-knife "
        for chromosome in chromosomes:
            cmd += f"-iin='/results/prs/{ndd}/plink/{chromosome}.pgen' "
            cmd += f"-iin='/results/prs/{ndd}/plink/{chromosome}.pvar' "
            cmd += f"-iin='/results/prs/{ndd}/plink/{chromosome}.psam' "
        cmd += f"-iin='/results/prs/{ndd}/plink/{ndd}_chr.txt' "
        cmd += f"-icmd='plink2 --pmerge-list {ndd}_chr.txt --set-all-var-ids @:# --make-pgen --out merged' "
        cmd += f"--instance-type mem2_ssd1_v2_x16 "
        cmd += f"--destination '{projectid}:/results/prs/{ndd}/plink/'"

        result = subprocess.run(
            cmd, 
            shell=True, 
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )

        if result.returncode != 0:
            print(f"Error running command:")
            print(result.stderr.decode("utf-8"))


In [None]:
for ndd in ["ad", "ad_no_apoe", "pd"]:
    if len(glob.glob(f"/mnt/project/results/prs/{ndd}/*.sscore")) == 0:
        cmd = f"dx run swiss-army-knife "
        cmd += f"-iin='/results/prs/{ndd}/plink/merged.pgen' "
        cmd += f"-iin='/results/prs/{ndd}/plink/merged.pvar' "
        cmd += f"-iin='/results/prs/{ndd}/plink/merged.psam' "
        cmd += f"-iin='/data/sumstats_{ndd}.txt' "
        cmd += f"-icmd='plink2 --pfile merged --score sumstats_{ndd}.txt list-variants --out {ndd}' "
        cmd += f"--instance-type mem2_ssd1_v2_x16 "
        cmd += f"--destination '{projectid}:/results/prs/{ndd}/'"

        result = subprocess.run(
            cmd, 
            shell=True, 
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )

        if result.returncode != 0:
            print(f"Error running command:")
            print(result.stderr.decode("utf-8"))


In [None]:
! dx download /results/prs/pd/pd.sscore --overwrite
! dx download /results/prs/ad/ad.sscore --overwrite
! dx download /results/prs/ad_no_apoe/ad_no_apoe.sscore --overwrite


In [None]:
def calculate_zscore(score_path, ids_control):
    df = pd.read_csv(score_path, sep= "\t")
    df["#FID"] = df["#FID"].astype(str)
    mean_controls = df["SCORE1_AVG"][df["#FID"].isin(ids_control)].mean()
    sd_controls = df["SCORE1_AVG"][df["#FID"].isin(ids_control)].std()
    return (df["SCORE1_AVG"] - mean_controls) / sd_controls


In [None]:
df_pd["zscore_pd"] = calculate_zscore("pd.sscore", df_controls["eid"])
df_ad["zscore_ad"] = calculate_zscore("ad.sscore", df_controls["eid"])
df_ad["zscore_ad_without_apoe"] = calculate_zscore("ad_no_apoe.sscore", df_controls["eid"])


# PRS Sensitivity Analyses

In [None]:
def prs_regression(df_ndd_icd, zscore, icd_code):
    formula = f"pheno ~ {zscore} + sex + age + PC1 + PC2 + PC3 + PC4 + PC5"
    model = smf.glm(formula, data=df_ndd_icd, family=sm.families.Binomial()).fit()

    glm_summary = model.summary2().tables[1]
    zscore_summary = glm_summary.loc[zscore, ["Coef.", "Std.Err.", "P>|z|"]]

    conf_int = model.conf_int(alpha=0.05)
    zscore_conf = conf_int.loc[zscore]
    df_summary = pd.DataFrame({
        "ICD-10" : [icd_code],
        "BETA" : [zscore_summary["Coef."]],
        "SE" : [zscore_summary["Std.Err."]],
        "P" : [zscore_summary["P>|z|"]],
        "L95" : [zscore_conf[0]],
        "U95" : [zscore_conf[1]],
    })
    
    return df_summary


In [None]:
def prs_confmat(df_ndd_icd, zscore, icd_code, ndd):
    model = smf.glm(f"pheno ~ {zscore}", data=df_ndd_icd, family=sm.families.Binomial()).fit()
    
    df_ndd_icd["probDisease"] = model.predict(df_ndd_icd)
    df_ndd_icd["reported"] = df_ndd_icd["pheno"].apply(lambda x: "DISEASE" if x == 1 else "CONTROL")

    # ROC analysis
    fpr, tpr, thresholds = roc_curve(df_ndd_icd["pheno"], df_ndd_icd["probDisease"])
    auc_value = roc_auc_score(df_ndd_icd["pheno"], df_ndd_icd["probDisease"])

    # Find best threshold and re-classify
    distances = ((1 - tpr)**2 + fpr**2)**0.5
    best_idx = distances.argmin()
    best_threshold = thresholds[best_idx]
    df_ndd_icd["predicted"] = df_ndd_icd["probDisease"].apply(lambda p: "DISEASE" if p > best_threshold else "CONTROL")

    # Confusion matrix + metrics
    cm = confusion_matrix(df_ndd_icd["reported"], df_ndd_icd["predicted"], labels=["CONTROL", "DISEASE"])
    tn, fp, fn, tp = cm.ravel()
    accuracy = accuracy_score(df_ndd_icd["reported"], df_ndd_icd["predicted"])
    balanced_acc = balanced_accuracy_score(df_ndd_icd["reported"], df_ndd_icd["predicted"])
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    # 95% CI using Clopper-Pearson
    correct_preds = (df_ndd_icd["predicted"] == df_ndd_icd["reported"]).sum()
    ci_lower, ci_upper = proportion_confint(count=correct_preds, nobs=len(df_ndd_icd), alpha=0.05)
    
    df_summary = pd.DataFrame({
        "ICD-10" : [icd_code],
        "AUC" : [auc_value],
        "Accuracy" : [accuracy],
        "L95" : [ci_lower],
        "U95" : [ci_upper],
        "Balanced_Accuracy" : [balanced_acc],
        "Sensitivity" : [sensitivity],
        "Specificity" : [specificity],
    })
    
    # Plot ROC curves
    plt.figure(figsize=(8, 5))
    fpr, tpr, _ = roc_curve(df_ndd_icd["pheno"], df_ndd_icd["probDisease"])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
    
    plt.plot([0, 1], [0, 1], "k--", lw=2)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve for {icd_code} and {ndd}")
    plt.legend(loc="lower right")
    zsc_name = "_".join(zscore.split("_")[1:])
    plt.savefig(f"roc_{icd_code}_{zsc_name}.png", dpi=300)
    plt.close()
    
    dx_upload(f"dx upload roc_{icd_code}_{zsc_name}.png --path results/prs/sensitivity_analyses/roc_{icd_code}_{zsc_name}.png")
    
    return df_summary


In [None]:
def prep_sensitivity_analysis(icd_codes, ndd, df_ndd, df_icd10, zscore):
    df_ndd = df_ndd.loc[df_ndd[ndd].notna(), ["eid", "sex", "Year_of_birth", ndd, "death_age_1", "death_age_2", zscore]]
    df_ndd.insert(1, "age", (pd.to_datetime(df_ndd[ndd]).dt.year - df_ndd["Year_of_birth"]).fillna(df_ndd["death_age_2"].fillna(df_ndd["death_age_1"]).fillna(2023 - df_ndd["Year_of_birth"])))
    df_ndd["sex"] = df_ndd["sex"].replace({"Male": 0, "Female": 1})
    df_ndd = df_ndd.merge(df_pcs, on="eid")
    
    df_regression_summary = []
    df_confmat = []
    for icd_code in tqdm(icd_codes, desc=f"{ndd} PRS Sensitivity Analysis"):
        df_ndd_icd = df_ndd.merge(df_icd10[["eid", icd_code]], on="eid")
        df_ndd_icd.insert(df_ndd_icd.shape[1], "pheno", df_ndd_icd[icd_code].notna())
        df_ndd_icd = df_ndd_icd[["eid", "sex", "age", "PC1", "PC2", "PC3", "PC4", "PC5", "pheno", zscore]]
        
        df_regression_summary.append(prs_regression(df_ndd_icd, zscore, icd_code))
        df_confmat.append(prs_confmat(df_ndd_icd, zscore, icd_code, ndd))
    
    zsc_name = "_".join(zscore.split("_")[1:])
    
    df_regression_summary = pd.concat(df_regression_summary)
    df_regression_summary.to_csv(f"{zsc}_regression.tsv", sep="\t", index=False)
    dx_upload(f"dx upload {zsc}_regression.tsv --path results/prs/sensitivity_analyses/{zsc}_regression.tsv")
    
    df_confmat = pd.concat(df_confmat)
    df_confmat.to_csv(f"{zsc}_confmat.tsv", sep="\t", index=False)
    dx_upload(f"dx upload {zsc}_confmat.tsv --path results/prs/sensitivity_analyses/{zsc}_confmat.tsv")
    

In [None]:
prep_sensitivity_analysis(icd_codes, "PD", df_pd, df_icd10, "zscore_pd")
prep_sensitivity_analysis(icd_codes, "AD", df_ad, df_icd10, "zscore_ad")
prep_sensitivity_analysis(icd_codes, "AD", df_ad, df_icd10, "zscore_ad_without_apoe")


## ROC Analysis

In [None]:
%%capture

! dx upload pd_roc_results.txt --path /sg_results/roc/pd/pd_roc_results.txt
! dx upload ad_roc_results.txt --path /sg_results/roc/ad/ad_roc_results.txt
! dx upload ad_without_apoe_roc_results.txt --path /sg_results/roc/ad_without_apoe/ad_without_apoe_roc_results.txt


In [None]:
fnames = glob.glob("*predictions.txt")
fnames = [fname.replace("_predictions.txt", "") for fname in fnames]
for fname in fnames:
    to_plot = pd.read_csv(f"{fname}_predictions.txt", sep="\t")
    
    # Plot ROC curves
    plt.figure(figsize=(8, 5))
    fpr, tpr, _ = roc_curve(to_plot["PHENO"], to_plot["probDisease"])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
    
    plt.plot([0, 1], [0, 1], "k--", lw=2)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve for {fname.replace('_',' ').replace('PRS final ', '')}")
    plt.legend(loc="lower right")
    output_path = f"roc_{fname}.png"
    plt.savefig(output_path, dpi=300)
    plt.close()


In [None]:
%%capture

fnames = glob.glob("roc_*.png")
for fname in fnames:
    if fname.startswith("roc_pd_"):
        ! dx upload {fname} --path /sg_results/roc/pd/{fname}
    elif fname.startswith("roc_ad_without_apoe_"):
        ! dx upload {fname} --path /sg_results/roc/ad_without_apoe/{fname}
    else:
        ! dx upload {fname} --path /sg_results/roc/ad/{fname}


# Cox Regression

In [None]:
def run_cox_regression(
    icd_codes, 
    ndd, 
    df_ndd, 
    df_icd10, 
    df_covar, 
    icd10_description_mapping, 
    cutoff_date, 
    lag_min=None, 
    lag_max=None, 
    df_pcs=None, 
    df_riskvars=None,
    df_zscore=None
):
    cox_summary = []
    
    out_path = f"cox_{ndd.lower()}"
    
    if lag_min is not None:
        out_path += f"_{lag_min}year_{lag_max}year"
    if df_pcs is not None:
        out_path += "_pcs"
    if df_zscore is not None:
        out_path += f"_{df_zscore.columns.values[1].replace(f'_{ndd.lower()}','')}"
    if df_riskvars is not None:
        out_path += "_" + "_".join([var.replace("_",":") for var in list(df_riskvars.columns.values)[1:]])
    
    for icd_code in tqdm(icd_codes, desc=f"{ndd} Cox Regression"):
        df = pd.merge(df_ndd[["eid", "complete_stop_date", "duration", "event", ndd]], df_icd10[["eid", icd_code]], on="eid")
        df = df.merge(df_covar[["eid", "Year_of_birth", "townsend", "sex"]], on='eid')
        if df_pcs is not None:
            df = df.merge(df_pcs, on='eid')
        if df_riskvars is not None:
            df = df.merge(df_riskvars, on='eid')
        if df_zscore is not None:
            df = df.merge(df_zscore, on='eid')
    
        df[ndd] = pd.to_datetime(df[ndd])
        df[icd_code] = pd.to_datetime(df[icd_code])
        
        df = df[~(df[icd_code] < cutoff_date)]
        df = df[~(df[ndd] < cutoff_date)]
        df.loc[df[icd_code] >= df[ndd], icd_code] = pd.NaT
        
        if lag_min is not None:
            df['years_from_stop'] = (df['complete_stop_date'] - df[icd_code]).dt.days / 365.25
            df.loc[(df['years_from_stop'] < lag_min) | (df['years_from_stop'] >= lag_max), icd_code] = pd.NaT
            df = df[(df[icd_code].isna()) | ((df['years_from_stop'] >= lag_min) & (df['years_from_stop'] < lag_max))].copy()
            df.drop(columns=["years_from_stop"], inplace=True)
        df.drop(columns = ["complete_stop_date", ndd, "eid"], inplace=True)
        df[icd_code] = (df[icd_code].notna()).astype(int)
        
        if sum(df[icd_code] == 1) < 3 or sum(df[df[icd_code] == 1]['event']) <= 1:
            continue

        formula = f'{icd_code} + Year_of_birth + townsend + sex'
        if df_pcs is not None:
            formula = formula + " + " + " + ".join(list(df_pcs.columns.values)[1:])
        if df_riskvars is not None:
            formula = formula + " + " + " + ".join(list(df_riskvars.columns.values)[1:])
        if df_zscore is not None:
            formula = formula + f" + {df_zscore.columns.values[1]}"

        df = df.dropna()
        cph = CoxPHFitter()
        model = cph.fit(
            df,
            duration_col = 'duration', 
            event_col = 'event',
            formula = formula,
            fit_options = {"max_steps": 100, 'step_size': 0.1},
        )
        
        cox_summary.append(pd.DataFrame({
            "ICD10" : [icd_code],
            "NDD" : [ndd],
            "Description" : [icd10_description_mapping.loc[0, icd_code]],
            "HR" : [model.summary['exp(coef)'].iloc[0]],
            "ci_min" : [model.summary['exp(coef) lower 95%'].iloc[0]],
            "ci_max" : [model.summary['exp(coef) upper 95%'].iloc[0]],
            "P_VAL" : [model.summary['p'].iloc[0]],
            "N_pairs" : [sum(df[df[icd_code] == 1]['event'])],
            "n" : [sum(df[icd_code] == 1)], 
            "icd10_diagnosis_range" : [f'{icd_code}_{lag_min}-{lag_max}_years' if lag_min is not None else "complete"],
        }))
        
    cox_summary = pd.concat(cox_summary)
    rejected, pvals_corrected = fdrcorrection(cox_summary['P_VAL'].values)
    cox_summary['P_VAL_FDR_CORRECTED'] = pvals_corrected
    cox_summary['rejected'] = rejected
    cox_summary = cox_summary.sort_values(by='HR', ascending=False)
    cox_summary = cox_summary[cox_summary['N_pairs'] > 5]
    
    cox_summary.to_csv(f"{out_path}.txt", sep="\t", index=False)
    dx_upload(f"dx upload {out_path}.txt --path results/cox/{out_path}.txt")

    return cox_summary


### Time-stratified cox regression analysis

In [None]:
ad_cox_summary_1_5 = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=1, lag_max=5,
)

ad_cox_summary_5_10 = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=5, lag_max=10,
)

ad_cox_summary_10_15 = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=10, lag_max=15,
)


In [None]:
pd_cox_summary_1_5 = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=1, lag_max=5,
)

pd_cox_summary_5_10 = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=5, lag_max=10,
)

pd_cox_summary_10_15 = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, lag_min=10, lag_max=15,
)


### Cox regression analysis without time stratification

In [None]:
ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date,
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs,
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_zscore=df_ad[["eid", "zscore_ad"]],
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_zscore=df_ad[["eid", "zscore_ad_without_apoe"]],
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_apoe, 
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_apoe, df_zscore=df_ad[["eid", "zscore_ad"]],
)

ad_cox_summary = run_cox_regression(
    icd_codes, "AD", df_ad, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_apoe, df_zscore=df_ad[["eid", "zscore_ad_without_apoe"]],
)


In [None]:
pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date,
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155162560:G:A_G"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155235843:T:C_T"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40220632:C:T_C"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40340400:G:A_G"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155162560:G:A_G"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155235843:T:C_T"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40220632:C:T_C"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40340400:G:A_G"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155162560:G:A_G", "1:155235843:T:C_T"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40220632:C:T_C", "12:40340400:G:A_G"]], 
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_gba1[["eid", "1:155162560:G:A_G", "1:155235843:T:C_T"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)

pd_cox_summary = run_cox_regression(
    icd_codes, "PD", df_pd, df_icd10, df_covar, icd10_description_mapping, cutoff_date, df_pcs=df_pcs, df_riskvars=df_lrrk2[["eid", "12:40220632:C:T_C", "12:40340400:G:A_G"]], df_zscore=df_pd[["eid", "zscore_pd"]],
)


# Kaplan-Meier Survival Analysis

In [None]:
def plot_kaplan_meier(
    icd_codes, 
    ndd, 
    df_ndd, 
    df_icd10, 
    cutoff_date, 
    icd10_description_mapping,
):    
    summary_data = []
    for icd_code in tqdm(icd_codes, desc=f"{ndd} Kaplan-Meier Survival Analysis"):
        df = pd.merge(df_ndd[["eid", "complete_stop_date", "duration", "event", ndd]], df_icd10[["eid", icd_code]], on="eid")
    
        df[ndd] = pd.to_datetime(df[ndd])
        df[icd_code] = pd.to_datetime(df[icd_code])
        
        df = df[~(df[ndd] < cutoff_date)]
        df = df[~(df[icd_code] < cutoff_date)]
        df.loc[df[icd_code] >= df[ndd], icd_code] = pd.NaT
        
        df.drop(columns = [ndd, "eid"], inplace=True)
        df[icd_code] = (df[icd_code].notna()).astype(int)
        
        if sum(df[icd_code] == 1) < 3 or sum(df[df[icd_code] == 1]['event']) <= 1:
            continue

        df = df.dropna()

        df_duration = df['duration']
        df_event = df['event']
        icd_mask = df[icd_code] == 1
        
        kmf_without = KaplanMeierFitter()
        kmf_with = KaplanMeierFitter()

        kmf_without.fit(durations=df_duration[~icd_mask], event_observed=df_event[~icd_mask], label=f"Without {icd10_description_mapping.loc[0, icd_code]}")
        kmf_with.fit(durations=df_duration[icd_mask], event_observed=df_event[icd_mask], label=f"With {icd10_description_mapping.loc[0, icd_code]}")
        
        fig, ax = plt.subplots()
        kmf_without.plot_survival_function(ax=ax, ci_show=True)
        kmf_with.plot_survival_function(ax=ax, ci_show=True)

        final_surv_with = round(kmf_with.survival_function_[f"With {icd10_description_mapping.loc[0, icd_code]}"].iloc[-1], 3)
        final_surv_without = round(kmf_without.survival_function_[f"Without {icd10_description_mapping.loc[0, icd_code]}"].iloc[-1], 3)

        plt.title(f"Development of {ndd} in {icd_code} vs Non-{icd_code} Groups")
        plt.xlabel("Duration in years")
        plt.ylabel(f"Proportion Without {ndd}")
        ax.legend(fontsize=7, loc='lower left')
        plt.savefig(f'Kaplan_Meier_{ndd}_{icd_code}.png')
        plt.close(fig)

        results = logrank_test(
            df_duration[~icd_mask], 
            df_duration[icd_mask], 
            event_observed_A=df_event[~icd_mask], 
            event_observed_B=df_event[icd_mask],
        )

        summary_data.append({
            "ICD-10 Code": icd_code,
            "Condition Description" : icd10_description_mapping.loc[0, icd_code],
            "Number with Condition" : int(icd_mask.sum()),
            "Number without Condition" : int((~icd_mask).sum()),
            f"{ndd} Events in Condition Group" : int(df_event[icd_mask].sum()),
            f"{ndd} Events in No-Condition Group" : int(df_event[~icd_mask].sum()),
            f"Final {ndd}-Free Survival (With Condition)" : final_surv_with,
            f"Final {ndd}-Free Survival (Without Condition)" : final_surv_without,
            "Absolute difference (%)" : ((final_surv_without - final_surv_with) * 100).round(2),
            "Log-Rank p-Value" : f"{results.p_value:.2e}",
            "Statistically Significant (p < 0.05)" : "Yes" if results.p_value < 0.05 else "No"
        })

    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(f"Kaplan_Meier_{ndd}_Summary.txt", index=False, sep="\t")
    dx_upload(f"dx upload Kaplan_Meier_{ndd}_Summary.txt --path results/kaplan_meier/Kaplan_Meier_{ndd}_Summary.txt")
  

In [None]:
plot_kaplan_meier(icd_codes, "AD", df_ad, df_icd10, cutoff_date, icd10_description_mapping)
plot_kaplan_meier(icd_codes, "PD", df_pd, df_icd10, cutoff_date, icd10_description_mapping)


# PRS-ICD10 Interaction GLM

In [None]:
def glm_interaction(icd_codes, ndd, df_ndd, df_covar, icd10_description_mapping, zscore):
    interaction_summary = []
    for icd_code in tqdm(icd_codes, desc=f"{ndd} PRS-ICD10 GLM"):
        df = pd.merge(df_ndd[["eid", ndd, zscore]], df_icd10[["eid", icd_code]], on="eid")
        df = df.merge(df_covar[["eid", "Year_of_birth", "townsend", "sex"]], on='eid')
        
        df[ndd] = pd.to_datetime(df[ndd])
        df[icd_code] = pd.to_datetime(df[icd_code])
        
        df = df[~(df[icd_code] < cutoff_date)]
        df = df[~(df[ndd] < cutoff_date)]
        df.loc[df[icd_code] >= df[ndd], icd_code] = pd.NaT
        
        df[icd_code] = (~df[icd_code].isna()).astype(int)
        df[ndd] = (~df[ndd].isna()).astype(int)
        
        if sum(df[icd_code] == 1) < 3 or sum(df[df[icd_code] == 1][ndd]) <= 1:
            continue
        
        formula = f"{ndd} ~ ({zscore} * {icd_code}) + Year_of_birth + townsend + sex"
        model = smf.glm(formula=formula, data=df, family=sm.families.Binomial()).fit()

        interaction = f"{zscore}:{icd_code}"
        interaction_ci = model.conf_int().loc[interaction]
        
        interaction_summary.append(pd.DataFrame({
            "Interaction term" : [interaction],
            "NDD" : [ndd],
            "ICD10-Description" : [icd10_description_mapping.loc[0, icd_code]],
            "ICD10-CODE" : [icd_code],
            "OR" : [np.exp(model.params[interaction])],
            "Beta" : [model.params[interaction]],
            "SE" : [model.bse[interaction]],
            "95% CI low" : [np.exp(interaction_ci)[0]],
            "95% CI high" : [np.exp(interaction_ci)[1]],
            "z" : [np.exp(model.params[interaction]) / model.bse[interaction]],
            "P-value" : [model.pvalues[interaction]],
            "N_pairs" : [sum(df[df[icd_code] == 1][ndd])],
            "n" : [sum(df[icd_code])],
        }))
        
    interaction_summary = pd.concat(interaction_summary)
    interaction_summary.reset_index(drop=True, inplace=True)
    rejected, pvals_corrected = fdrcorrection(interaction_summary['P-value'].values)
    interaction_summary['P_VAL_FDR_CORRECTED'] = pvals_corrected
    interaction_summary['rejected'] = rejected
    interaction_summary["Interaction term"] = interaction_summary.apply(
        lambda row: row["Interaction term"].split(":")[0] + ":" + row["ICD10-CODE"], axis=1
    )
    
    interaction_summary.to_csv(f'{ndd}_{zscore}_glm_interaction_summary.txt', index=False, sep="\t")
    dx_upload(f"dx upload {ndd}_{zscore}_glm_interaction_summary.txt --path results/ndd_icd10_interaction/{ndd}_{zscore}_glm_interaction_summary.txt")


In [None]:
glm_interaction(icd_codes, "PD", df_pd, df_covar, icd10_description_mapping, "zscore_pd")
glm_interaction(icd_codes, "AD", df_ad, df_covar, icd10_description_mapping, "zscore_ad")
glm_interaction(icd_codes, "AD", df_ad, df_covar, icd10_description_mapping, "zscore_ad_without_apoe")


# Density Plots of PRS z-Scores

In [None]:
def density_plot(icd_codes, ndd, df_ndd, icd10_description_mapping, zscore):
    for icd_code in tqdm(icd_codes, desc=f"{ndd} PRS Density Plots"):
        df = pd.merge(df_ndd[["eid", ndd, zscore]], df_icd10[["eid", icd_code]], on="eid")
        
        df_ndd_only = df.loc[df[ndd].notna() & df[icd_code].isna()].copy()
        df_ndd_icd = df.loc[df[ndd].notna() & df[icd_code].notna()].copy()
        
        if len(df_ndd_icd) < 5 or len(df_ndd_only) < 5:
            continue
            
        fig, ax = plt.subplots(figsize=(12, 6))
        
        sns.kdeplot(df_ndd_only[zscore], label=ndd, fill=True, ax=ax)
        sns.kdeplot(df_ndd_icd[zscore], label=f'{ndd} and {icd10_description_mapping.loc[0, icd_code]}', fill=True, ax=ax)
        
        t_stat, p_val = ttest_ind(df_ndd_icd[zscore].dropna(), df_ndd_only[zscore].dropna(), equal_var=False, nan_policy='omit')
        ax.text(0.05, 0.95, f'T-stat: {t_stat:.2f}\nP-val: {p_val:.3e}', transform=ax.transAxes, fontsize=10, verticalalignment='top')
        ax.legend()
        ax.set_title(f'PRS Density Plot for individuals with {icd_code} and {ndd}')
        ax.set_xlabel("PRS " + zscore.replace(f"_{ndd.lower()}","").replace("no apoe","without APOE"))
        ax.set_ylabel('Density')
        
        plt.tight_layout()
        zsc_name = "_".join(zscore.split("_")[1:])
        plt.savefig(f'{zsc_name}_{icd_code}_densityplot.png')
        plt.close()
        
        dx_upload(f"dx upload {zsc_name}_{icd_code}_densityplot.png --path results/density_plots/{zsc_name}/{zsc_name}_{icd_code}_densityplot.png")


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=FutureWarning)
    density_plot(icd_codes, "PD", df_pd, icd10_description_mapping, "zscore_pd")
    density_plot(icd_codes, "AD", df_ad, icd10_description_mapping, "zscore_ad")
    density_plot(icd_codes, "AD", df_ad, icd10_description_mapping, "zscore_ad_without_apoe")


#  Association Between Olink Biomarker Levels and Alzheimer's Disease

In [None]:
def olink_ndd_association(ndd, df_ndd, df_olink, df_covar, df_pcs, zscore):
    df = pd.merge(df_ndd[["eid", ndd, zscore]], df_covar, on="eid")
    df = df.merge(df_olink, on="eid")
    df = df.merge(df_pcs, on="eid")
    df[ndd] = df[ndd].notna().astype(int)
    
    df_association_summary = []
    for olink_marker in tqdm(df_olink.columns.tolist()[1:], desc=f"{ndd} OLINK GLMs"):
        if len(df[df[olink_marker].notna()]) < 5:
            continue
            
        formula = f"{ndd} ~ {olink_marker} + {zscore} + Age_at_recruitment + townsend + sex + PC1 + PC2 + PC3 + PC4 + PC5"

        model = smf.glm(formula=formula, data=df, family=sm.families.Binomial()).fit()

        n_cases = len(df[df[olink_marker].notna() & (df[ndd] == 1)])
        n_controls = len(df[df[olink_marker].notna() & (df[ndd] == 0)])
        
        try:
            influence = model.get_influence()
            cooks_d = influence.cooks_distance[0]
            max_cooks_d = cooks_d.max()
            potential_outlier = max_cooks_d > 0.5
        except np.linalg.LinAlgError:
            max_cooks_d = np.nan
            potential_outlier = False
        
        df_association_summary.append(pd.DataFrame({
            "NDD" : [ndd],
            "olink_marker" : [olink_marker],
            "odds_ratio" : [np.exp(model.params.iloc[1])],
            "ci_min" : [np.exp(model.conf_int(alpha=0.05).iloc[1, 0])],
            "ci_max" : [np.exp(model.conf_int(alpha=0.05).iloc[1, 1])],
            "P_VAL" : [model.pvalues.iloc[1]],
            "n_cases" : [n_cases],
            "n_controls" : [n_controls],
            "nobs" : [model.nobs],
            "max_cooks_d" : [max_cooks_d],
            "fixed_threshold" : [0.5],
            "potential_outlier" : [potential_outlier],
        }))
        
    df_association_summary = pd.concat(df_association_summary)
    df_association_summary['Bonferroni_Significant'] = df_association_summary['P_VAL'] < 0.05 / len(df_association_summary)
    
    p_values = df_association_summary['P_VAL'].values
    rejected, pvals_corrected = fdrcorrection(p_values)
    df_association_summary['rejected'] = rejected
    df_association_summary['P_VAL_FDR_CORRECTED'] = pvals_corrected
    
    df_association_summary = df_association_summary.merge(on='olink_marker', how='left')
    df_association_summary = df_association_summary.sort_values(by='odds_ratio')
    
    zsc_name = "_".join(zscore.split("_")[1:])
    df_association_summary.to_csv(f"{zsc_name}_olink_association.txt", index=False, sep="\t")
    dx_upload(f"dx upload {zsc_name}_olink_association.txt --path results/olink_association/{zsc_name}_olink_association.txt")
    

In [None]:
olink_ndd_association("PD", df_pd, df_olink, df_covar, df_pcs, "zscore_pd")
olink_ndd_association("AD", df_ad, df_olink, df_covar, df_pcs, "zscore_ad")
olink_ndd_association("AD", df_ad, df_olink, df_covar, df_pcs, "zscore_ad_without_apoe")


# t-test for Olink biomarker levels in NDD vs. NDD + ICD10

In [None]:
def olink_ndd_ttest(icd_codes, ndd, df_ndd, df_olink, df_covar, zscore):    
    df_ttest_summary = []
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    
    zsc_name = "_".join(zscore.split("_")[1:])
    olink_df = pd.read_csv(f"{zsc_name}_olink_association.txt", sep="\t")
    olink_df = olink_df[(olink_df["odds_ratio"] >= 1) & (olink_df["P_VAL_FDR_CORRECTED"] <= 0.05)]
    olink_markers = olink_df["olink_marker"].tolist()
    
    for icd_code in tqdm(icd_codes, desc=f"{ndd} Olink-ICD-NDD t-test"):
        for olink in olink_markers:
            df = pd.merge(df_ndd[["eid", ndd, zscore]], df_covar, on="eid")
            df = df.merge(df_olink[["eid", olink]], on="eid")
            df = df.merge(df_icd10[["eid", icd_code]], on="eid")
            df[ndd] = df[ndd].notna().astype(int)
            
            df.dropna(subset=[olink], inplace=True)
        
            df_ndd_only = df.loc[df[ndd].notna() & df[icd_code].isna()].copy()
            df_ndd_icd = df.loc[df[icd_code].notna() & df[ndd].notna()].copy()

            t_stat, p_val = ttest_ind(df_ndd_icd[olink].dropna(), df_ndd_only[olink].dropna(), equal_var=False, nan_policy='omit')
            df_ttest_summary.append(pd.DataFrame({
                'olink_marker' : [olink],
                'ICD10_Code' : [icd_code],
                'ICD10_Code_definition' : [icd10_description_mapping.loc[0, icd_code]], 
                f'{ndd}_mean' : [df_ndd_only[olink].mean()],
                f'{ndd}_ICD10_mean' : [df_ndd_icd[olink].mean()],
                f'T-Statistic ({ndd}+ICD10 vs. {ndd})' : [t_stat],
                'Degrees of Freedom' : [len(df_ndd_icd[olink].dropna()) + len(df_ndd_only[olink].dropna()) - 2],
                f'{ndd}_size' : [len(df_ndd_only[olink].dropna())],
                f'{ndd}_ICD10_size' : [len(df_ndd_icd[olink].dropna())],
                'P-Value' : [p_val],
            }))
    
    df_ttest_summary = pd.concat(df_ttest_summary)
    df_ttest_summary['Bonferroni_Significant'] = df_ttest_summary['P-Value'] < 0.05 / len(df_ttest_summary)
    
    pvals = df_ttest_summary['P-Value'].values
    rejected, pvals_corrected = fdrcorrection(pvals)
    df_ttest_summary['P_VAL_FDR_CORRECTED'] = pvals_corrected
    df_ttest_summary['FDR_CORRECTED_rejected'] = rejected
    
    df_ttest_summary = df_ttest_summary.sort_values(by = f"T-Statistic ({ndd}+ICD10 vs. {ndd})")
    df_ttest_summary.to_csv(f'{zsc_name}_ttest_summary.txt', index=False, sep="\t")
    dx_upload(f"dx upload {zsc_name}_ttest_summary.txt --path results/olink_ttest/{zsc_name}_ttest_summary.txt")


In [None]:
olink_ndd_ttest(icd_codes, "PD", df_pd, df_olink, df_covar, "zscore_pd")
olink_ndd_ttest(icd_codes, "AD", df_ad, df_olink, df_covar, "zscore_ad")
olink_ndd_ttest(icd_codes, "AD", df_ad, df_olink, df_covar, "zscore_ad_without_apoe")


# Multi-modal Regression

In [None]:
# === Statistical helper functions ===
def calculate_95_ci(data):
    mean = np.mean(data)
    ci = stats.t.interval(
        0.95, 
        len(data) - 1, 
        mean, 
        stats.sem(data),
    )
    
    ci_range = mean - ci[0]
    return f"{mean:.2f} ± {ci_range:.2f}"


In [None]:
def fit_classifier(ndd, y, feature_sets, param_grid):
    results_per_modality = {}
    cv_models = []
    df_metrics = []
    
    outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    
    # Loop through each feature set
    for label, X in feature_sets.items():
        # Metrics for the current feature set
        feature_metrics = []
        results_per_modality[label] = []

        # Outer CV
        for train_index, test_index in outer_cv.split(X, y):
            # Split data into training and test sets for the outer CV
            X_train_outer, X_test_outer = X.iloc[train_index], X.iloc[test_index]
            y_train_outer, y_test_outer = y[train_index], y[test_index]

            # Define pipeline
            pipeline = Pipeline([
                ('scaler', StandardScaler()),
                ('feature_selection', SelectFromModel(LinearSVC(penalty="l1", dual=False))),
                ('xgb', XGBClassifier())
            ])

            # Inner CV loop for hyperparameter tuning
            inner_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
            grid_search = GridSearchCV(pipeline, param_grid, cv=inner_cv, scoring='roc_auc', verbose=1, n_jobs=-1)
            grid_search.fit(X_train_outer, y_train_outer)

            # Best model evaluation
            best_model = grid_search.best_estimator_
            if label == 'Combined':
                cv_models.append(best_model)

            # Predict probabilities and classes for both training and test sets
            train_probabilities = best_model.predict_proba(X_train_outer)[:, 1]
            test_probabilities = best_model.predict_proba(X_test_outer)[:, 1]
            y_train_pred = best_model.predict(X_train_outer)
            y_test_pred = best_model.predict(X_test_outer)
            results_per_modality[label].append([test_probabilities, y_test_outer])

            # Calculate ROC AUC and Balanced Accuracy for both sets
            train_auc = roc_auc_score(y_train_outer, train_probabilities)
            test_auc = roc_auc_score(y_test_outer, test_probabilities)
            balanced_acc_train = balanced_accuracy_score(y_train_outer, y_train_pred)
            balanced_acc_test = balanced_accuracy_score(y_test_outer, y_test_pred)

            # Calculate confusion matrix and derive sensitivity and specificity
            tn, fp, fn, tp = confusion_matrix(y_test_outer, y_test_pred).ravel()
            sensitivity = tp / (tp + fn)
            specificity = tn / (tn + fp)

            # Store metrics
            feature_metrics.append({
                'Train AUC': train_auc,
                'Test AUC': test_auc,
                'Train Balanced Accuracy': balanced_acc_train,
                'Test Balanced Accuracy': balanced_acc_test,
                'Sensitivity': sensitivity,
                'Specificity': specificity
            })

        # Average metrics for the current feature set
        avg_metrics = {
            'Feature Set': label,
            'Train AUC': np.mean([m['Train AUC'] for m in feature_metrics]),
            'Test AUC': np.mean([m['Test AUC'] for m in feature_metrics]),
            'Train Balanced Accuracy': np.mean([m['Train Balanced Accuracy'] for m in feature_metrics]),
            'Test Balanced Accuracy': np.mean([m['Test Balanced Accuracy'] for m in feature_metrics]),
            'Sensitivity': np.mean([m['Sensitivity'] for m in feature_metrics]),
            'Specificity': np.mean([m['Specificity'] for m in feature_metrics]),
            'Train AUC_CI': calculate_95_ci([m['Train AUC'] for m in feature_metrics]),
            'Test AUC_CI': calculate_95_ci([m['Test AUC'] for m in feature_metrics]),
            'Train Balanced Accuracy_CI': calculate_95_ci([m['Train Balanced Accuracy'] for m in feature_metrics]),
            'Test Balanced Accuracy_CI': calculate_95_ci([m['Test Balanced Accuracy'] for m in feature_metrics]),
            'Sensitivity_CI': calculate_95_ci([m['Sensitivity'] for m in feature_metrics]),
            'Specificity_CI': calculate_95_ci([m['Specificity'] for m in feature_metrics])
        }
        df_metrics.append(avg_metrics)
        
    df_metrics = pd.DataFrame(df_metrics)
    df_metrics.to_csv(f'{ndd}_metrics.txt', index=False, sep="\t")
        
    return df_metrics, results_per_modality, cv_models


In [None]:
# === AUC-ROC plotting function ===
def plot_auc_roc(ndd, results_per_modality, df_metrics=None):
    palette_progression = {'Age+Sex+Townsend:': 'lightcoral', 'Olink:': 'red'}
    colorlist = ['blue', 'green', 'red', 'purple', 'orange', 'cyan']

    class_labels = {}
    input_predictions = []
    for e, x in enumerate(list(results_per_modality)):
        class_labels[e] = x + ':'
        palette_progression[x + ':'] = colorlist[e]
        input_predictions.append(results_per_modality[x])

    # === Create plot ===
    fig, ax = plt.subplots(1, 1, figsize=(14.4, 10.8))
    sns.despine(offset=5, trim=True)

    all_means = defaultdict(list)
    for k in range(0, len(class_labels)):
        if class_labels.get(k, None) is None:
            continue
        tprs = []
        aucs = []
        mean_fpr = np.linspace(0, 1, 100)
        for preds, label in input_predictions[k]:
            fpr, tpr, _ = roc_curve(label, preds)
            tprs.append(interp(mean_fpr, fpr, tpr))
            tprs[-1][0] = 0.0
            roc_auc = auc(fpr, tpr)
            aucs.append(roc_auc)

        class_name = class_labels[k]
        mean_tpr = np.mean(tprs, axis=0)
        mean_tpr[-1] = 1.0
        mean_auc = auc(mean_fpr, mean_tpr)
        all_means[class_name].append(mean_auc)
        std_auc = np.std(aucs)
        
        acc_ci = df_metrics[df_metrics['Feature Set'] == class_name[:-1]].iloc[0].loc['Test Balanced Accuracy_CI']
        if std_auc > 0:
            ax.plot(mean_fpr, mean_tpr,
                    label=r'{0} mean ROC (AUC = {1:0.2f} $\pm$ {2:0.2f}, Bal Acc = {3})'.format(
                        class_name, mean_auc, std_auc, acc_ci),
                    lw=2, alpha=.9, color=palette_progression[class_name])
        else:
            ax.plot(mean_fpr, mean_tpr,
                    label=r'{0} mean ROC (AUC = {1:0.2f}, Bal Acc = {3})'.format(
                        class_name, mean_auc, std_auc, acc_ci),
                    lw=2, alpha=.9, color=palette_progression[class_name])

        std_tpr = np.std(tprs, axis=0)
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.1)
        
    df_results = pd.DataFrame(all_means)
    df_results.to_csv(f"{ndd}_aucs.tsv", sep="\t")

    ax.plot([0, 1], [0, 1], 'k--')
    ax.set_xlim([-0.025, 1.025])
    ax.set_ylim([-0.025, 1.025])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('{}'.format(""))
    ax.legend()
    
    plt.title(f"{ndd} ROC Curve")
    plt.show()


In [None]:
def shap_analysis(X, cv_models):
    # Create Explainer and get shap_values
    y_regr = cv_models[0].predict_proba(X)[:, 1] + cv_models[1].predict_proba(X)[:, 1] + cv_models[2].predict_proba(X)[:, 1] + cv_models[3].predict_proba(X)[:, 1] + cv_models[4].predict_proba(X)[:, 1]
    y_regr = y_regr / 5

    model = LGBMRegressor()
    model.fit(X, y_regr)
    y_pred = model.predict(X)
    explainer = shap.Explainer(model, X)
    shap_values = explainer(X, check_additivity=False)
    
    return shap_values
    

In [None]:
def plot_shap(shap_values, ndd):
    # Font setup
    plt.rcParams.update({
        'font.size': 12,
    })

    figsize = (6, 5)

    # --- Step 1: Beeswarm plot ---
    buf1 = io.BytesIO()
    plt.figure(figsize=figsize)
    shap.plots.beeswarm(shap_values, max_display=20, show=False)

    fig = plt.gcf()

    # Remove only the "Feature value" label, keep colorbar
    for ax in fig.axes:
        if ax != plt.gca():
            for text in ax.get_yticklabels() + ax.get_xticklabels() + ax.texts:
                if text.get_text() == 'Feature value':
                    text.set_visible(False)

    ax1 = plt.gca()
    for tick in ax1.get_xticklabels() + ax1.get_yticklabels():
        tick.set_fontsize(12)
        
    ax1.set_title('Beeswarm Plot', fontsize=12)
    ax1.set_xlabel(ax1.get_xlabel(), fontsize=12)
    ax1.set_ylabel(ax1.get_ylabel(), fontsize=12)
    ax1.set_yticklabels([])

    ax1.spines['bottom'].set_position(('outward', 10))  # Move x-axis outward

    plt.tight_layout()
    plt.savefig(buf1, format='png', bbox_inches='tight', dpi=300)
    plt.close()

    # --- Step 2: Bar plot ---
    buf2 = io.BytesIO()
    plt.figure(figsize=figsize)
    shap.plots.bar(shap_values, max_display=20, show=False)

    ax2 = plt.gca()
    ax2.tick_params(axis='y', pad=4)

    for tick in ax2.get_xticklabels() + ax2.get_yticklabels():
        tick.set_fontsize(12)
        
    ax2.set_title('Bar Plot', fontsize=12)
    ax2.set_xlabel(ax2.get_xlabel(), fontsize=12)
    ax2.set_ylabel(ax2.get_ylabel(), fontsize=12)
    ax2.title.set_position([0, -5])

    plt.tight_layout()
    plt.savefig(buf2, format='png', bbox_inches='tight', dpi=300)
    plt.close()

    # --- Step 3: Load images ---
    buf1.seek(0)
    buf2.seek(0)
    img1 = Image.open(buf1)
    img2 = Image.open(buf2)

    # --- Step 4: Resize to same height ---
    if img1.height != img2.height:
        target_height = max(img1.height, img2.height)
        img1 = img1.resize((img1.width, target_height))
        img2 = img2.resize((img2.width, target_height))

    # --- Step 5: Adjust vertical alignment if needed ---
    height_diff = abs(img1.height - img2.height)
    if img1.height > img2.height:
        img2 = img2.crop((0, height_diff, img2.width, img2.height))
    elif img1.height < img2.height:
        img1 = img1.crop((0, height_diff, img1.width, img1.height))

    # --- Step 6: Lower the beeswarm plot slightly ---
    vertical_offset = 30
    img1_lowered = Image.new('RGB', (img1.width, img1.height + vertical_offset), (255, 255, 255))
    img1_lowered.paste(img1, (0, vertical_offset))

    # --- Step 7: Add left margin to bar plot to shift it right ---
    left_margin = 50  # pixels to shift right
    bar_with_margin = Image.new('RGB', (img2.width + left_margin, img2.height), (255, 255, 255))
    bar_with_margin.paste(img2, (left_margin, 0))

    # --- Step 8: Combine images ---
    gap = 10
    total_width = img1_lowered.width + bar_with_margin.width + gap
    combined_height = max(img1_lowered.height, bar_with_margin.height)
    combined = Image.new('RGB', (total_width, combined_height), (255, 255, 255))

    combined.paste(bar_with_margin, (0, 0))
    combined.paste(img1_lowered, (bar_with_margin.width + gap, 0))

    # --- Step 9: Final combined figure ---
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.imshow(combined)
    ax.axis('off')

    # Label (A)
    if ndd == "AD":
        ax.text(30, 30, 'A', fontsize=20, fontweight='bold', color='black')
    if ndd == "PD":
        ax.text(30, 30, 'B', fontsize=20, fontweight='bold', color='black')

    plt.tight_layout()
    fig.savefig(f'{ndd}_shap.png', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
# === Plotting style ===
sns.reset_defaults()
mpl.rcParams.update(mpl.rcParamsDefault)
sns.set_theme("notebook", font_scale=1.5, rc={"lines.linewidth": 1.5})
sns.set_style("white")
sns.set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
# mpl.rc('font', family='serif', serif='Times New Roman')
plt.rcParams.update({"savefig.format": 'png'})


In [None]:
def multimodal_regression(ndd, df_ndd, df_covar, df_icd10, df_olink, df_pcs, df_riskvars, olink_features, icd_codes, zscore, n_est, lr, depth):
    demographics_features = ['townsend', 'sex', 'Age_at_recruitment']
    genetics_features = [zscore, 'PC1', 'PC2', 'PC3', 'PC4', 'PC5'] + list(df_riskvars.columns.values[1:])
        
    df = pd.merge(df_ndd[["eid", ndd, zscore]], df_covar[["eid"] + demographics_features], on="eid")
    df = df.merge(df_icd10[["eid"] + icd_codes], on="eid")
    df = df.merge(df_olink[["eid"] + olink_features], on="eid")
    df = df.merge(df_pcs, on="eid")
    df = df.merge(df_riskvars, on='eid')
    
    df[ndd] = df[ndd].notna().astype(int)
    
    df_cases = df[df[ndd] == 1]
    df_controls = df[df[ndd] == 0].sample(len(df_cases), random_state=42)
    print(f"CASES: {df_cases['Age_at_recruitment'].mean():.2f} +/- {df_cases['Age_at_recruitment'].std():.2f}")
    print(f"CONTROLS: {df_controls['Age_at_recruitment'].mean():.2f} +/- {df_controls['Age_at_recruitment'].std():.2f}")
    df_balanced = pd.concat([df_cases, df_controls], axis=0)
    
    y = df_balanced[ndd].values
    X = df_balanced.drop([ndd], axis=1)

    # Data preparation
    feature_sets = {
        'Genetics': X[genetics_features],
        'Clinical': X[icd_codes],
        'Olink': X[olink_features],
        'Demographics': X[demographics_features],
        'Combined without Clinical': X[genetics_features + olink_features + demographics_features],
        'Combined': X,
    }

    # Define a parameter grid for Gradient Boosting
    param_grid = {
        'xgb__n_estimators': n_est,
        'xgb__learning_rate': lr,
        'xgb__max_depth': depth,
    }
    
    # Fit classifier model
    df_metrics, results_per_modality, cv_models = fit_classifier(ndd, y, feature_sets, param_grid)
    plot_auc_roc(ndd, results_per_modality, df_metrics=df_metrics)
    
    # Shapley analysis
    shap_values = shap_analysis(X, cv_models)
    plot_shap(shap_values, ndd)


In [None]:
olink_features_pd = [
    'itgav', 'vat1', 'egfr', 'megf9', 'adgrg2', 'tnxb', 'itgam', 'il13ra1', 'itgb1', 'clec10a', 'itga11', 
    'setmar', 'itgb2', 'hpgds', 'bag3', 'eps8l2', 'cdon', 'ptprn2', 'scg2', 'klk8', 'dctpp1', 'crhbp', 
    'ifnlr1', 'tppp3', 'angptl3', 'hnmt', 'comp', 'furin', 'tafa5', 'cxcl17', 'enah', 'sftpd', 'cd276', 'nefl',
]
olink_features_ad = ['nptxr', 'cst5', 'psg1', 'ren', 'calb1', 'gdf15', 'il1rl1', 'ltbp2', 'nefl', 'gfap']

df_riskvars_pd = df_lrrk2[["eid", "12:40220632:C:T_C", "12:40340400:G:A_G"]].merge(df_gba1[["eid", "1:155162560:G:A_G", "1:155235843:T:C_T"]], on="eid")
df_riskvars_pd["eid"] = df_riskvars_pd["eid"].astype(str)

df_riskvars_ad = df_apoe
df_riskvars_ad["eid"] = df_riskvars_ad["eid"].astype(str)


In [None]:
multimodal_regression(
    "PD", df_pd, df_covar, df_icd10, df_olink, df_pcs, df_riskvars_pd, olink_features_pd, icd_codes,
    "zscore_pd", [1, 3, 5, 10, 15], [0.001, 0.01, 0.1], [2, 3, 5],
)
multimodal_regression(
    "AD", df_ad, df_covar, df_icd10, df_olink, df_pcs, df_riskvars_ad, olink_features_ad, icd_codes,
    "zscore_ad", [2, 3, 5, 10, 15], [0.001, 0.01, 0.1], [3, 4, 5],
)
multimodal_regression(
    "AD", df_ad, df_covar, df_icd10, df_olink, df_pcs, df_riskvars_ad, olink_features_ad, icd_codes,
    "zscore_ad_without_apoe", [2, 3, 5, 10, 15], [0.001, 0.01, 0.1], [3, 4, 5],
)
