# Clinical Data Analysis for CJD Subtypes
========================================================

This notebook performs comprehensive statistical analysis and visualisation of clinical data for different 
Sporadic Creutzfeldt-Jakob Disease (sCJD) subtypes [MM(V)1, VV2, MV2K] compared to controls (CTRL).

Analysis Components:
------------------
1. Data Preprocessing:
   - Imports clinical data from Excel file
   - Imports normalises Olink data and checks for missing values
   - Formats columns appropriately (Int64 for categorical variables)
   - Performs data type verification and unique value checks

2. Statistical Analysis:
   - Normality testing for continuous variables using:
     * Shapiro-Wilk test
     * Kolmogorov-Smirnov test
   - Demographics table generation including:
     * Basic demographics (age, sex)
     * Disease timeline metrics (onset-LP, onset-death, LP-death)
     * Clinical symptoms and diagnostic markers
     * Statistical comparisons between groups using appropriate tests:
       - ANOVA for normal distributions
       - Kruskal-Wallis for non-normal distributions
       - Chi-square/Fisher's exact for categorical variables

3. Visualizations:
   - Kaplan-Meier survival analysis:
     * Survival curves by CJD subtype
     * Log-rank tests for between-group comparisons
   - Clinical variable distribution:
     * 3x4 grid of countplots showing presence/absence of clinical features
     * Includes all major symptoms and diagnostic markers
     * Stratified by CJD subtype

Outputs:
-------
- Comprehensive demographics table with statistical comparisons
- Survival analysis plot (kaplan_meier_cjd_subtype.png)
- Clinical variables distribution plot (clinical_variables_cjd_subtype.png)

In [None]:
# Standard library imports
import os
import warnings

# Third-party imports
import numpy as np
import pandas as pd

# Statistical analysis
from scipy import stats
from scipy.stats import (
    chi2_contingency,
    f_oneway,
    fisher_exact,
    kruskal,
    mannwhitneyu,
    kstest,
    norm
)
import scikit_posthocs as sp

# Visualization
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns

# Survival analysis
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test

# Jupyter specific
%matplotlib inline

# Suppress warnings
warnings.filterwarnings("ignore")

In [2]:
# Define path
data_path = os.path.dirname(os.getcwd()) + '/data'
figure_path = os.path.dirname(os.getcwd()) + '/figures/demographics'

### Exploring clinical variables

In [3]:
# import and visualize dataframe
df = pd.read_excel(data_path + '/raw/clinical_data.xlsx')

# Print information on columns
# print(df.info())

In [4]:
# format columns
df[['age at LP', 'Sex', 'psychiatric', 'dementia', 'myoclonus', 'pyramidal', 'extrapyramidal', 'cerebellar', 'visual', 'akinetic mutism', 'PSWC', 'positive 14-3-3 WB', 'increased t-tau', 'positive MRI', '14-3-3 ELISA', 't-tau', 'NfL']] = df[['age at LP', 'Sex', 'psychiatric', 'dementia', 'myoclonus', 'pyramidal', 'extrapyramidal', 'cerebellar', 'visual', 'akinetic mutism', 'PSWC', 'positive 14-3-3 WB', 'increased t-tau', 'positive MRI', '14-3-3 ELISA', 't-tau', 'NfL']].astype("Int64")
df[['CSF ISS code']]=df[['CSF ISS code']].astype("object")

# print(df.info())

In [None]:
# Check number of unique values in categorical columns:

for col in ['group', 'CJD subtype', 'subtype', 'Sex', 'diagnosis', 'Codon 129', 'psychiatric', 'dementia', 'myoclonus', 'pyramidal', 'extrapyramidal', 'cerebellar', 'visual', 'akinetic mutism', 'PSWC', 'positive 14-3-3 WB', 'increased t-tau', 'positive MRI']:
    print(f"\nUnique values in {col}:")
    print(df[col].value_counts())

### Check the normality of numeric variables with both Shapiro-Wilk and Kolmogorov-Smirnov tests
Shapiro-Wilk is generally more powerful for smaller sample sizes. 

