## Installing SuStaIn and setting it up to run in a notebook

To get SuStaIn up and running first you need to install the package. I'm using Anaconda and had some conflicts with existing packagaes so I had to create a new environment. For me the whole set up process looked like this...

Step 1: Open up a terminal window and create a new environment "sustain_env" in anaconda that uses python 3.7 and activate the environment ready to install pySuStaIn.
```console
conda create --name sustain_tutorial_env python=3.7
conda activate sustain_tutorial_env
```

Step 2: Use the terminal to install necessary packages for running the notebook and pySuStaIn within the environment.
```console
conda install -y ipython jupyter matplotlib statsmodels numpy pandas scipy seaborn openpyxl pip
pip install nbconvert
pip install git+https://github.com/ucl-pond/pySuStaIn
```

Step 3: Use the terminal to run the notebook from inside the environment.
```console
jupyter notebook
```

Once you've got your environment running the general workflow will be to open a terminal window and navigate to the directory with the notebook in, activate the envirnoment, open a jupyter notebook inside and use the notebook to run your analyses, then use the terminal deactivate the environment once you've finished running analyses.
```console
conda activate sustain_tutorial_env
jupyter notebook
conda deactivate
```

In [None]:
import os
os.environ['PATH'] = r'C:\Program Files\Pandoc;' + os.environ['PATH']

print(os.environ.get('PATH'))

In [None]:
import multiprocessing

# Logical cores (includes hyper-threading)
num_logical_cores = multiprocessing.cpu_count()

print(f"Logical CPU cores available: {num_logical_cores}")

import psutil

num_physical_cores = psutil.cpu_count(logical=False)
print(f"Physical CPU cores available: {num_physical_cores}")

In [3]:
# Load libraries

import os
import pandas
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pySuStaIn
import statsmodels.formula.api as smf
from scipy import stats
import sklearn.model_selection


In [None]:
import pandas as pd

data = pd.read_excel(#COMBINED .csv)
data.tail(20)

In [None]:
#Make all values a percentile of the respective column
import pandas as pd

# Convert each value in the dataframe to a percentile relative to its column max value
data = data.copy()  # Create a copy of the original data to avoid overwriting

# Iterate over the columns
for col in data.columns:
    # Skip 'Lab_no' and 'Diagnosis' columns and only apply percentile calculation to numerical columns
    if col not in ['Lab_no', 'Diagnosis'] and np.issubdtype(data[col].dtype, np.number):
        max_value = data[col].max()  # Get the maximum value for each column
        if max_value > 0:  # Ensure max_value is not 0 to avoid division by zero
            data[col] = (data[col] / max_value) * 100  # Calculate the percentile
        else:
            data[col] = data[col]  # If max_value is 0, leave column as is

# View the transformed data
data.tail(150)

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pySuStaIn
import statsmodels.formula.api as smf
from scipy import stats
import sklearn.model_selection
def generate_synthetic_controls(data, n_controls=10, id_prefix='CTRL_Synth_'):
    real_data = data[data['Diagnosis'] == 1]  # Use only patient data

    synthetic_controls = []

    for i in range(n_controls):
        synthetic_row = {}
        for col in data.columns:
            if col == 'Lab_no':
                synthetic_row[col] = f"{id_prefix}{i+1}"
            elif col == 'Diagnosis':
                synthetic_row[col] = 0
            elif np.issubdtype(data[col].dtype, np.number):
                col_vals = real_data[col]
                col_mean = col_vals.mean()
                col_std = col_vals.std()

                if pd.isna(col_std) or col_std == 0:
                    synthetic_row[col] = np.random.uniform(0.01, 0.05)
                else:
                    # Shift the mean to be lower (e.g., 30% of real mean)
                    synth_mean = 0.1 * col_mean
                    synth_std = 0.1 * col_std  # maintain good spread

                    value = np.random.normal(loc=synth_mean, scale=synth_std)
                    # Avoid negative or near-zero values
                    synthetic_row[col] = np.clip(value, 0.05, None)
            else:
                synthetic_row[col] = None

        synthetic_controls.append(synthetic_row)

    return pd.DataFrame(synthetic_controls)
# Generate 10 synthetic control rows
synthetic_controls = generate_synthetic_controls(data, n_controls=50)

# Combine with original data
data = pd.concat([data, synthetic_controls], ignore_index=True)

# Show the last few rows
data.tail(50)

In [None]:
data.Diagnosis.value_counts()

In [None]:
# store our biomarker labels as a variable
biomarkers = [col for col in data.columns if col not in ['Diagnosis', 'Lab_no']]
print(biomarkers)

# Normalize to control group

In [None]:
non_numeric_cols = data.select_dtypes(exclude=[np.number]).columns
print(non_numeric_cols)

In [None]:
# first a quick look at the patient and control distribution for one of our biomarkers

biomarker = biomarkers[0]

for biomarker in biomarkers:
    sns.displot(data=data, x=biomarker, hue='Diagnosis', kind='kde')
    plt.title(biomarker)
    plt.show()

In [None]:
import pandas as pd
import numpy as np

# Create DataFrame from data (assumes `data` is a list of dicts or similar)
df = pd.DataFrame(data)

# Filter only control rows (Diagnosis == 0)
control_df = df[df['Diagnosis'] == 0].copy()

# Select only numeric biomarker columns (excluding Diagnosis)
biomarker_columns = control_df.select_dtypes(include=np.number).drop(columns=['Diagnosis'])

# Calculate control group means and standard deviations
biomarker_means = biomarker_columns.mean()
biomarker_stds = biomarker_columns.std(ddof=0)  # Population std

# Display results
print("🧪 Control subjects preview:\n", control_df.head())
print("\n📊 Mean of each control biomarker:\n", biomarker_means)
print("\n📈 Standard deviation of each control biomarker:\n", biomarker_stds)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_selection import VarianceThreshold

# === PARAMETERS ===
CLIP_RANGE = (-3, 5)
VARIANCE_THRESHOLD = 5e-11
STD_THRESHOLD = 1e-5

# === Step 1: Copy & Clip Data ===
zdata = pd.DataFrame(data).copy()
control_df = control_df.copy()

# Clip to remove extreme outliers (helps stabilize Z-scoring)
zdata[biomarkers] = zdata[biomarkers].clip(lower=CLIP_RANGE[0], upper=CLIP_RANGE[1])

# === Step 2: Normalize Using Control Means/SDs ===
for biomarker in biomarkers:
    mean_val = biomarker_means[biomarker]
    std_val = biomarker_stds[biomarker]

    if std_val > 0:
        zdata[biomarker] = (zdata[biomarker] - mean_val) / std_val
        control_df[biomarker] = (control_df[biomarker] - mean_val) / std_val
    else:
        print(f"⚠️ Skipped {biomarker} due to 0 std")

# === Step 3: Flip Decreasing Biomarkers ===
# Use *original* (pre-Z) means to detect disease-reducing markers
original_df = pd.DataFrame(data)
mean_all = original_df[biomarkers].mean()
mean_controls = control_df[biomarkers].mean()  # Already normalized

is_decreasing = mean_all < mean_controls
print("🧪 Biomarkers that decrease with disease:\n", is_decreasing)

for biomarker in biomarkers:
    if is_decreasing[biomarker]:
        zdata[biomarker] *= -1
        control_df[biomarker] *= -1

# === Step 4: Diagnostics ===
print("\n📊 Control Means (after normalization):\n", control_df[biomarkers].mean())
print("\n📈 Control Std Devs:\n", control_df[biomarkers].std())
print("\n📊 All Data Means:\n", zdata[biomarkers].mean())
print("\n📈 All Data Std Devs:\n", zdata[biomarkers].std())

# === Step 5: Filtering Biomarkers ===
# Apply variance threshold to filter out low-variance biomarkers
selector = VarianceThreshold(threshold=VARIANCE_THRESHOLD)
selector.fit(zdata[biomarkers])

# Keep biomarkers that passed the variance threshold
filtered_biomarkers = [biomarkers[i] for i in range(len(biomarkers)) if selector.variances_[i] > VARIANCE_THRESHOLD]

# Apply standard deviation threshold
final_filtered = [b for b in filtered_biomarkers if zdata[b].std() >= STD_THRESHOLD]

# === Step 6: Plotting the First Retained Biomarker ===
# Use final list of good biomarkers after all filtering
zdata_filtered = zdata[final_filtered].copy()
zdata_filtered['Diagnosis'] = data['Diagnosis'].values  # Add back Diagnosis column

# Plot KDE for each biomarker in the filtered list
for biomarker in final_filtered:
    sns.displot(data=zdata_filtered, x=biomarker, hue='Diagnosis', kind='kde')
    plt.title(f"{biomarker} (KDE by Diagnosis)")
    plt.axvline(0, ls='--', c='black')
    plt.show()



# Prepare SuStaIn inputs

In [None]:
import os
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_selection import VarianceThreshold

# === Reproducibility & Clean Output ===
np.random.seed(42)
warnings.filterwarnings("ignore")  # Suppress optional warnings (like scipy/numpy warnings)

N = len(biomarkers)         # number of biomarkers

SuStaInLabels = biomarkers
Z_vals = np.array([[1,2,3]]*N)     # Z-scores for each biomarker
Z_max  = np.array([15]*N)           # maximum z-score
#Z_max = np.array([29, 48, 45, 43, 19, 26, 50, 119, 152, 90, 57, 44, 56, 61, 34])

print("Z_max:", Z_max)

In [49]:
# Input the settings for z-score SuStaIn
# To make the tutorial run faster I've set 
# N_startpoints = 10 and N_iterations_MCMC = int(1e4)
# I recommend using N_startpoints = 25 and 
# N_iterations_MCMC = int(1e5) or int(1e6) in general though


#✅ Control parallel startpoints execution (with n_jobs=24 via joblib’s parallel_backend).

#✅ Limit numpy/scipy BLAS threading to 24 threads (via OMP_NUM_THREADS=24), avoiding oversubscription per core.


import os
import pySuStaIn
from joblib import parallel_backend

# Limit BLAS threads (numpy/scipy etc.)
os.environ["OMP_NUM_THREADS"] = "12"  # Or 32 if thermals are fine

# SuStaIn Settings
N_startpoints = 30
N_S_max = 4
N_iterations_MCMC = int(1e6)
output_folder = os.path.join(os.getcwd(), "/Users/hemanthnelvagal/Desktop/SUSTAIN/MPP_SUSTAIN/MPP_SUSTAIN_COMMON_4N_LIMBIC")
dataset_name = 'MPP_SUSTAIN_COMMON_4N_LIMBIC'
5
# Save input data to Excel (for record)
zdata.to_excel('zdata_input_norm_lowctrl.xlsx', index=False)

# Run SuStaIn with controlled parallelism
with parallel_backend('loky', n_jobs=24):  # Adjust n_jobs if needed
    sustain_input = pySuStaIn.ZscoreSustainMissingData(
        zdata[biomarkers].values,
        Z_vals,
        Z_max,
        SuStaInLabels,
        N_startpoints,
        N_S_max,
        N_iterations_MCMC,
        output_folder,
        dataset_name,
        True
    )

In [None]:
print("Any rows in zdata[biomarkers] with zero variance?")
print((zdata[biomarkers].std(axis=1) == 0).sum(), "rows found.")

import seaborn as sns

zdata_filtered = zdata.copy()
zdata_filtered['Lab_no'] = data['Lab_no']

melted = zdata_filtered.melt(id_vars=['Lab_no'], value_vars=biomarkers, var_name='Biomarker', value_name='Zscore')
melted['Type'] = melted['Lab_no'].apply(lambda x: 'Synthetic' if str(x).startswith("CTRL_Synth") else 'Real')

plt.figure(figsize=(14, 6))
sns.boxplot(x='Biomarker', y='Zscore', hue='Type', data=melted)
plt.xticks(rotation=90)
plt.title("Z-score distributions by biomarker: Synthetic vs Real")
plt.tight_layout()
plt.show()

# Run SuStaIn!

In [None]:
import pickle
from pySuStaIn import ZscoreSustain
# make the output directory if it's not already created
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)
try:
    # === Run SuStaIn ===
    results = sustain_input.run_sustain_algorithm(plot=True)
    samples_sequence, samples_f, ml_subtype, prob_ml_subtype, ml_stage, prob_ml_stage, prob_subtype_stage = results

    # === Initialize dimensions ===
    n_subjects = zdata.shape[0]
    n_stages = Z_vals.shape[1]
    n_subtypes = N_S_max

    # === Normalize probability matrix and fix NaNs ===
    prob_subtype_stage = np.nan_to_num(prob_subtype_stage, nan=0.0)
    total_probs = prob_subtype_stage.sum(axis=(1, 2), keepdims=True)
    total_probs[total_probs == 0] = 1  # avoid division by zero
    prob_subtype_stage /= total_probs

    # === Compute most likely subtype and stage for each subject ===
    ml_subtype = np.full(n_subjects, -1, dtype=int)
    ml_stage = np.full(n_subjects, -1, dtype=int)
    prob_ml_subtype = np.zeros(n_subjects)
    prob_ml_stage = np.zeros(n_subjects)

    for i in range(n_subjects):
        subtype_probs = np.sum(prob_subtype_stage[i], axis=0)
        ml_subtype[i] = int(np.argmax(subtype_probs))
        prob_ml_subtype[i] = subtype_probs[ml_subtype[i]]

        stage_probs = prob_subtype_stage[i, :, ml_subtype[i]]
        ml_stage[i] = int(np.argmax(stage_probs))
        prob_ml_stage[i] = stage_probs[ml_stage[i]]

    print("✅ SuStaIn finished with normalized probabilities and safe subtype/stage assignments.")

except ValueError as e:
    # === Emergency fallback ===
    print("❌ ValueError during SuStaIn assignment. Likely due to NaNs or division by zero.")
    print("🔍 Error:", e)

    n_subjects = zdata.shape[0]
    n_stages = Z_vals.shape[1]

    ml_subtype = np.full(n_subjects, -1, dtype=int)
    ml_stage = np.full(n_subjects, -1, dtype=int)
    prob_ml_subtype = np.zeros(n_subjects)
    prob_ml_stage = np.zeros(n_subjects)
    prob_subtype_stage = np.zeros((n_subjects, n_stages + 1, N_S_max))

for name, obj in {
    "samples_sequence": samples_sequence,
    "samples_f": samples_f,
    "ml_subtype": ml_subtype,
    "prob_ml_subtype": prob_ml_subtype,
    "Z_vals": Z_vals,
    "biomarker_labels": SuStaInLabels
}.items():
    with open(os.path.join(output_folder, "pickle_files", f"{name}.pickle"), "wb") as f:
        pickle.dump(obj, f)


# Output path for plots
N_SAMPLES = 1000
pvd_dir = os.path.join(output_folder, "PVD")
os.makedirs(pvd_dir, exist_ok=True)

# Generate PVD figure
figs = ZscoreSustain.plot_positional_var(
    samples_sequence=samples_sequence,
    samples_f=samples_f,
    n_samples=N_SAMPLES,
    Z_vals=Z_vals,
    biomarker_labels=SuStaInLabels,
    separate_subtypes=True,
    figsize=(16, 4)
)

# Save figure at high resolution
# Unpack tuple safely
for figlist in figs:
    if isinstance(figlist, (list, tuple)):
        for fig in figlist:
            if isinstance(fig, plt.Figure):
                fig.axes[0].set_title("")
                fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.png"), dpi=1000, bbox_inches="tight")
                fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.pdf"), bbox_inches="tight")
                plt.show()
                plt.close(fig)
    elif isinstance(figlist, plt.Figure):
        fig = figlist
        fig.axes[0].set_title("")
        fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.png"), dpi=1000, bbox_inches="tight")
        fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.pdf"), bbox_inches="tight")
        plt.show()
        plt.close(fig)

In [None]:
import pickle

output_pickle_dir = os.path.join(output_folder, "pickle_files")
os.makedirs(output_pickle_dir, exist_ok=True)

with open(os.path.join(output_pickle_dir, "samples_sequence.pickle"), "wb") as f:
    pickle.dump(samples_sequence, f)

with open(os.path.join(output_pickle_dir, "samples_f.pickle"), "wb") as f:
    pickle.dump(samples_f, f)

with open(os.path.join(output_pickle_dir, "ml_subtype.pickle"), "wb") as f:
    pickle.dump(ml_subtype, f)

with open(os.path.join(output_pickle_dir, "prob_ml_subtype.pickle"), "wb") as f:
    pickle.dump(prob_ml_subtype, f)

with open(os.path.join(output_pickle_dir, "Z_vals.pickle"), "wb") as f:
    pickle.dump(Z_vals, f)

with open(os.path.join(output_pickle_dir, "biomarker_labels.pickle"), "wb") as f:
    pickle.dump(SuStaInLabels, f)

print("✅ Saved all key arrays from 1e6 run to disk.")

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pySuStaIn import ZscoreSustain
# === Constants ===
biomarkers = [
    'ASYN FRONTAL', 'ABETA FRONTAL', 'AT8 FRONTAL', 'ASYN TEMPORAL', 'ABETA TEMPORAL', 'AT8 TEMPORAL', 
    'ASYN PARIETAL', 'ABETA PARIETAL', 'AT8 PARIETAL',
    'ASYN_HIPPO', 'ABETA_HIPPO', 'AT8_HIPPO', 'ASYN_PARAHIPPO', 'ABETA_PARAHIPPO', 'AT8_PARAHIPPO'
]
Z_vals = np.array([[1, 2, 3]] * len(biomarkers))
N_SAMPLES = 1000
base_dir = "/Users/hemanthnelvagal/Desktop/SUSTAIN/MPP_SUSTAIN/MPP_SUSTAIN_COMMON_4N_LIMBIC"
pickle_dir = os.path.join(base_dir, "pickle_files")
out_base = os.path.join(base_dir, "PVD_highres_from_full_model")
os.makedirs(out_base, exist_ok=True)

# === Load model output ===
print("📦 Loading SuStaIn outputs...")
with open(os.path.join(pickle_dir, "samples_sequence.pickle"), "rb") as f:
    samples_sequence_full = pickle.load(f)
with open(os.path.join(pickle_dir, "samples_f.pickle"), "rb") as f:
    samples_f_full = pickle.load(f)

# === Helper to flatten figure list ===
def flatten_figs(obj):
    if isinstance(obj, plt.Figure):
        yield obj
    elif hasattr(obj, "get_figure"):
        yield obj.get_figure()
    elif isinstance(obj, (list, tuple)):
        for elt in obj:
            yield from flatten_figs(elt)