In [None]:
def test_normality(df, variables):
    results = {}
    for var in variables:
        # Drop NaN values
        data = df[var].dropna()
        
        # Perform Shapiro-Wilk test
        shapiro_stat, shapiro_p = stats.shapiro(data)
        
        # Perform Kolmogorov-Smirnov test
        mean = data.mean()
        std = data.std()
        ks_stat, ks_p = stats.kstest(data, 'norm', args=(mean, std))
        
        # Store results
        results[var] = {
            'Shapiro-Wilk': {'statistic': shapiro_stat, 'p-value': shapiro_p},
            'Kolmogorov-Smirnov': {'statistic': ks_stat, 'p-value': ks_p}
        }
    
    # Print results
    for var, tests in results.items():
        print(f"\nVariable: {var}")
        print("-" * 50)
        
        # Print Shapiro-Wilk results
        sw_p = tests['Shapiro-Wilk']['p-value']
        print(f"Shapiro-Wilk Test:")
        print(f"Statistic: {tests['Shapiro-Wilk']['statistic']:.4f}")
        print(f"P-value: {sw_p:.4f}")
        print(f"Interpretation: {'Normal' if sw_p > 0.05 else 'Non-normal'} distribution")
        
        # Print Kolmogorov-Smirnov results
        ks_p = tests['Kolmogorov-Smirnov']['p-value']
        print(f"\nKolmogorov-Smirnov Test:")
        print(f"Statistic: {tests['Kolmogorov-Smirnov']['statistic']:.4f}")
        print(f"P-value: {ks_p:.4f}")
        print(f"Interpretation: {'Normal' if ks_p > 0.05 else 'Non-normal'} distribution")
        
        # Print final interpretation if tests disagree
        if (sw_p > 0.05) != (ks_p > 0.05):
            print("\nNote: Tests disagree on normality.")
            print("Consider checking Q-Q plots and histograms for visual inspection.")
        
        print("-" * 50)
    
    return results

# Define the variables to check for normality
variables_to_check = ['age at LP', 'onset-LP', 'onset-death', 'LP-death']

# Perform normality tests
normality_results = test_normality(df, variables_to_check)

## Summary statistics

Her we create a comprehensive demographics table for comparing different groups of patients, specifically focusing on CJD (Creutzfeldt-Jakob Disease) subtypes. Here's what it does:

1. Main Function Purpose:
- Creates a statistical comparison table between different CJD subtypes (MM(V)1, VV2, MV2K) and a control group (CTRL)
- Handles both categorical and continuous variables
- Calculates appropriate statistical tests based on data type

2. Key Components:
Helper Functions:
- `get_categorical_summary`: Calculates proportions and percentages for categorical variables
- `get_continuous_summary_nd`: Handles normally distributed data (mean ± SD)
- `get_continuous_summary_nnd`: Handles non-normally distributed data (median and IQR)
- `get_available_n`: Tracks available data points vs. total sample size

Statistical Testing (`calculate_p_value`):
- Implements three types of tests:
  - ANOVA for normally distributed continuous data
  - Kruskal-Wallis for non-normally distributed continuous data
  - Chi-square/Fisher's exact for categorical data
- Automatically switches between Chi-square and Fisher's exact test based on expected frequencies