# === Plot high-res PVDs by slicing from full model ===
for N in [4]:
    print(f"\n🔁 Generating PVDs for N={N} (sliced from full model)")
    out_dir = os.path.join(out_base, f"N{N}")
    os.makedirs(out_dir, exist_ok=True)

    seq_n = samples_sequence_full[:N, :, :N_SAMPLES]
    f_n = samples_f_full[:N, :N_SAMPLES]

    figs = ZscoreSustain.plot_positional_var(
        samples_sequence=seq_n,
        samples_f=f_n,
        n_samples=N_SAMPLES,
        Z_vals=Z_vals,
        biomarker_labels=biomarkers,
        separate_subtypes=True,
        figsize=(16, 4 * N)
    )

    for i, fig in enumerate(flatten_figs(figs), start=1):
        fig.axes[0].set_title("")  # remove title
        fig.savefig(os.path.join(out_dir, f"PVD_N{N}_Subtype{i}.png"), dpi=1000, bbox_inches="tight")
        fig.savefig(os.path.join(out_dir, f"PVD_N{N}_Subtype{i}.pdf"), bbox_inches="tight")
        plt.close(fig)
        print(f"  ✅ Saved high-res PVD for Subtype {i} (N={N})")

print("\n🎉 Done: All high-resolution PVDs sliced from N=4 model.")

In [None]:
import os
import pandas as pd
from joblib import parallel_backend
import pySuStaIn

# === Setup paths and parameters ===
main_output_folder = output_folder
dataset_name = dataset_name
n_jobs_to_use = 28
N_iterations_MCMC = N_iterations_MCMC
N_startpoints = N_startpoints
all_Ns = [1, 2, 3, 4]

# === Initialize SuStaIn parameters ===
SuStaInLabels = biomarkers.copy()
Z_vals = np.array([[1, 2, 3]] * len(biomarkers))
Z_max = np.array([15] * len(biomarkers))

# === Step 1: Run SuStaIn and save assignments for each N ===
for N_S_max in all_Ns:
    sub_output_folder = os.path.join(main_output_folder, f"N{N_S_max}")
    os.makedirs(sub_output_folder, exist_ok=True)

    print(f"▶ Running ZscoreSustainMissingData for N_S = {N_S_max}...")

    with parallel_backend('loky', n_jobs=n_jobs_to_use):
        sustain = pySuStaIn.ZscoreSustainMissingData(
            zdata[biomarkers].values,
            Z_vals,
            Z_max,
            SuStaInLabels,
            N_startpoints,
            N_S_max,
            N_iterations_MCMC,
            main_output_folder,
            dataset_name,
            use_parallel_startpoints=True,
            seed=42
        )
        results = sustain.run_sustain_algorithm(plot=False)

    # === Extract output ===
    (samples_sequence, samples_f,
     ml_subtype, prob_ml_subtype,
     ml_stage, prob_ml_stage,
     prob_subtype_stage) = results

    df_assign = pd.DataFrame({
        "Lab_no": zdata["Lab_no"].values.ravel(),
        f"Subtype_N{N_S_max}": ml_subtype.ravel(),
        f"Stage_N{N_S_max}": ml_stage.ravel()
    })

    assign_path = os.path.join(sub_output_folder, f"assignments_N{N_S_max}.xlsx")
    df_assign.to_excel(assign_path, index=False)
    print(f"✅ Saved: {assign_path}")

# === Step 2: Combine safely into one Excel (memory-friendly) ===
print("📦 Combining assignments...")

# Load Lab_no once from N1
base_file = os.path.join(main_output_folder, "N1", "assignments_N1.xlsx")
combined_df = pd.read_excel(base_file, usecols=["Lab_no"])

# Append subtype/stage from each N
for N_S_max in all_Ns:
    file_path = os.path.join(main_output_folder, f"N{N_S_max}", f"assignments_N{N_S_max}.xlsx")
    partial = pd.read_excel(file_path, usecols=[f"Subtype_N{N_S_max}", f"Stage_N{N_S_max}"])
    combined_df = pd.concat([combined_df, partial], axis=1)

# Final export
combined_output_path = os.path.join(main_output_folder, "Combined_Assignments_All_Ns.xlsx")
combined_df.to_excel(combined_output_path, index=False)
print(f"✅ All assignments combined → {combined_output_path}")

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# === Create plot figures ===
plt.figure(0, figsize=(8, 5))  # Trace plot
plt.title('MCMC Log-Likelihood Trace', fontsize=16)
plt.xlabel('MCMC Samples', fontsize=14)
plt.ylabel('Log Likelihood', fontsize=14)

plt.figure(1, figsize=(8, 5))  # Histogram plot
plt.title('Histogram of Log Likelihoods', fontsize=16)
plt.xlabel('Log Likelihood', fontsize=14)
plt.ylabel('Number of Samples', fontsize=14)

# === Plot per subtype ===
for s in range(N_S_max):
    # Load SuStaIn output
    pickle_path = os.path.join(output_folder, 'pickle_files', f'{dataset_name}_subtype{s}.pickle')
    pk = pd.read_pickle(pickle_path)
    samples_likelihood = pk["samples_likelihood"]

    # Add trace to Figure 0
    plt.figure(0)
    plt.plot(range(N_iterations_MCMC), samples_likelihood, label=f"Subtype {s}")

    # Add histogram to Figure 1
    plt.figure(1)
    plt.hist(samples_likelihood, bins=50, alpha=0.6, label=f"Subtype {s}")

# === Finalize and save Figure 0 ===
plt.figure(0)
plt.legend(loc='upper right', fontsize=12)
plt.tight_layout()
trace_path = os.path.join(output_folder, 'MCMC_loglikelihood_trace_LRG.png')
plt.savefig(trace_path)

# === Finalize and save Figure 1 ===
plt.figure(1)
plt.legend(loc='upper right', fontsize=12)
plt.tight_layout()
hist_path = os.path.join(output_folder, 'MCMC_histogram_LRG.png')
plt.savefig(hist_path)

# === Show and close ===
plt.show()
plt.close('all')

print(f"✅ Saved trace plot: {trace_path}")
print(f"✅ Saved histogram: {hist_path}")

In [None]:
# =========================
# SuStaIn MCMC plots + stats Excel (Supplementary Figure 2)
# =========================
import os, re, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ----------- USER CONFIG -----------
output_folder = output_folder
dataset_name  = dataset_name          # <-- change me
excel_name    = "Supplementary_Figure_2_stats.xlsx"
# -----------------------------------

# ---------- helpers: IO / discovery ----------
def collect_subtype_pickles(output_folder, dataset_name):
    pk_dir = os.path.join(output_folder, "pickle_files")
    pattern = os.path.join(pk_dir, f"{dataset_name}_subtype*.pickle")
    files = sorted(glob.glob(pattern))

    def _key(p):
        m = re.search(r"_subtype(\d+)\.pickle$", p)
        return int(m.group(1)) if m else 10**9

    files.sort(key=_key)
    return files

def load_samples_likelihood(pickle_path):
    pk = pd.read_pickle(pickle_path)
    # common keys seen in SuStaIn outputs
    for k in ("samples_likelihood", "samples_likelihoods"):
        if k in pk:
            return np.asarray(pk[k], dtype=float)
    raise KeyError(f"'samples_likelihood' not found in {pickle_path}")

# ---------- helpers: stats ----------
def safe_array(x):
    a = np.asarray(x, dtype=float).ravel()
    return a[np.isfinite(a)]

def quantiles(a, qs=(0.025, 0.25, 0.75, 0.975)):
    qs_vals = np.quantile(a, qs)
    return {
        "q2.5%":  float(qs_vals[0]),
        "q25%":   float(qs_vals[1]),
        "q75%":   float(qs_vals[2]),
        "q97.5%": float(qs_vals[3]),
    }

def summary_for_subtype(samples, subtype_name):
    a = safe_array(samples)
    if a.size == 0:
        return {
            "Subtype": subtype_name, "n": 0, "mean": np.nan, "sd": np.nan,
            "CV": np.nan, "median": np.nan, "min": np.nan, "q2.5%": np.nan,
            "q25%": np.nan, "q75%": np.nan, "q97.5%": np.nan, "max": np.nan,
            "IQR": np.nan, "range": np.nan, "best_loglik": np.nan, "best_index": np.nan
        }
    q = quantiles(a)
    best_idx = int(np.argmax(a))
    sd = float(np.std(a, ddof=1)) if a.size > 1 else 0.0
    mean = float(np.mean(a))
    return {
        "Subtype": subtype_name,
        "n": int(a.size),
        "mean": mean,
        "sd": sd,
        "CV": (sd / mean) if (a.size > 1 and mean != 0) else np.nan,
        "median": float(np.median(a)),
        "min": float(np.min(a)),
        "q2.5%": q["q2.5%"],
        "q25%": q["q25%"],
        "q75%": q["q75%"],
        "q97.5%": q["q97.5%"],
        "max": float(np.max(a)),
        "IQR": float(q["q75%"] - q["q25%"]),
        "range": float(np.max(a) - np.min(a)),
        "best_loglik": float(a[best_idx]),
        "best_index": best_idx
    }

def make_stats_excel(output_folder, dataset_name, excel_name="Supplementary_Figure_2_stats.xlsx"):
    files = collect_subtype_pickles(output_folder, dataset_name)
    if not files:
        raise FileNotFoundError(
            f"No pickle files found in {os.path.join(output_folder,'pickle_files')} "
            f"matching {dataset_name}_subtype*.pickle"
        )

    summaries = []
    quantiles_long = []

    for p in files:
        samples = load_samples_likelihood(p)
        m = re.search(r"_subtype(\d+)\.pickle$", os.path.basename(p))
        s_label = f"Subtype {m.group(1)}" if m else os.path.basename(p)

        smry = summary_for_subtype(samples, s_label)
        summaries.append(smry)

        quantiles_long.append({
            "Subtype": s_label,
            "q2.5%": smry["q2.5%"],
            "q25%": smry["q25%"],
            "median": smry["median"],
            "q75%": smry["q75%"],
            "q97.5%": smry["q97.5%"],
        })

    df_summary = pd.DataFrame(summaries)[[
        "Subtype","n","mean","sd","CV","median","min","q2.5%","q25%","q75%","q97.5%","max","IQR","range","best_loglik","best_index"
    ]]
    df_quants  = pd.DataFrame(quantiles_long)

    # Overall across subtypes
    all_samples = []
    for p in files:
        all_samples.append(safe_array(load_samples_likelihood(p)))
    all_samples = np.concatenate([a for a in all_samples if a.size > 0]) if all_samples else np.array([])
    df_overall = (pd.DataFrame([summary_for_subtype(all_samples, "All subtypes combined")])[df_summary.columns]
                  if all_samples.size else pd.DataFrame(columns=df_summary.columns))

    # Write Excel
    excel_path = os.path.join(output_folder, excel_name)
    with pd.ExcelWriter(excel_path, engine="xlsxwriter") as xw:
        df_summary.to_excel(xw, sheet_name="Per-subtype summary", index=False)
        df_quants.to_excel(xw,   sheet_name="Quantiles", index=False)
        df_overall.to_excel(xw,  sheet_name="Overall", index=False)

    print(f"✅ Wrote stats Excel: {excel_path}")
    return excel_path

# ---------- plotting ----------
def plot_mcmc(output_folder, dataset_name):
    files = collect_subtype_pickles(output_folder, dataset_name)
    if not files:
        raise FileNotFoundError(
            f"No pickle files found in {os.path.join(output_folder,'pickle_files')} "
            f"matching {dataset_name}_subtype*.pickle"
        )

    # Create figures
    plt.figure(0, figsize=(8, 5))  # Trace plot
    plt.title('MCMC Log-Likelihood Trace', fontsize=16)
    plt.xlabel('MCMC Samples', fontsize=14)
    plt.ylabel('Log Likelihood', fontsize=14)

    plt.figure(1, figsize=(8, 5))  # Histogram plot
    plt.title('Histogram of Log Likelihoods', fontsize=16)
    plt.xlabel('Log Likelihood', fontsize=14)
    plt.ylabel('Number of Samples', fontsize=14)

    # Add per-subtype traces/histograms
    for p in files:
        m = re.search(r"_subtype(\d+)\.pickle$", os.path.basename(p))
        s_idx = int(m.group(1)) if m else None
        label = f"Subtype {s_idx}" if s_idx is not None else os.path.basename(p)

        samples = load_samples_likelihood(p)
        x = range(len(samples))

        plt.figure(0)
        plt.plot(x, samples, label=label)

        plt.figure(1)
        plt.hist(samples, bins=50, alpha=0.6, label=label)

    # Save figures
    plt.figure(0)
    plt.legend(loc='upper right', fontsize=12)
    plt.tight_layout()
    trace_path = os.path.join(output_folder, 'MCMC_loglikelihood_trace_LRG.png')
    plt.savefig(trace_path)

    plt.figure(1)
    plt.legend(loc='upper right', fontsize=12)
    plt.tight_layout()
    hist_path = os.path.join(output_folder, 'MCMC_histogram_LRG.png')
    plt.savefig(hist_path)

    # Show & close
    plt.show()
    plt.close('all')

    print(f"✅ Saved trace plot: {trace_path}")
    print(f"✅ Saved histogram: {hist_path}")

# ---------- main ----------
if __name__ == "__main__":
    # 1) plots
    plot_mcmc(output_folder, dataset_name)
    # 2) stats Excel
    stats_xlsx = make_stats_excel(output_folder, dataset_name, excel_name=excel_name)
    print("✅ Excel path:", stats_xlsx)

In [None]:
print(zdata[biomarkers].std())

In [None]:
print(zdata[biomarkers].describe())
print(output_folder)

In [None]:
print("ml_subtype:", np.unique(ml_subtype))
print("ml_stage:", np.unique(ml_stage))
print("prob_subtype_stage NaNs:", np.isnan(prob_subtype_stage).sum())

In [None]:
# Retrieveing the P - number from the model

#The output od SUStaIn is:[samples_sequence, samples_f, ml_subtype, prob_ml_subtype, ml_stage, prob_ml_stage, prob_subtype_stage]

print(ml_subtype)

# Flatten the 2D array to a 1D array
ml_subtype_flat = ml_subtype.flatten()
ml_stage_flat = ml_stage.flatten()

# Create the pandas Series
ml_subtype_series = pandas.Series(ml_subtype_flat)
ml_stage_series = pandas.Series(ml_stage_flat)

# Combining Cases and subtypes in one variable
combined_df = pandas.DataFrame({
   'Case': data['Lab_no'].reset_index(drop=True),  # Reset index in case of mismatch
   'Subtype': ml_subtype_series,
    'Stage': ml_stage_series
})

# Print the combined DataFrame
print(combined_df)

# Saving the combined variable to an excel file
#Specify the directory and filename
directory = output_folder  # Replace with your desired directory
filename = "\Combined_CaseNumber_&_Subtype3.xlsx"   # Replace with your desired file name

# Full path to save the Excel file
file_path = directory + filename

# Save the DataFrame to an Excel file
combined_df.to_excel(file_path, index=False)  # index=False avoids saving the DataFrame index

# Optional: Print confirmation
print(f"DataFrame successfully saved to {file_path}")

combined_df.Subtype.value_counts()

# Evaluate subtypes

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# for each subtype model
for s in range(N_S_max):
    # load pickle file (SuStaIn output) and get the sample log likelihood values
    pickle_filename_s = output_folder + '/pickle_files/' + dataset_name + '_subtype' + str(s) + '.pickle'
    pk = pd.read_pickle(pickle_filename_s)
    samples_likelihood = pk["samples_likelihood"]
    
    # === Line plot of likelihood trace ===
    plt.figure(0)
    plt.plot(range(N_iterations_MCMC), samples_likelihood, label=f"Subtype {s}")
    plt.xlabel('MCMC Samples', fontsize=14)
    plt.ylabel('Log Likelihood', fontsize=14)
    plt.title('MCMC Log-Likelihood Trace', fontsize=16)
    plt.legend(loc='upper right', fontsize=12)

    # === Histogram of likelihoods ===
    plt.figure(1)
    plt.hist(samples_likelihood, bins=50, alpha=0.6, label=f"Subtype {s}")
    plt.xlabel('Log Likelihood', fontsize=14)
    plt.ylabel('Number of Samples', fontsize=14)
    plt.title('Histogram of Log Likelihoods', fontsize=16)
    plt.legend(loc='upper right', fontsize=12)

# Save combined histogram figure
save_path = os.path.join(output_folder, 'MCMC_histogram.png')
plt.figure(1)
plt.tight_layout()
plt.savefig(save_path)
plt.show()
plt.close('all')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Flatten arrays to ensure they are 1D
ml_subtype = np.asarray(ml_subtype).flatten()
prob_ml_subtype = np.asarray(prob_ml_subtype).flatten()

# Build the DataFrame
combined_df = pd.DataFrame({
    'Subtype': ml_subtype,
    'Prob_Subtype': prob_ml_subtype
})

# Step 1: Compute count and f-score per subtype
subtype_counts = combined_df['Subtype'].value_counts().sort_index()
subtype_f_scores = combined_df.groupby('Subtype')['Prob_Subtype'].mean().sort_index()

# Step 2: Plot
fig, ax = plt.subplots(figsize=(8, 5))
subtype_counts.plot(kind='bar', color='skyblue', ax=ax)
plt.xlabel('Subtype')
plt.ylabel('Number of subjects')
plt.title('Frequency of Subjects by Subtype')