In [7]:
def create_demographics_table(df):
    # Create a copy of the dataframe to avoid modifying the original
    df = df.copy()
    
    # Create analysis groups
    df['analysis_group'] = df.apply(lambda x: x['CJD subtype'] if x['group'] == 'CJD' else x['group'], axis=1)
    
    # Define the subgroups in the desired order
    subgroups = ['MM(V)1', 'VV2', 'MV2K', 'CTRL']
    subgroups_no_ctrl = ['MM(V)1', 'VV2', 'MV2K']

    def get_categorical_summary(data):
        n_valid = len(data.dropna())
        if n_valid == 0:
            return "NA"
        count = data.dropna().sum()
        percentage = (count / n_valid) * 100
        return f"{count}/{n_valid} ({percentage:.1f})"

    def get_continuous_summary_nd(data):
        if len(data.dropna()) == 0:
            return "NA"
        mean = data.mean()
        std = data.std()
        return f"{mean:.1f} ± {std:.1f}"

    def get_continuous_summary_nnd(data):
        if len(data.dropna()) == 0:
            return "NA"
        median = data.median()
        q25 = data.quantile(0.25)
        q75 = data.quantile(0.75)
        return f"{median:.1f} ({q25:.1f} - {q75:.1f})"

    def get_available_n(data):
        n_total = len(data)
        n_available = len(data.dropna())
        return str(n_available) if n_total == n_available else f"{n_available}/{n_total}"
    
    def calculate_p_value(data_list, test_type, groups=subgroups_no_ctrl):
        try:
            if test_type == "anova":
                stat, p_value = f_oneway(*data_list)
                
            elif test_type == "kruskal":
                stat, p_value = kruskal(*data_list)
                
            elif test_type == "chi2_fisher":
                contingency_table = np.array(data_list)
                _, _, _, expected = chi2_contingency(contingency_table)
                
                # Check if chi-square conditions are met:
                # "at least 80% of cells have expected frequency ≥ 5"
                chi2_conditions_met = (expected >= 5).sum() / expected.size >= 0.8
                
                if chi2_conditions_met:
                    # If conditions are met, use chi-square test
                    _, p_value, _, _ = chi2_contingency(contingency_table)
                else:
                    # If conditions not met, do pairwise Fisher's exact tests
                    p_values = []
                    
                    # Compare each pair of groups (e.g., for 3 groups: 1vs2, 1vs3, 2vs3)
                    for i in range(len(data_list)):
                        for j in range(i + 1, len(data_list)):
                            # Create 2x2 contingency table for each pair
                            pair_table = np.array([data_list[i], data_list[j]])
                            
                            try:
                                # Use Fisher's exact test for each 2x2 table
                                _, p = fisher_exact(pair_table)
                                p_values.append(p)
                            except Exception:
                                p_values.append(np.nan)
                    
                    # Use the minimum p-value from pairwise comparisons
                    p_value = np.nanmin(p_values) if p_values else np.nan
            
            # Format p-value
            if np.isnan(p_value):
                return "NA"
            elif p_value < 0.001:
                return "<0.001"
            return round(p_value, 3)
            
        except Exception as e:
            print(f"Error in calculate_p_value: {str(e)}")  # For debugging
            return "NA"

    # Initialize rows list
    rows = []
    
    # Sample size
    rows.append(["N", *[str(len(df[df['analysis_group'] == group])) for group in subgroups], "NA"])

    # Add sex statistics
    female_stats = []
    sex_contingency = []
    for group in subgroups:
        group_data = df[df['analysis_group'] == group]
        females = (group_data['Sex'] == 1).sum()
        n_valid = len(group_data['Sex'].dropna())
        if n_valid == 0:
            female_stats.append("NA")
        else:
            percentage = (females / n_valid) * 100
            female_stats.append(f"{females}/{n_valid} ({percentage:.1f})")
        sex_contingency.append([females, n_valid - females])
    p_value_sex = calculate_p_value(sex_contingency, "chi2_fisher")
    rows.append(["Female, n/N (%)", *female_stats, p_value_sex])

    # Add age statistics
    age_stats = []
    age_data = [df[df['analysis_group'] == group]['age at LP'].dropna() for group in subgroups]
    for group_data in age_data:
        n_available = get_available_n(group_data)
        age_stats.append(f"{get_continuous_summary_nd(group_data)} [n={n_available}]")
    p_value_age = calculate_p_value(age_data, "anova")
    rows.append(["Age at LP, mean ± SD [n]", *age_stats, p_value_age])

    # Add onset-LP statistics
    onset_lp_stats = []
    onset_lp_data = [df[df['analysis_group'] == group]['onset-LP'].dropna() for group in subgroups_no_ctrl]
    for group_data in onset_lp_data:
        n_available = get_available_n(group_data)
        onset_lp_stats.append(f"{get_continuous_summary_nnd(group_data)} [n={n_available}]")
    onset_lp_stats.append("NA")  # Add NA for CTRL group
    p_value_onset_lp = calculate_p_value(onset_lp_data, "kruskal")
    rows.append(["Time onset to LP (months), median (IQR) [n]", *onset_lp_stats, p_value_onset_lp])

    # Add disease duration statistics
    duration_stats = []
    duration_data = [df[df['analysis_group'] == group]['onset-death'].dropna() for group in subgroups_no_ctrl]
    for group_data in duration_data:
        n_available = get_available_n(group_data)
        duration_stats.append(f"{get_continuous_summary_nnd(group_data)} [n={n_available}]")
    duration_stats.append("NA")  # Add NA for CTRL group
    p_value_duration = calculate_p_value(duration_data, "kruskal")
    rows.append(["Disease duration (months), median (IQR) [n]", *duration_stats, p_value_duration])

    # Add LP-death statistics
    lp_death_stats = []
    lp_death_data = [df[df['analysis_group'] == group]['LP-death'].dropna() for group in subgroups_no_ctrl]
    for group_data in lp_death_data:
        n_available = get_available_n(group_data)
        lp_death_stats.append(f"{get_continuous_summary_nnd(group_data)} [n={n_available}]")
    lp_death_stats.append("NA")  # Add NA for CTRL group
    p_value_lp_death = calculate_p_value(lp_death_data, "kruskal")
    rows.append(["Time LP to death (months), median (IQR) [n]", *lp_death_stats, p_value_lp_death])

    # Add categorical variables
    categorical_vars = [
        'psychiatric', 'dementia', 'myoclonus', 'pyramidal', 'extrapyramidal', 
        'cerebellar', 'visual', 'akinetic mutism', 'PSWC', 'positive 14-3-3 WB', 
        'increased t-tau', 'positive MRI'
    ]

    for var in categorical_vars:
        categorical_stats = []
        contingency_table = []
        for group in subgroups_no_ctrl:
            group_data = df[df['analysis_group'] == group][var]
            categorical_stats.append(get_categorical_summary(group_data))
            n_valid = len(group_data.dropna())
            if n_valid > 0:
                contingency_table.append([group_data.sum(), n_valid - group_data.sum()])
        
        p_value = calculate_p_value(contingency_table, "chi2_fisher")
        categorical_stats.append("NA")  # Add NA for CTRL group
        rows.append([f"{var.capitalize()}, n/N (%)", *categorical_stats, p_value])

    # Create the final DataFrame
    result_df = pd.DataFrame(rows, columns=['Parameter', *subgroups, 'p-value'])
    
    return result_df