# Step 3: Annotate counts and f-scores
for i, (count, f_score) in enumerate(zip(subtype_counts, subtype_f_scores)):
    ax.text(i, count + 1, f'n={count}\nf={f_score:.2f}',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Flatten arrays first
ml_subtype = np.asarray(ml_subtype).flatten()
prob_ml_subtype = np.asarray(prob_ml_subtype).flatten()
ml_stage = np.asarray(ml_stage).flatten()
# Build DataFrame
combined_df = pd.DataFrame({
    'Subtype': ml_subtype,
    'Prob_Subtype': prob_ml_subtype,
    'Stage': ml_stage
})
bins=range(0, int(max(combined_df['Stage'])) + 2)

# Plot
plt.figure(figsize=(10, 6))
sns.histplot(
    data=combined_df,
    x='Stage',
    hue='Subtype',
    multiple='stack',
    bins=bins,
    palette='tab10'
)
plt.xlabel('SuStaIn Stage')
plt.ylabel('Number of Subjects')
plt.title('Stage Distribution Across Subtypes')
plt.tight_layout()
plt.show()

# Subtype and stage individuals

In [None]:
# let's take a look at all of the things that exist in SuStaIn's output (pickle) file
pk.keys()

In [None]:
# The SuStaIn output has everything we need. We'll use it to populate our dataframe.

s = N_S_max-1
pickle_filename_s = output_folder + '/pickle_files/' + dataset_name + '_subtype' + str(s) + '.pickle'
pk = pandas.read_pickle(pickle_filename_s)

print(f"zdata shape: {zdata.shape}")
print(len(pk['ml_subtype']))

for variable in ['ml_subtype', # the assigned subtype
                 'prob_ml_subtype', # the probability of the assigned subtype
                 'ml_stage', # the assigned stage 
                 'prob_ml_stage',]: # the probability of the assigned stage
    
    # add SuStaIn output to dataframe
    zdata.loc[:,variable] = pk[variable] 

# let's also add the probability for each subject of being each subtype
for i in range(s):
    zdata.loc[:,'prob_S%s'%i] = pk['prob_subtype'][:,i]
zdata.head(90)

In [None]:
# IMPORTANT!!! The last thing we need to do is to set all "Stage 0" subtypes to their own subtype
# We'll set current subtype (0 and 1) to 1 and 0, and we'll call "Stage 0" individuals subtype 0.

# make current subtypes (0 and 1) 1 and 2 instead
zdata.loc[:,'ml_subtype'] = zdata.ml_subtype.values + 1

# convert "Stage 0" subjects to subtype 0
zdata.loc[zdata.ml_stage==0,'ml_subtype'] = 0

# Specify the directory and filename
directory = output_folder  # Replace with your desired directory
filename = "\FullOutput.xlsx"   # Replace with your desired file name

# Full path to save the Excel file
file_path = directory + filename

# Save the DataFrame to an Excel file
zdata.to_excel(file_path, index=False)  # index=False avoids saving the DataFrame index

# Optional: Print confirmation
print(f"DataFrame successfully saved to {file_path}")

In [None]:
# IMPORTANT!!! The last thing we need to do is to set all "Stage 0" subtypes to their own subtype
# We'll set current subtype (0 and 1) to 1 and 0, and we'll call "Stage 0" individuals subtype 0.
# Ensure that the 'ml_subtype' column exists
if 'ml_subtype' not in zdata.columns:
    print("❌ 'ml_subtype' column is missing! Please check the model output.")
else:
    # === Final Subtype Reassignment ===
    # Create a copy of the original subtype column (optional but safer)
    zdata['Adjusted_Subtype'] = zdata['ml_subtype'] + 1  # Shift all existing subtypes up by 1

    # Assign Stage 0 subjects to new subtype 0
    zdata.loc[zdata['ml_stage'] == 0, 'Adjusted_Subtype'] = 0

    # (Optional) Drop the old column or rename
    # zdata = zdata.drop(columns=['ml_subtype'])
    # zdata.rename(columns={'Adjusted_Subtype': 'ml_subtype'}, inplace=True)

    # === Save the updated DataFrame ===
    filename = "FullOutput.xlsx"
    file_path = os.path.join(output_folder, filename)
    zdata.to_excel(file_path, index=False)

    print(f"✅ Final output with adjusted subtypes saved to: {file_path}")


In [None]:
zdata.ml_subtype.value_counts()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Ensure Diagnosis is numeric (optional sanity check)
zdata['Diagnosis'] = pd.to_numeric(zdata['Diagnosis'], errors='coerce')

# Plot ml_stage distribution by Diagnosis, separated by subtype
g = sns.displot(
    data=zdata,
    x='ml_stage',
    hue='Diagnosis',
    col='ml_subtype',
    kind='hist',
    multiple='stack',
    bins=15,
    palette='Set2'
)

g.set_axis_labels("SuStaIn Stage", "Count")
g.set_titles("Subtype {col_name}")
plt.suptitle("SuStaIn Stage Distributions by Subtype & Diagnosis", y=1.05)

# Save the figure
save_path = os.path.join(output_folder, 'hist_stages_per_subtype.png')
plt.savefig(save_path)
print(f"✅ Saved histogram figure to {save_path}")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import os

sns.pointplot(x='ml_stage',y='prob_ml_subtype', # input variables
              hue='ml_subtype',                 # "grouping" variable
            data=zdata[zdata.ml_subtype>0]) # only plot for Subtypes 1 and 2 (not 0)
plt.ylim(0,1) 
plt.axhline(0.5,ls='--',color='k') # plot a line representing change (0.5 in the case of 2 subtypes)

save_path = os.path.join(output_folder, 'Subtype probabilities.png')
plt.savefig(save_path)

# Annotate
plt.ylim(0, 1)
plt.axhline(0.5, ls='--', color='k', label='Decision boundary (0.5)')
plt.title("Subtype Probability vs. SuStaIn Stage")
plt.xlabel("SuStaIn Stage")
plt.ylabel("Probability of Assigned Subtype")
plt.legend(title="Subtype")
plt.tight_layout()

# Save the figure
save_path = os.path.join(output_folder, 'Subtype_probabilities_by_stage.png')
plt.savefig(save_path)
print(f"✅ Saved plot to {save_path}")

# Evaluate relationships

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# === Ensure output directory ===
plot_dir = os.path.join(output_folder, "Biomarker_vs_Stage_Plots")
os.makedirs(plot_dir, exist_ok=True)

# === Filter for assigned subtypes only ===
plot_data = zdata[zdata.ml_subtype >= 0].copy()

# === Store correlation results ===
results_pearson = []
results_spearman = []

# === Loop over biomarkers ===
for var in biomarkers:
    print(f"🔍 Processing: {var}")
    
    # Plot
       # === Create the plot ===
    g = sns.lmplot(
        x='ml_stage',
        y=var,
        hue='ml_subtype',
        data=plot_data,
        aspect=1.5,
        height=6,
        scatter_kws={'alpha': 0.6},
        line_kws={'lw': 2},
        legend=False
    )

    ax = g.ax
    ax.set_title(f'{var} vs SuStaIn Stage by Subtype', fontsize=18)
    ax.set_xlabel('SuStaIn Stage', fontsize=14)
    ax.set_ylabel(var, fontsize=14)
    ax.grid(True)

    # === Prepare annotation text block
    text_lines = []

    for subtype in sorted(plot_data.ml_subtype.unique()):
        sub_df = plot_data[plot_data.ml_subtype == subtype][['ml_stage', var]].dropna()

        if len(sub_df) < 3:
            continue

        # Pearson
        r_pearson, p_pearson = stats.pearsonr(sub_df['ml_stage'], sub_df[var])
        results_pearson.append({
            "Biomarker": var,
            "Subtype": subtype,
            "N": len(sub_df),
            "r": r_pearson,
            "p": p_pearson
        })

        # Spearman
        r_spearman, p_spearman = stats.spearmanr(sub_df['ml_stage'], sub_df[var])
        results_spearman.append({
            "Biomarker": var,
            "Subtype": subtype,
            "N": len(sub_df),
            "rho": r_spearman,
            "p": p_spearman
        })

        # Add to text summary (both Spearman and Pearson)
        text_lines.append(
            f"Subtype {subtype}: Spearman:ρ = {r_spearman:.2f}, Spearman:p = {p_spearman:.3f}   |   Pearson:r = {r_pearson:.2f}, Pearson:p = {p_pearson:.3f}"
        )

    # === Final touches ===
    g.add_legend(title="Subtype", bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)
    plt.tight_layout(rect=[0, 0.15, 0.85, 1])  # ⬅️ leave more space at bottom and right

    # === Add text box under plot
    if text_lines:
        textstr = "\n".join(text_lines)
        ax.text(
            0.5, -0.4,  # further below plot
            textstr,
            transform=ax.transAxes,
            fontsize=12,
            ha='center',
            va='top',
            bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.5')
        )

    # === Show and save ===
    fname = os.path.join(plot_dir, f"{var}_vs_Stage_by_Subtype.png")
    plt.savefig(fname, bbox_inches='tight')
    plt.show()
    plt.close()
    print(f"✅ Saved plot: {fname}")

# === Save correlation results to Excel ===
excel_path = os.path.join(output_folder, "Biomarker_Correlations.xlsx")
with pd.ExcelWriter(excel_path) as writer:
    pd.DataFrame(results_pearson).to_excel(writer, sheet_name='Pearson', index=False)
    pd.DataFrame(results_spearman).to_excel(writer, sheet_name='Spearman', index=False)

print(f"📄 Exported correlation results to: {excel_path}")

In [None]:
# we can also look at differences in each biomarker across subtypes
from scipy import stats
import pandas as pd

def run_ttest(zdata, biomarkers, group_a, group_b):
    results = pd.DataFrame(index=biomarkers, columns=['t_stat', 'p_value'])

    for biomarker in biomarkers:
        values_a = zdata.loc[zdata.ml_subtype == group_a, biomarker]
        values_b = zdata.loc[zdata.ml_subtype == group_b, biomarker]
        
        t_stat, p_val = stats.ttest_ind(values_a, values_b, nan_policy='omit')
        results.loc[biomarker, 't_stat'] = float(t_stat)
        results.loc[biomarker, 'p_value'] = float(p_val)

    results['significant (p<0.05)'] = results['p_value'] < 0.05
    return results.sort_values('p_value')

# Run comparisons
results_0vs1 = run_ttest(zdata, biomarkers, group_a=0, group_b=1)
results_1vs2 = run_ttest(zdata, biomarkers, group_a=1, group_b=2)
results_0vs2 = run_ttest(zdata, biomarkers, group_a=0, group_b=2)
results_0vs3 = run_ttest(zdata, biomarkers, group_a=0, group_b=3)
results_1vs3 = run_ttest(zdata, biomarkers, group_a=1, group_b=3)
results_2vs3 = run_ttest(zdata, biomarkers, group_a=2, group_b=3)

# Display or save
print("🧪 Comparison: Subtype 0 vs 1")
print(results_0vs1)

print("\n🧪 Comparison: Subtype 1 vs 2")
print(results_1vs2)

print("\n🧪 Comparison: Subtype 0 vs 2")
print(results_0vs2)

print("\n🧪 Comparison: Subtype 0 vs 3")
print(results_0vs3)

print("\n🧪 Comparison: Subtype 1 vs 3")
print(results_1vs3)

print("\n🧪 Comparison: Subtype 2 vs 3")
print(results_2vs3)

excel_path = os.path.join(output_folder, 'Subtype_Comparisons.xlsx')
with pd.ExcelWriter(excel_path) as writer:
    results_0vs1.to_excel(writer, sheet_name='0_vs_1')
    results_1vs2.to_excel(writer, sheet_name='1_vs_2')
    results_0vs2.to_excel(writer, sheet_name='0_vs_2')
    results_0vs3.to_excel(writer, sheet_name='0_vs_3')
    results_1vs3.to_excel(writer, sheet_name='1_vs_3')
    results_2vs3.to_excel(writer, sheet_name='2_vs_3')

print(f"✅ Saved all subtype comparison results to: {excel_path}")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Function to convert t_stat to float and plot heatmap
def plot_tstat_heatmap(results_df, comparison_name):
    # Ensure the t_stat column is float
    tstat_df = pd.DataFrame(results_df['t_stat'].astype(float))

    plt.figure(figsize=(6, len(tstat_df) * 0.4))
    sns.heatmap(tstat_df, annot=True, fmt=".2f", square=False,
                cmap='RdBu_r', center=0, cbar_kws={'label': 't-statistic'})
    plt.title(f"T-test: {comparison_name}")
    plt.tight_layout()

    # Show the plot
    plt.show()

    # Save the plot
    save_path = os.path.join(output_folder, f"heatmap_{comparison_name.replace(' ', '_')}.png")
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Saved heatmap to: {save_path}")
    
plot_tstat_heatmap(results_0vs1, "Subtype 0 vs 1")
plot_tstat_heatmap(results_1vs2, "Subtype 1 vs 2")
plot_tstat_heatmap(results_0vs2, "Subtype 0 vs 2")
plot_tstat_heatmap(results_0vs3, "Subtype 0 vs 3")
plot_tstat_heatmap(results_1vs3, "Subtype 1 vs 3")
plot_tstat_heatmap(results_2vs3, "Subtype 2 vs 3")

In [None]:
for var in biomarkers:
    print(f"🔍 Processing: {var}")
    
    # Create boxplot
    sns.boxplot(x='ml_subtype', y=var, data=zdata)
    plt.title(f"{var} vs SuStaIn Subtype")
    plt.xlabel("Subtype")
    plt.ylabel(var)
    plt.tight_layout()
    
    # Save figure
    save_path = os.path.join(output_folder, f'{var}_vs_ml_subtype.png')
    plt.savefig(save_path)
    plt.show()
    plt.close()
    print(f"✅ Saved: {save_path}")

# Evaluate 1 subtype

In [None]:
# Input the settings for z-score SuStaIn
# To make the tutorial run faster I've set 
# N_startpoints = 10 and N_iterations_MCMC = int(1e4)
# I recommend using N_startpoints = 25 and 
# N_iterations_MCMC = int(1e5) or int(1e6) in general though


#✅ Control parallel startpoints execution (with n_jobs=24 via joblib’s parallel_backend).

#✅ Limit numpy/scipy BLAS threading to 24 threads (via OMP_NUM_THREADS=24), avoiding oversubscription per core.

import pickle
import os
import pySuStaIn
from joblib import parallel_backend

# Limit BLAS threads (numpy/scipy etc.)
os.environ["OMP_NUM_THREADS"] = "28"  # Or 32 if thermals are fine

# SuStaIn Settings
N_SAMPLES = 1000
N_startpoints = 30
N_S_max = 1
N_iterations_MCMC = int(1e6)
output_folder = os.path.join(os.getcwd(), "/Users/hemanthnelvagal/Desktop/SUSTAIN/MPP_SUSTAIN/MPP_SUSTAIN_COMMON_4N_LIMBIC/1N")
dataset_name = '1N'
5
# Save input data to Excel (for record)
zdata.to_excel('zdata_input_norm_lowctrl.xlsx', index=False)

# Run SuStaIn with controlled parallelism
with parallel_backend('loky', n_jobs=24):  # Adjust n_jobs if needed
    sustain_input = pySuStaIn.ZscoreSustainMissingData(
        zdata[biomarkers].values,
        Z_vals,
        Z_max,
        SuStaInLabels,
        N_startpoints,
        N_S_max,
        N_iterations_MCMC,
        output_folder,
        dataset_name,
        True
    )

# make the output directory if it's not already created
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)
try:
    # === Run SuStaIn ===
    results = sustain_input.run_sustain_algorithm(plot=True)
    samples_sequence, samples_f, ml_subtype, prob_ml_subtype, ml_stage, prob_ml_stage, prob_subtype_stage = results

    # === Initialize dimensions ===
    n_subjects = zdata.shape[0]
    n_stages = Z_vals.shape[1]
    n_subtypes = N_S_max

    # === Normalize probability matrix and fix NaNs ===
    prob_subtype_stage = np.nan_to_num(prob_subtype_stage, nan=0.0)
    total_probs = prob_subtype_stage.sum(axis=(1, 2), keepdims=True)
    total_probs[total_probs == 0] = 1  # avoid division by zero
    prob_subtype_stage /= total_probs

    # === Compute most likely subtype and stage for each subject ===
    ml_subtype = np.full(n_subjects, -1, dtype=int)
    ml_stage = np.full(n_subjects, -1, dtype=int)
    prob_ml_subtype = np.zeros(n_subjects)
    prob_ml_stage = np.zeros(n_subjects)

    for i in range(n_subjects):
        subtype_probs = np.sum(prob_subtype_stage[i], axis=0)
        ml_subtype[i] = int(np.argmax(subtype_probs))
        prob_ml_subtype[i] = subtype_probs[ml_subtype[i]]

        stage_probs = prob_subtype_stage[i, :, ml_subtype[i]]
        ml_stage[i] = int(np.argmax(stage_probs))
        prob_ml_stage[i] = stage_probs[ml_stage[i]]

    print("✅ SuStaIn finished with normalized probabilities and safe subtype/stage assignments.")

except ValueError as e:
    # === Emergency fallback ===
    print("❌ ValueError during SuStaIn assignment. Likely due to NaNs or division by zero.")
    print("🔍 Error:", e)

    n_subjects = zdata.shape[0]
    n_stages = Z_vals.shape[1]

    ml_subtype = np.full(n_subjects, -1, dtype=int)
    ml_stage = np.full(n_subjects, -1, dtype=int)
    prob_ml_subtype = np.zeros(n_subjects)
    prob_ml_stage = np.zeros(n_subjects)
    prob_subtype_stage = np.zeros((n_subjects, n_stages + 1, N_S_max))

for name, obj in {
    "samples_sequence": samples_sequence,
    "samples_f": samples_f,
    "ml_subtype": ml_subtype,
    "prob_ml_subtype": prob_ml_subtype,
    "Z_vals": Z_vals,
    "biomarker_labels": SuStaInLabels
}.items():
    with open(os.path.join(output_folder, "pickle_files", f"{name}.pickle"), "wb") as f:
        pickle.dump(obj, f)
# === High-Resolution PVD Plot ===
import matplotlib.pyplot as plt
from pySuStaIn.ZscoreSustain import ZscoreSustain  # Required for plot_positional_var()

# Output path for plots
pvd_dir = os.path.join(output_folder, "PVD")
os.makedirs(pvd_dir, exist_ok=True)

# Generate PVD figure
figs = ZscoreSustain.plot_positional_var(
    samples_sequence=samples_sequence,
    samples_f=samples_f,
    n_samples=N_SAMPLES,
    Z_vals=Z_vals,
    biomarker_labels=SuStaInLabels,
    separate_subtypes=True,
    figsize=(16, 4)
)

# Save figure at high resolution
# Unpack tuple safely
for figlist in figs:
    if isinstance(figlist, (list, tuple)):
        for fig in figlist:
            if isinstance(fig, plt.Figure):
                fig.axes[0].set_title("")
                fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.png"), dpi=1000, bbox_inches="tight")
                fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.pdf"), bbox_inches="tight")
                plt.show()
                plt.close(fig)
    elif isinstance(figlist, plt.Figure):
        fig = figlist
        fig.axes[0].set_title("")
        fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.png"), dpi=1000, bbox_inches="tight")
        fig.savefig(os.path.join(pvd_dir, f"PVD_N1_Subtype_from_true_model_{id(fig)}.pdf"), bbox_inches="tight")
        plt.show()
        plt.close(fig)

# Cross-validation

In [45]:
import os
print("Logical cores (n_proc max):", os.cpu_count())

Logical cores (n_proc max): 14


In [None]:
import sklearn.model_selection
import numpy as np
import pandas as pd
import os

# === CONFIGURATION ===
N_folds = 10  # You can change this value based on your needs
N_startpoints_cv = 30  # Reduced for speed
N_iterations_MCMC_cv = int(1e5)  # Reduced for speed
N_S_max_cv = 7  # Max subtypes to evaluate
n_proc = 12  # Number of parallel processes, adjust as needed

# === Prepare cross-validation folds ===
labels = zdata['Diagnosis'].values
cv = sklearn.model_selection.StratifiedKFold(n_splits=N_folds, shuffle=True, random_state=42)
cv_it = list(cv.split(zdata, labels))  # Convert generator to list