In [None]:
demographics_table = create_demographics_table(df)
demographics_table.to_csv(data_path + '/results/demographics_summary.csv')
demographics_table

### Evaluate disease stage across sCJD subtypes

In [None]:
df["disease_stage"] = df["onset-LP"] / df["onset-death"]

groups_unique = df["SubGroup"].dropna().unique()
groups_data = [df.loc[df["SubGroup"] == g, "disease_stage"].dropna() for g in groups_unique]

# Kruskal-Wallis test
kruskal_stat, kruskal_p = kruskal(*groups_data)
print(f"Kruskal-Wallis test: H = {kruskal_stat:.4f}, p = {kruskal_p:.4e}")

# Post-hoc Dunn test
dunn_res = sp.posthoc_dunn(df, val_col="disease_stage", group_col="SubGroup", p_adjust="bonferroni")

# Save as excel
output_dir = os.path.join("..", "data", "results")
#os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "disease_stage_kruskal_dunn.xlsx")

with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
    # Median (IQR) by subtype
    summary_stage = df.groupby("SubGroup")["disease_stage"].agg(
        Median="median",
        Q1=lambda x: x.quantile(0.25),
        Q3=lambda x: x.quantile(0.75)
    ).reset_index()
    summary_stage["IQR"] = summary_stage["Q1"].round(3).astype(str) + " - " + summary_stage["Q3"].round(3).astype(str)
    summary_stage.drop(columns=["Q1", "Q3"], inplace=True)
    summary_stage.to_excel(writer, sheet_name="Median_IQR_by_SubGroup", index=False)

    # Kruskal-Wallis
    pd.DataFrame({
        "H_statistic": [kruskal_stat],
        "p_value": [kruskal_p]
    }).to_excel(writer, sheet_name="Kruskal_Wallis", index=False)

    # Dunn test
    dunn_res.to_excel(writer, sheet_name="Dunn_posthoc")

### Survival analyses among CJD subtypes