# Store the test indices from each fold
test_idxs = [test for _, test in cv_it]

# Check how many test indices per fold
print(f"Number of folds: {N_folds}")
print(f"Test indices per fold: {test_idxs}")

# === Initialize separate SuStaIn object for CV (faster settings) ===
cv_output_folder = os.path.join(output_folder, 'crossval_temp')
os.makedirs(cv_output_folder, exist_ok=True)

sustain_input_cv = pySuStaIn.ZscoreSustainMissingData(
    zdata[biomarkers].values,
    Z_vals,
    Z_max,
    SuStaInLabels,
    N_startpoints_cv,
    N_S_max_cv,
    N_iterations_MCMC_cv,
    cv_output_folder,    # Set temporary folder for CV
    "CV_TEMP",           # Dummy dataset name for CV
    True                 # Use parallel startpoints
)

# === Perform cross-validation with parallelization ===
print("🚀 Running SuStaIn cross-validation with parallelization...")

# Make sure the cross-validation method handles test indices and parallelization
CVIC, loglike_matrix = sustain_input_cv.cross_validate_sustain_model(
    test_idxs=test_idxs
)

# === View results ===
print("\n✅ Cross-validation complete!")
print("CVIC (Cross-Validation Information Criterion):\n", CVIC)
print("\nLog-likelihood matrix:\n", loglike_matrix)

# Choosing the optimal number of subtypes

In [None]:
# === Print CVIC & Log-Likelihood ===
print("CVIC for each subtype model: " + str(CVIC))
print("Average test set log-likelihood for each subtype model: " + str(np.mean(loglike_matrix, axis=0)))

# === Plot CVIC ===
plt.figure(0)
plt.plot(np.arange(1, N_S_max_cv + 1, dtype=int), CVIC, marker='o')
plt.xticks(np.arange(1, N_S_max_cv + 1, dtype=int))
plt.ylabel('CVIC')  
plt.xlabel('Number of Subtypes') 
plt.title('CVIC across Subtype Models')

save_path = os.path.join(output_folder, 'CVIC.png')
plt.savefig(save_path)

# === Plot Test Set Log-Likelihood (Boxplot) ===
plt.figure(1)
df_loglike = pd.DataFrame(data=loglike_matrix, columns=[f"Subtype_{i+1}" for i in range(N_S_max_cv)])
df_loglike.boxplot(grid=False)
plt.ylabel('Log-Likelihood')  
plt.xlabel('Number of Subtypes') 
plt.title('Test Set Log-Likelihood across Folds')

save_path = os.path.join(output_folder, 'LogLikelihood_Boxplot.png')
plt.savefig(save_path)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

S_range = np.arange(1, len(CVIC) + 1)
mean_ll = loglike_matrix.mean(axis=0)

plt.figure(figsize=(10, 5))

# CVIC plot
plt.subplot(1, 2, 1)
plt.plot(S_range, CVIC, marker='o', label='Corrected CVIC')
plt.xlabel("Number of Subtypes (S)")
plt.ylabel("Corrected CVIC")
plt.title("Corrected CVIC vs. Subtypes")
plt.grid(True)
plt.legend()

# Log-likelihood plot
plt.subplot(1, 2, 2)
plt.plot(S_range, mean_ll, marker='s', color='orange', label='Mean Log-Likelihood')
plt.xlabel("Number of Subtypes (S)")
plt.ylabel("Mean Test Log-Likelihood")
plt.title("Test Log-Likelihood vs. Subtypes")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import pandas as pd
from itertools import combinations
from scipy.stats import wilcoxon

# === 1) CVIC stats ===
cvic_df = pd.DataFrame({
    "Subtypes (S)": S_range,
    "CVIC": CVIC
})
cvic_df["ΔCVIC_vs_best"] = cvic_df["CVIC"] - cvic_df["CVIC"].min()
cvic_df["Rank"] = cvic_df["CVIC"].rank().astype(int)

# === 2) Log-likelihood descriptive stats ===
def summarise_array(arr):
    arr = np.asarray(arr, dtype=float)
    return {
        "n": len(arr),
        "mean": np.mean(arr),
        "sd": np.std(arr, ddof=1),
        "median": np.median(arr),
        "min": np.min(arr),
        "max": np.max(arr),
        "IQR": np.percentile(arr, 75) - np.percentile(arr, 25),
        "q2.5%": np.percentile(arr, 2.5),
        "q97.5%": np.percentile(arr, 97.5)
    }

loglike_stats = []
for j in range(loglike_matrix.shape[1]):
    stats = summarise_array(loglike_matrix[:, j])
    stats["Subtypes (S)"] = j+1
    loglike_stats.append(stats)

loglike_df = pd.DataFrame(loglike_stats)

# === 3) Pairwise Wilcoxon signed-rank tests ===
pairs = []
for (i, j) in combinations(range(loglike_matrix.shape[1]), 2):
    x, y = loglike_matrix[:, i], loglike_matrix[:, j]
    try:
        stat, p = wilcoxon(x, y)
    except ValueError:  # if identical values → cannot compute
        stat, p = np.nan, np.nan
    pairs.append({
        "Comparison": f"S={i+1} vs S={j+1}",
        "Wilcoxon_stat": stat,
        "p_value": p
    })

pairwise_df = pd.DataFrame(pairs)

# === 4) Write all results to Excel ===
with pd.ExcelWriter("CVIC_LogLike_stats.xlsx", engine="xlsxwriter") as xw:
    cvic_df.to_excel(xw, sheet_name="CVIC", index=False)
    loglike_df.to_excel(xw, sheet_name="LogLikelihoods", index=False)
    pairwise_df.to_excel(xw, sheet_name="Pairwise_Wilcoxon", index=False)

print("✅ Wrote CVIC + loglike stats + pairwise tests to CVIC_LogLike_stats.xlsx")

# SANKEYS

In [None]:
import os
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