In [None]:
def plot_km_curves(df, figure_path):
    # Create figure
    fig, ax = plt.subplots(figsize=(8, 6))
    kmf = KaplanMeierFitter()
    
    # Define colors for each subtype
    colors = {
        'MV2K': '#2ecc71',    # green
        'VV2': '#e74c3c',     # red
        'MM(V)1': '#3498db'   # blue
    }
    
    # Add event indicator and filter for valid disease duration
    df['event'] = 1
    df = df[df['LP-death'] > 0]

    # Plot KM curves
    for subtype in df['CJD subtype'].unique():
        mask = df['CJD subtype'] == subtype
        kmf.fit(
            df.loc[mask, 'LP-death'],
            df.loc[mask, 'event'],
            label=f"{subtype} (n={sum(mask)})"
        )
        kmf.plot(ci_show=True, color=colors[subtype], ax=ax)

    # Perform log-rank tests and add results to plot
    subtypes = list(df['CJD subtype'].unique())
    test_results = []
    y_position = 0.35
    
    for i in range(len(subtypes)):
        for j in range(i+1, len(subtypes)):
            s1 = df[df['CJD subtype'] == subtypes[i]]
            s2 = df[df['CJD subtype'] == subtypes[j]]
            lr_test = logrank_test(
                s1['LP-death'],
                s2['LP-death'],
                s1['event'],
                s2['event']
            )
            # Format p-value text
            if lr_test.p_value < 0.001:
                p_text = 'p < 0.001'
            else:
                p_text = f'p = {lr_test.p_value:.3f}'
                
            # Add text to plot (moved to right side)
            plt.text(0.35, y_position, 
                    f'{subtypes[i]} vs {subtypes[j]}: {p_text}',
                    transform=ax.transAxes,
                    fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
            y_position -= 0.05

    # Customize plot
    plt.title('Survival by sCJD Subtype', pad=20, fontsize=14)
    plt.xlabel('Time (months)', fontsize=12)
    plt.ylabel('Survival Probability', fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Move legend to right side, outside plot
    plt.legend(bbox_to_anchor=(0.7, 0.96), loc='upper left', frameon=True, framealpha=0.8)
    
    # Adjust layout to make room for legend and text on right
    plt.subplots_adjust(right=0.85)
    
    # Save and show
    plt.savefig(figure_path + '/kaplan_meier_cjd_subtype.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    return fig

fig = plot_km_curves(df, figure_path)
fig

### Plot clinical variables

In [None]:
# Create figure
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(20, 18))
axes = axes.flatten()

# Define colors and variables
palette = ['#8ecae6', '#219ebc']  # Light blue for No, Darker blue for Yes
clinical_variables = ['psychiatric', 'dementia', 'myoclonus', 'pyramidal', 
                     'extrapyramidal', 'cerebellar', 'visual', 'akinetic mutism',
                     'PSWC', 'positive 14-3-3 WB', 'increased t-tau', 'positive MRI']

var_names = {
    'psychiatric': 'Psychiatric Symptoms', 'dementia': 'Dementia',
    'myoclonus': 'Myoclonus', 'pyramidal': 'Pyramidal Signs',
    'extrapyramidal': 'Extrapyramidal Signs', 'cerebellar': 'Cerebellar Signs',
    'visual': 'Visual Disturbances', 'akinetic mutism': 'Akinetic Mutism',
    'PSWC': 'PSWC at EEG', 'positive 14-3-3 WB': '14-3-3 WB Positive',
    'increased t-tau': 't-tau Positive', 'positive MRI': 'MRI Positive'
}

# Plot each variable
for i, var in enumerate(clinical_variables):
    data_to_plot = df[df['CJD subtype'] != 'CTRL']
    
    sns.countplot(data=data_to_plot, x='CJD subtype', hue=var, 
                 ax=axes[i], palette=palette, legend=False)
    
    # Add percentage labels
    for container in axes[i].containers:
        for patch in enumerate(container.patches):
            subtype = data_to_plot['CJD subtype'].unique()[patch[0] // 2]
            subtype_total = len(data_to_plot[data_to_plot['CJD subtype'] == subtype])
            percentage = (patch[1].get_height() / subtype_total) * 100
            if percentage >= 5:
                axes[i].text(patch[1].get_x() + patch[1].get_width()/2.,
                           patch[1].get_height(),
                           f'{percentage:.0f}%',
                           ha='center', va='bottom', fontsize=10)
    
    # Add chi-square test p-value
    contingency_table = pd.crosstab(data_to_plot['CJD subtype'], data_to_plot[var])
    p_value = chi2_contingency(contingency_table)[1]
    p_value_text = f'p < 0.001' if p_value < 0.001 else f'p = {p_value:.3f}'
    axes[i].text(0.95, 0.95, p_value_text, transform=axes[i].transAxes,
                ha='right', va='top', fontsize=10,
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
    
    # Style the subplot
    axes[i].set_title(var_names[var], fontsize=12, pad=10)
    axes[i].set_ylabel('Number of Cases', fontsize=10)
    axes[i].set_xlabel('')
    axes[i].tick_params(axis='x', rotation=0, labelsize=10)
    axes[i].tick_params(axis='y', labelsize=10)
    axes[i].yaxis.grid(True, linestyle='--', alpha=0.7)
    axes[i].set_axisbelow(True)
    axes[i].set_ylim(0, 48)

# Add legend and title
handles = [Rectangle((0, 0), 1, 1, color=c) for c in palette]
fig.legend(handles, ['Absent', 'Present'],
          loc='center', bbox_to_anchor=(0.5, 0.05),
          ncol=2, fontsize=12, title='Clinical Feature', frameon=True)

fig.suptitle('Clinical Features Across sCJD Subtypes', fontsize=16, y=0.95)
plt.tight_layout()
plt.subplots_adjust(top=0.9, bottom=0.1)

plt.savefig(figure_path + '/clinical_features_cjd_subtype.png', dpi=1200, bbox_inches='tight')
plt.show()