# === Load Excel ===
df = pd.read_excel(#COMBINED.xlsx)
df.columns = df.columns.str.strip()

# === Constants ===
base_vars = ['ASYN_Subtype', 'Age_at_death', 'Duration', 'BIO_SEX','ApoE',  'OH', 'INFARCT', 'AD_LEVEL','DLB', 'Dementia']
group_var = 'MPP_Subtype'
outdir = "MPP_SUSTAIN_Sankey_Output_Dur15_BW_DLB"
os.makedirs(outdir, exist_ok=True)

# === Clean NAs ===
df[base_vars + [group_var]] = df[base_vars + [group_var]].astype(str).replace("nan", "NA").fillna("NA")
df["ASYN_Subtype"] = df["ASYN_Subtype"].apply(lambda x: "ASYN<br>" + x.replace("ASYN ", "") if x != "NA" else x)
df["AD_LEVEL"] = df["AD_LEVEL"].apply(lambda x: "AD<br>" + x.replace("AD ", "") if x != "NA" else x)
df["Duration"] = df["Duration"].apply(lambda x: "Duration<br>" + x.replace("Duration ", "") if x != "NA" else x)
df["MPP_Subtype"] = df["MPP_Subtype"].apply(lambda x: "MPP<br>Subtype<br>" + x.replace("MPP Subtype", "").strip() if x != "NA" else x)

# === Label mappings for display
label_mapping = {
    "Dementia": {"present": "Dementia", "absent": "No dementia"},
    "AD_LEVEL": {"low": "AD Low", "high": "AD High"},
    "OH": {"present": "OH Present", "absent": "OH Absent"},
    "INFARCT": {"present": "Ischaemia Present", "absent": "Ischaemia Absent"},
    "DLB": { "DLB": "DLB","PDD": "PDD","PD": "PD"}
}
for col, mapping in label_mapping.items():
    df[col] = df[col].map(mapping).fillna(df[col])

# === Label renaming for final nodes
label_renaming = {
    "ApoE4": "APOE ε4",
    "ApoE3": "APOE ε3",
    "AD<br>int/high": "AD<br>Int/High",
    "AD<br>none/low": "AD<br>None/Low",
    "OH": "OH<br>Present",
    "No OH": "OH<br>Absent",
    "Ischaemia": "Ischaemia<br>Present",
    "No Ischaemia": "Ischaemia<br>Absent"
}

# === Color override
label_color_override = {
    "ASYN<br>Amygdala": "#1a1a1a",
    "ASYN<br>Brainstem1": "#666666",
    "ASYN<br>Brainstem2": "#aaaaaa",
    "ASYN<br>Neocortical": "#eeeeee",
    "AAD <77yrs": "#1a1a1a",
    "AAD ≥77yrs": "#666666",
    "Duration<br><15yrs": "#1a1a1a",
    "Duration<br>≥15yrs": "#666666",
    "APOE ε4": "#1a1a1a",
    "APOE ε3": "#666666",
    "APOE NA": "#eeeeee",
    "AD<br>Int/High": "#1a1a1a", "AD<br>None/Low": "#666666",
    "Dementia": "#1a1a1a", "No dementia": "#666666", "Dementia NA": "#eeeeee",
    "Female": "#666666", "Male": "#1a1a1a",
    "OH<br>Present": "#1a1a1a", "OH<br>Absent": "#666666", "OH NA": "#eeeeee",
    "Ischaemia<br>Present": "#1a1a1a", "Ischaemia<br>Absent": "#666666",
    "DLB": "#1a1a1a","PDD": "#666666", "PD": "#eeeeee"
}

# === Sankey data builder
def make_sankey_data(df_subset):
    df_plot = df_subset.copy()
    ordered_vars = base_vars

    label_lookup, raw_source, raw_target, raw_value = {}, [], [], []

    # Count for percentages
    label_counts_per_var = {}
    for col in ordered_vars:
        col_series = df_plot[col]
        clean_series = col_series[~col_series.str.endswith(" NA", na=False)]
        label_counts_per_var[col] = clean_series.value_counts().to_dict()

    # Create unique label list with renaming
    for col in ordered_vars:
        for val in df_plot[col].unique():
            renamed = label_renaming.get(val, val)
            if renamed not in label_lookup:
                label_lookup[renamed] = len(label_lookup)

    # Create connections with renaming
    for i in range(len(ordered_vars) - 1):
        col_from, col_to = ordered_vars[i], ordered_vars[i + 1]
        df_pair = df_plot[[col_from, col_to]]
        df_pair = df_pair[(df_pair[col_from] != "NA") & (df_pair[col_to] != "NA")]
        grouped = df_pair.groupby([col_from, col_to]).size().reset_index(name="count")
        for _, row in grouped.iterrows():
            src = label_renaming.get(row[col_from], row[col_from])
            tgt = label_renaming.get(row[col_to], row[col_to])
            raw_source.append(label_lookup[src])
            raw_target.append(label_lookup[tgt])
            raw_value.append(row['count'])

    # Format node labels with %
    label_list = list(label_lookup.keys())
    label_list_with_pct = []
    for label in label_list:
        handled = False
        for col in ordered_vars:
            col_counts = label_counts_per_var.get(col, {})
            # Inverse-rename the label to match raw data keys
            inv_label = next((orig for orig, renamed in label_renaming.items() if renamed == label), label)
            if inv_label in col_counts:
                count = col_counts[inv_label]
                total = sum(col_counts.values())
                pct = (count / total * 100) if total > 0 else 0
                label_list_with_pct.append(f"{label}<br>{pct:.1f}%")
                handled = True
                break
            elif label.endswith(" NA"):
                label_list_with_pct.append(label)
                handled = True
                break
        if not handled:
            label_list_with_pct.append(label)

    # Assign node colors
    node_colors = []
    for label in label_list:
        raw_label = label.split("<br>")[0] if "<br>" in label else label
        node_colors.append(label_color_override.get(label, label_color_override.get(raw_label, "#999999")))

    return dict(
        node=dict(label=label_list_with_pct, pad=20, thickness=20, color=node_colors),
        link=dict(source=raw_source, target=raw_target, value=raw_value)
    )

# === Plotter
def plot_sankey(df_subset, title, filename):
    sankey_data = make_sankey_data(df_subset)
    fig = go.Figure(go.Sankey(**sankey_data))
    fig.update_layout(
        font=dict(size=42, color="black"),
        width=3000,
        height=1000,
        margin=dict(t=60, l=30, r=30, b=30),
        paper_bgcolor="white",
        title=dict(text=title, font_size=24)
    )
    html_path = os.path.join(outdir, f"{filename}.html")
    pdf_path = os.path.join(outdir, f"{filename}.pdf")
    fig.write_html(html_path)
    try:
        pio.write_image(fig, pdf_path, format="pdf", scale=2)
    except Exception as e:
        print(f"[⚠️] PDF export failed for {filename}: {e}")
    fig.show()

# === Run plots
plot_sankey(df, "", "sankey_ALL_MPP")

for subtype in sorted(df[group_var].dropna().unique()):
    df_sub = df[df[group_var] == subtype]
    if not df_sub.empty:
        safe_name = str(subtype).replace(" ", "_").replace("/", "_")
        plot_sankey(df_sub, f"", f"sankey_MPP_{safe_name}")


import itertools
import numpy as np
import pandas as pd
from scipy.stats import chi2_contingency, fisher_exact

node_vars = ['ASYN_Subtype', 'Age_at_death', 'Duration', 'BIO_SEX','ApoE',  'OH', 'INFARCT', 'AD_LEVEL','DLB', 'Dementia']
group_var = 'MPP_Subtype'

pairwise_results = []

for var in node_vars:
    subtypes = df[group_var].dropna().unique()
    for st1, st2 in itertools.combinations(subtypes, 2):
        sub_df = df[df[group_var].isin([st1, st2])]
        ct = pd.crosstab(sub_df[var], sub_df[group_var])

        if ct.shape[1] == 2 and ct.shape[0] >= 2:
            if ct.shape == (2, 2):
                # 2x2 table → Fisher if any cell < 5, else Chi-square
                if (ct.values < 5).any():
                    _, p = fisher_exact(ct)
                    method = "Fisher’s exact (2×2, small counts)"
                else:
                    _, p, _, _ = chi2_contingency(ct)
                    method = "Chi-square (2×2)"
            else:
                # >2 categories vs 2 subtypes → Chi-square
                _, p, _, _ = chi2_contingency(ct)
                method = "Chi-square (>2×2)"
            pairwise_results.append((var, st1, st2, p, method))

pairwise_df = pd.DataFrame(
    pairwise_results,
    columns=["Node", "Subtype1", "Subtype2", "P_value", "Method"]
)

# Bonferroni correction across all pairwise tests
pairwise_df["P_adj"] = (pairwise_df["P_value"] * len(pairwise_df)).clip(upper=1.0)

# Significance flags
pairwise_df["p-value significant (p≤0.05)"] = np.where(pairwise_df["P_value"] <= 0.05, "Yes", "No")
pairwise_df["adj p-value significant (p≤0.05)"] = np.where(pairwise_df["P_adj"] <= 0.05, "Yes", "No")

# (Optional) tidy rounding for readability
pairwise_df["P_value"] = pairwise_df["P_value"].round(6)
pairwise_df["P_adj"]   = pairwise_df["P_adj"].round(6)

print(pairwise_df.sort_values(["Node", "P_adj"]))

# Save to CSV and Excel
pairwise_df.to_csv("pairwise_node_subtype_comparisons.csv", index=False, float_format="%.10g")
pairwise_df.to_excel("pairwise_node_subtype_comparisons.xlsx", index=False)

print("✅ Saved: 'pairwise_node_subtype_comparisons.csv' and '.xlsx'")

In [6]:
# PATHOLOGY BURDEN

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import scikit_posthocs as sp
from statannotations.Annotator import Annotator

# === Setup ===
outdir = "MPP_ViolinPlots"
os.makedirs(outdir, exist_ok=True)

# Load your data
zdata = pd.read_excel(
    #COMBINED.xlsx
)

# Variables and labels
variables_to_plot = [
    'ASYN_PATH', 'AT8_PATH', 'ABETA_PATH',
    'Age_at_death_Continuous', 'Disease_duration_CONTINUOUS'
]

y_labels = {
    'ASYN_PATH': r'LB/mm$^2$',
    'AT8_PATH': r'%pTau',
    'ABETA_PATH': r'%Aβ',
    'Age_at_death_Continuous': r'Age at death (years)',
    'Disease_duration_CONTINUOUS': r'Disease duration (years)'
}

group_var = 'MPP_Subtype'

# Map old subtype labels → new ones
subtype_label_map = {
    "MPP Subtype 1": "Subtype 1",
    "MPP Subtype 2": "Subtype 2",
    "MPP Subtype 3": "Subtype 3",
    "MPP Subtype 4": "Subtype 4"
}

# Apply new labels
zdata[group_var] = zdata[group_var].map(subtype_label_map)

# Updated order
subtype_order = ["Subtype 1", "Subtype 2", "Subtype 3", "Subtype 4"]

stats_summary = []

# Plot loop
for var in variables_to_plot:
    plt.figure(figsize=(8, 6))

    # Violin + strip
    ax = sns.violinplot(
        data=zdata, x=group_var, y=var, hue=group_var,
        order=subtype_order,
        palette='Set2', legend=False, inner=None
    )
    sns.stripplot(
        data=zdata, x=group_var, y=var,
        order=subtype_order,
        color='black', size=3, jitter=0.2, ax=ax
    )

    ax.set_ylabel(y_labels[var], fontsize=20)
    ax.set_xlabel("MPP Subtype", fontsize=20)
    ax.tick_params(axis='x', labelsize=20)
    ax.tick_params(axis='y', labelsize=20)

    # Kruskal-Wallis
    groups = [group[var].dropna().values for name, group in zdata.groupby(group_var)]
    kw_stat, kw_p = stats.kruskal(*groups)
    stats_summary.append({
        'Variable': var,
        'Test': 'Kruskal-Wallis',
        'Statistic': kw_stat,
        'p-value': kw_p
    })

    # Dunn test (BH-adjusted)
    dunn = sp.posthoc_dunn(zdata, val_col=var, group_col=group_var, p_adjust='fdr_bh')

    # Annotate significant pairs
    pairs = [(i, j) for i in dunn.index for j in dunn.columns if i < j and dunn.loc[i, j] < 0.05]
    if pairs:
        pvals = [dunn.loc[i, j] for i, j in pairs]

        annotator = Annotator(ax, pairs, data=zdata, x=group_var, y=var, order=subtype_order)
        annotator.configure(
            test=None,
            verbose=False,
            loc="outside",
            text_format="star",
            line_offset_to_group=0.2,
            fontsize=20
        )
        annotator.set_pvalues(pvals)
        annotator.annotate()

    sns.despine(top=True, right=True)
    plt.tight_layout()

    for ext in ['png', 'pdf', 'svg']:
        plt.savefig(os.path.join(outdir, f"violin_{var}.{ext}"), dpi=300)
    plt.show()
    plt.close()