In [None]:
# -*- coding: utf-8 -*-
"""
Longitudinal trajectory clustering analysis of behavior_score (with covariate adjustment)
- Fixed number of clusters: 3
- Re-order group labels based on mean behavior_score (low → high = 0 → 2)
- Output wide-format data
- Analyze association between group and dementia (dementia_final)
- Automatically detect CSV delimiter
"""

import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.interpolate import interp1d
from scipy.signal import savgol_filter
from scipy.cluster.hierarchy import dendrogram, linkage

from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LinearRegression

import statsmodels.api as sm

# =========================
# Paths
# =========================
base_dir = r"./data/F8/"
output_dir = r"./plot/F8/"

main_csv = os.path.join(base_dir, "trajectory_charls.csv")  # Replace with other cohort data if needed

# Dataset name (used in output filenames)
DATASET_NAME = "CHARLS"

# ★ Score variable name
SCORE_VAR = "behavior_score"

# =========================
# ★ Fixed number of clusters = 3
# =========================
FIXED_K = 3

# =========================
# ★ Function to auto-detect CSV delimiter
# =========================
def read_csv_auto(filepath, name="self_data"):
    """Automatically detect delimiter and read CSV file"""
    print(f"\n--- Reading {name}: {os.path.basename(filepath)} ---")
    
    # Peek at the first line
    with open(filepath, 'r', encoding='utf-8') as f:
        first_line = f.readline()
    
    # Detect delimiter
    if '\t' in first_line and ',' not in first_line:
        sep = '\t'
        sep_name = "tab"
    elif ',' in first_line:
        sep = ','
        sep_name = "comma"
    elif ';' in first_line:
        sep = ';'
        sep_name = "semicolon"
    else:
        sep = ','
        sep_name = "default comma"
    
    print(f"  Detected delimiter: {sep_name}")
    
    df = pd.read_csv(filepath, sep=sep)
    print(f"  Rows: {len(df)}, Columns: {len(df.columns)}")
    print(f"  Column names: {list(df.columns)}")
    
    return df

# =========================
# Load data
# =========================
print("=== Loading data ===")

# Main dataset (single file)
df = read_csv_auto(main_csv, "main dataset")

# Check for dementia_status
if 'dementia_status' in df.columns:
    print(f"  ★ dementia_status found. Distribution: {df['dementia_status'].value_counts().to_dict()}")
else:
    print("  ⚠ dementia_status not found in data")

# Check score variable
if SCORE_VAR in df.columns:
    print(f"  ★ {SCORE_VAR} found")
else:
    print(f"  ⚠ {SCORE_VAR} not found! Please check column names")

# =========================
# Data preprocessing
# =========================
print("\n=== Data preprocessing ===")

df['id'] = df['id'].astype(str)
df['wave'] = pd.to_numeric(df['wave'], errors='coerce')

# Record wave range of the cohort
WAVE_MIN = df['wave'].min()
WAVE_MAX = df['wave'].max()
print(f"\n★ Cohort wave range: {WAVE_MIN} ~ {WAVE_MAX}")

# =========================
# Data cleaning
# =========================
print("\n=== Data cleaning ===")

# ★ Covariates used for adjustment: age, sex, bmi
numeric_cols = [SCORE_VAR, 'age', 'bmi', 'dementia_status']
for col in numeric_cols:
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors='coerce')

key_cols = ['id', 'wave', SCORE_VAR]
df = df.dropna(subset=key_cols).copy()
print(f"After removing missing {SCORE_VAR}: {len(df)} rows")

# =========================
# Handle duplicates
# =========================
def mode_or_first(x):
    m = x.mode()
    return m.iloc[0] if len(m) > 0 else x.iloc[0]

agg_dict = {SCORE_VAR: 'mean'}
if 'age' in df.columns:
    agg_dict['age'] = 'mean'
if 'sex' in df.columns:
    agg_dict['sex'] = mode_or_first
if 'bmi' in df.columns:
    agg_dict['bmi'] = 'mean'
if 'dementia_status' in df.columns:
    agg_dict['dementia_status'] = 'max'

df = df.groupby(['id', 'wave'], as_index=False).agg(agg_dict)
print(f"After duplicate handling: {len(df)} rows")

measurement_counts = df.groupby('id').size()
ids_with_at_least_2 = measurement_counts[measurement_counts >= 2].index
df = df[df['id'].isin(ids_with_at_least_2)].copy()
print(f"IDs with ≥2 measurements: {len(ids_with_at_least_2)} individuals")
print(f"Final dataset: {len(df)} rows")

df = df.sort_values(['id', 'wave']).reset_index(drop=True)

# =========================
# Covariate adjustment
# =========================
print("\n=== Covariate adjustment ===")

# ★ Adjust only for age, sex, bmi
adjustment_vars = ['age', 'sex', 'bmi']
available_adj_vars = [v for v in adjustment_vars if v in df.columns]
print(f"Adjustment covariates: {available_adj_vars}")

df_for_regression = df.dropna(subset=available_adj_vars + [SCORE_VAR]).copy()
print(f"After removing covariate missing values: {len(df_for_regression)} rows, "
      f"{df_for_regression['id'].nunique()} individuals")

transformers = []
if 'sex' in available_adj_vars:
    transformers.append(('sex_onehot', OneHotEncoder(drop='if_binary', handle_unknown='ignore'), ['sex']))

ct = ColumnTransformer(transformers=transformers, remainder='passthrough')

X = ct.fit_transform(df_for_regression[available_adj_vars])
y = df_for_regression[SCORE_VAR].values

linreg = LinearRegression()
linreg.fit(X, y)
y_pred = linreg.predict(X)

df_for_regression[f'{SCORE_VAR}_adj'] = y - y_pred

print(f"Adjustment model R² = {linreg.score(X, y):.4f}")

df = df_for_regression.copy()

# =========================
# Interpolate to uniform time grid
# =========================
print("\n=== Trajectory feature extraction ===")

time_grid = np.linspace(WAVE_MIN, WAVE_MAX, num=20)

def interpolate_trajectory(group, score_col):
    if len(group) < 2:
        return np.full(len(time_grid), group[score_col].values[0])
    f = interp1d(group['wave'], group[score_col], kind='linear', fill_value='extrapolate')
    return f(time_grid)

unique_ids = df['id'].unique()

interpolated_original = []
interpolated_adjusted = []
smoothed_trajectories = []
derivative_features = []

for id_val in unique_ids:
    g = df[df['id'] == id_val].sort_values('wave')
    
    interp_orig = interpolate_trajectory(g, SCORE_VAR)
    interpolated_original.append(interp_orig)
    
    interp_adj = interpolate_trajectory(g, f'{SCORE_VAR}_adj')
    interpolated_adjusted.append(interp_adj)

    if len(interp_adj) >= 5:
        smoothed = savgol_filter(interp_adj, window_length=5, polyorder=2)
    else:
        smoothed = interp_adj

    deriv = np.gradient(smoothed, time_grid)
    smoothed_trajectories.append(smoothed)
    derivative_features.append(deriv)

interpolated_original = np.array(interpolated_original)
interpolated_adjusted = np.array(interpolated_adjusted)
X_smoothed = np.array(smoothed_trajectories)
X_deriv = np.array(derivative_features)
X_feat = np.hstack([X_smoothed, X_deriv])

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_feat)

print(f"Feature matrix: {X_scaled.shape[0]} individuals × {X_scaled.shape[1]} features")

# =========================
# Hierarchical clustering
# =========================
print("\n=== Hierarchical clustering (fixed k=3) ===")

linkage_matrix = linkage(X_scaled, method='ward', metric='euclidean')

plt.figure(figsize=(12, 8))
dendrogram(linkage_matrix, truncate_mode='lastp', p=30, leaf_rotation=90)
plt.title(f'Dendrogram - {DATASET_NAME} {SCORE_VAR} Trajectory Clustering', fontsize=14)
plt.xlabel('Sample Index / Cluster')
plt.ylabel('Distance')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'dendrogram_{DATASET_NAME}_{SCORE_VAR}_adjusted_k3.pdf'),
            format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(output_dir, f'dendrogram_{DATASET_NAME}_{SCORE_VAR}_adjusted_k3.png'),
            format='png', dpi=300, bbox_inches='tight')
plt.close()

print("\nSilhouette scores for reference:")
k_range = range(2, 7)
sil_scores = []
for k in k_range:
    model_k = AgglomerativeClustering(n_clusters=k, linkage='ward')
    labels_k = model_k.fit_predict(X_scaled)
    sil = silhouette_score(X_scaled, labels_k)
    sil_scores.append((k, sil))
    marker = " ← selected" if k == FIXED_K else ""
    print(f"  k={k}: Silhouette = {sil:.4f}{marker}")

hierarchical = AgglomerativeClustering(n_clusters=FIXED_K, linkage='ward')
cluster_labels = hierarchical.fit_predict(X_scaled)

# =========================
# ★ Re-order groups by mean behavior_score
# =========================
print(f"\n=== ★ Re-order groups by mean {SCORE_VAR} ===")

group_means = {}
for g in range(FIXED_K):
    group_ids = unique_ids[cluster_labels == g]
    group_data = df[df['id'].isin(group_ids)]
    group_means[g] = group_data[SCORE_VAR].mean()

print("Original group means:")
for g, m in sorted(group_means.items()):
    print(f"  Original Group {g}: mean {SCORE_VAR} = {m:.3f}")

sorted_groups = sorted(group_means.keys(), key=lambda x: group_means[x])
old_to_new = {old: new for new, old in enumerate(sorted_groups)}

print("\nRemapping:")
for old in sorted(old_to_new.keys()):
    new = old_to_new[old]
    print(f"  Original Group {old} → New Group {new}")

cluster_labels = np.array([old_to_new[g] for g in cluster_labels])

print("\n★ After re-ordering:")
print("  Group 0 = Low level (low protection → high dementia risk)")
print("  Group 1 = Medium level")
print("  Group 2 = High level (high protection → low dementia risk, reference)")

# =========================
# Write group labels back
# =========================
id_to_group = dict(zip(unique_ids, cluster_labels))
df['group'] = df['id'].map(id_to_group)

cluster_sizes = pd.Series(cluster_labels).value_counts().sort_index()
print("\nCluster sizes:")
for g, n in cluster_sizes.items():
    print(f"  Group {g}: n = {n} ({n/len(cluster_labels)*100:.1f}%)")

# =========================
# Generate wide-format data
# =========================
print("\n=== Generating wide-format data ===")

all_waves = sorted(df['wave'].unique())

# Pivot behavior_score to wide format
df_score_wide = df.pivot(index='id', columns='wave', values=SCORE_VAR)
df_score_wide.columns = [f'{SCORE_VAR}_wave{int(w)}' for w in df_score_wide.columns]
df_score_wide = df_score_wide.reset_index()

# Pivot adjusted score to wide format
df_score_adj_wide = df.pivot(index='id', columns='wave', values=f'{SCORE_VAR}_adj')
df_score_adj_wide.columns = [f'{SCORE_VAR}_adj_wave{int(w)}' for w in df_score_adj_wide.columns]
df_score_adj_wide = df_score_adj_wide.reset_index()

# Baseline data
df_baseline = df.sort_values(['id', 'wave']).groupby('id').first().reset_index()
baseline_cols = ['id', 'group']
for var in available_adj_vars:
    if var in df_baseline.columns:
        baseline_cols.append(var)
df_baseline_selected = df_baseline[baseline_cols].copy()

rename_dict = {var: f'{var}_baseline' for var in available_adj_vars if var in df_baseline_selected.columns}
df_baseline_selected = df_baseline_selected.rename(columns=rename_dict)

# ★ dementia_status (take max per person)
print("\n--- Extracting dementia_status ---")
if 'dementia_status' in df.columns:
    df_dementia = df.groupby('id')['dementia_status'].max().reset_index()
    df_dementia = df_dementia[df_dementia['id'].isin(unique_ids)]
    print(f"  dementia_status extracted: {len(df_dementia)} individuals")
    print(f"  Distribution: {df_dementia['dementia_status'].value_counts().to_dict()}")
else:
    df_dementia = pd.DataFrame({'id': unique_ids})
    print("  ⚠ dementia_status not found in data")

# Trajectory features
trajectory_features = []
for i, id_val in enumerate(unique_ids):
    g = df[df['id'] == id_val].sort_values('wave')
    trajectory_features.append({
        'id': id_val,
        'n_measurements': len(g),
        'wave_min': g['wave'].min(),
        'wave_max': g['wave'].max(),
        'follow_up_duration': g['wave'].max() - g['wave'].min(),
        f'{SCORE_VAR}_mean': g[SCORE_VAR].mean(),
        f'{SCORE_VAR}_std': g[SCORE_VAR].std(),
        f'{SCORE_VAR}_first': g[SCORE_VAR].iloc[0],
        f'{SCORE_VAR}_last': g[SCORE_VAR].iloc[-1],
        f'{SCORE_VAR}_change': g[SCORE_VAR].iloc[-1] - g[SCORE_VAR].iloc[0],
    })

df_traj_features = pd.DataFrame(trajectory_features)

# Merge all
df_wide = df_baseline_selected.merge(df_traj_features, on='id', how='left')
if 'dementia_status' in df_dementia.columns:
    df_wide = df_wide.merge(df_dementia, on='id', how='left')
df_wide = df_wide.merge(df_score_wide, on='id', how='left')
df_wide = df_wide.merge(df_score_adj_wide, on='id', how='left')

print(f"Wide-format data: {len(df_wide)} rows × {len(df_wide.columns)} columns")
print(f"dementia_status present: {'dementia_status' in df_wide.columns}")

out_wide_csv = os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_with_trajectory_group_adjusted_k3.csv')
df_wide.to_csv(out_wide_csv, index=False, encoding='utf-8-sig')
print(f"Saved: {out_wide_csv}")

# =========================
# ★★★ Association between Group and Dementia ★★★
# =========================
print("\n" + "="*60)
print("=== ★ Association between Group and Dementia (dementia_status) ===")
print("="*60)

if 'dementia_status' not in df_wide.columns:
    print("\n⚠ Error: dementia_status variable not found! Skipping dementia analysis")
else:
    df_analysis = df_wide.dropna(subset=['group', 'dementia_status']).copy()
    df_analysis['dementia_status'] = df_analysis['dementia_status'].astype(int)
    df_analysis['group'] = df_analysis['group'].astype(int)
    
    print(f"\nAnalysis sample size: {len(df_analysis)} individuals")
    print(f"Dementia events: {df_analysis['dementia_status'].sum()} ({df_analysis['dementia_status'].mean()*100:.1f}%)")
    
    # 1. Dementia incidence by group
    print("\n--- 1. Dementia incidence by group ---")
    dementia_by_group = df_analysis.groupby('group').agg({
        'dementia_status': ['count', 'sum', 'mean']
    }).round(4)
    dementia_by_group.columns = ['n', 'dementia_cases', 'dementia_rate']
    dementia_by_group['dementia_pct'] = (dementia_by_group['dementia_rate'] * 100).round(2)
    print(dementia_by_group)
    
    # 2. Chi-square test
    print("\n--- 2. Chi-square test ---")
    from scipy.stats import chi2_contingency
    
    contingency_table = pd.crosstab(df_analysis['group'], df_analysis['dementia_status'])
    print("Contingency table:")
    print(contingency_table)
    
    chi2, p_value, dof, expected = chi2_contingency(contingency_table)
    print(f"\nChi-square = {chi2:.4f}, P-value = {p_value:.4e}")
    
    # 3. Logistic regression
    print("\n--- 3. Logistic regression analysis ---")
    
    df_analysis['group_0'] = (df_analysis['group'] == 0).astype(int)
    df_analysis['group_1'] = (df_analysis['group'] == 1).astype(int)
    
    # 3a. Unadjusted model
    print("\n3a. Unadjusted logistic regression (reference: Group 2):")
    
    X_unadj = df_analysis[['group_0', 'group_1']]
    X_unadj = sm.add_constant(X_unadj)
    y = df_analysis['dementia_status']
    
    results_unadj = []
    try:
        model_unadj = sm.Logit(y, X_unadj).fit(disp=0)
        
        print(f"\n{'Group':<20} {'OR':<10} {'95% CI':<20} {'P-value':<10}")
        print("-" * 60)
        
        for var in ['group_0', 'group_1']:
            coef = model_unadj.params[var]
            se = model_unadj.bse[var]
            pval = model_unadj.pvalues[var]
            or_val = np.exp(coef)
            ci_low = np.exp(coef - 1.96 * se)
            ci_high = np.exp(coef + 1.96 * se)
            
            group_name = "Group 0 (Low protection)" if var == 'group_0' else "Group 1 (Medium)"
            sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else ""
            print(f"{group_name:<20} {or_val:<10.3f} ({ci_low:.3f}-{ci_high:.3f})    {pval:.4f} {sig}")
            
            results_unadj.append({
                'Model': 'Unadjusted', 'Group': group_name,
                'OR': or_val, 'CI_low': ci_low, 'CI_high': ci_high, 'P_value': pval
            })
        
        print(f"{'Group 2 (High protection)':<20} {'1.000':<10} {'(Reference)':<20}")
        results_unadj.append({'Model': 'Unadjusted', 'Group': 'Group 2 (High protection)', 
                              'OR': 1.0, 'CI_low': 1.0, 'CI_high': 1.0, 'P_value': np.nan})
    except Exception as e:
        print(f"Unadjusted logistic regression failed: {e}")
    
    # 3b. Adjusted model
    print("\n3b. Adjusted logistic regression (covariates adjusted):")
    
    adj_vars_for_model = [f'{var}_baseline' for var in available_adj_vars if f'{var}_baseline' in df_analysis.columns]
    print(f"Adjustment variables: {adj_vars_for_model}")
    
    results_adj = []
    df_adj = df_analysis.dropna(subset=adj_vars_for_model + ['group_0', 'group_1', 'dementia_status']).copy()
    print(f"Adjusted sample size: {len(df_adj)} individuals")
    
    if len(df_adj) > 0 and len(adj_vars_for_model) > 0:
        X_adj = df_adj[['group_0', 'group_1'] + adj_vars_for_model]
        X_adj = sm.add_constant(X_adj)
        y_adj = df_adj['dementia_status']
        
        try:
            model_adj = sm.Logit(y_adj, X_adj).fit(disp=0)
            
            print(f"\n{'Variable':<25} {'OR':<10} {'95% CI':<20} {'P-value':<10}")
            print("-" * 65)
            
            for var in ['group_0', 'group_1']:
                coef = model_adj.params[var]
                se = model_adj.bse[var]
                pval = model_adj.pvalues[var]
                or_val = np.exp(coef)
                ci_low = np.exp(coef - 1.96 * se)
                ci_high = np.exp(coef + 1.96 * se)
                
                group_name = "Group 0 (Low protection)" if var == 'group_0' else "Group 1 (Medium)"
                sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else ""
                print(f"{group_name:<25} {or_val:<10.3f} ({ci_low:.3f}-{ci_high:.3f})    {pval:.4f} {sig}")
                
                results_adj.append({
                    'Model': 'Adjusted', 'Group': group_name,
                    'OR': or_val, 'CI_low': ci_low, 'CI_high': ci_high, 'P_value': pval
                })
            
            print(f"{'Group 2 (High protection)':<25} {'1.000':<10} {'(Reference)':<20}")
            results_adj.append({'Model': 'Adjusted', 'Group': 'Group 2 (High protection)',
                                'OR': 1.0, 'CI_low': 1.0, 'CI_high': 1.0, 'P_value': np.nan})
            
            print("\nCovariate effects:")
            for var in adj_vars_for_model:
                if var in model_adj.params.index:
                    or_val = np.exp(model_adj.params[var])
                    pval = model_adj.pvalues[var]
                    sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else ""
                    print(f"  {var}: OR = {or_val:.3f}, P = {pval:.4f} {sig}")
                    
        except Exception as e:
            print(f"Adjusted logistic regression failed: {e}")
    
    # 4. Save logistic regression results
    all_results = results_unadj + results_adj
    if all_results:
        df_results = pd.DataFrame(all_results)
        results_csv = os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_group_dementia_logistic_results.csv')
        df_results.to_csv(results_csv, index=False, encoding='utf-8-sig')
        print(f"\nSaved: {results_csv}")
    
    # 5. Forest plot
    print("\n--- 5. Forest plot ---")
    
    plot_data = []
    for r in results_unadj + results_adj:
        if r['OR'] != 1.0:
            plot_data.append(r)
    
    if plot_data:
        df_plot = pd.DataFrame(plot_data)
        
        fig, ax = plt.subplots(figsize=(10, 6))
        colors = {'Unadjusted': '#1f77b4', 'Adjusted': '#ff7f0e'}
        
        y_pos = 0
        y_positions = []
        y_labels = []
        
        for group in ['Group 0 (Low protection)', 'Group 1 (Medium)']:
            for model in ['Unadjusted', 'Adjusted']:
                row = df_plot[(df_plot['Group'] == group) & (df_plot['Model'] == model)]
                if len(row) > 0:
                    row = row.iloc[0]
                    y_positions.append(y_pos)
                    y_labels.append(f"{group}\n{model}")
                    
                    ax.plot([row['CI_low'], row['CI_high']], [y_pos, y_pos], 
                            color=colors[model], linewidth=2)
                    ax.scatter(row['OR'], y_pos, color=colors[model], s=100, zorder=5)
                    
                    sig = '***' if row['P_value'] < 0.001 else '**' if row['P_value'] < 0.01 else '*' if row['P_value'] < 0.05 else ''
                    ax.annotate(f"{row['OR']:.2f}{sig}", xy=(row['CI_high'] + 0.1, y_pos), fontsize=10, va='center')
                    
                    y_pos += 1
            y_pos += 0.5
        
        ax.axvline(x=1, color='black', linestyle='--', linewidth=1)
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        ax.set_xlabel('Odds Ratio (95% CI)', fontsize=12)
        ax.set_title(f'{DATASET_NAME}: Trajectory Group and Dementia Risk\n(Reference: Group 2 - High Protection)', fontsize=14)
        ax.set_xlim(0, max(df_plot['CI_high'].max() * 1.3, 3))
        ax.invert_yaxis()
        ax.grid(axis='x', alpha=0.3)
        
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='#1f77b4', label='Unadjusted'),
                           Patch(facecolor='#ff7f0e', label='Adjusted')]
        ax.legend(handles=legend_elements, loc='lower right')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_group_dementia_forest_plot.pdf'),
                    format='pdf', bbox_inches='tight')
        plt.savefig(os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_group_dementia_forest_plot.png'),
                    format='png', dpi=300, bbox_inches='tight')
        plt.close()
        print("Saved: forest_plot")
    
    # 6. Bar plot
    print("\n--- 6. Bar plot ---")
    
    fig, ax = plt.subplots(figsize=(8, 6))
    group_labels = ['Group 0\n(Low Protection)', 'Group 1\n(Medium)', 'Group 2\n(High Protection)']
    dementia_rates = dementia_by_group['dementia_pct'].values
    colors = ['#d62728', '#ff7f0e', '#2ca02c']
    
    bars = ax.bar(group_labels, dementia_rates, color=colors, edgecolor='black', linewidth=1.5)
    
    for bar, rate, (idx, row) in zip(bars, dementia_rates, dementia_by_group.iterrows()):
        ax.annotate(f'{rate:.1f}%\n(n={int(row["dementia_cases"])}/{int(row["n"])})',
                    xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                    ha='center', va='bottom', fontsize=11)
    
    ax.set_ylabel('Dementia Incidence Rate (%)', fontsize=12)
    ax.set_title(f'{DATASET_NAME}: Dementia Incidence by {SCORE_VAR} Trajectory Group', fontsize=14)
    ax.set_ylim(0, max(dementia_rates) * 1.3)
    ax.annotate(f'Chi-square: P = {p_value:.4e}', xy=(0.5, 0.95), xycoords='axes fraction',
                ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_group_dementia_barplot.pdf'),
                format='pdf', bbox_inches='tight')
    plt.savefig(os.path.join(output_dir, f'{DATASET_NAME}_{SCORE_VAR}_group_dementia_barplot.png'),
                format='png', dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved: barplot")

# =========================
# Plot trajectories
# =========================
print("\n=== Plotting trajectories ===")

traj_data = []
for i, id_val in enumerate(unique_ids):
    group = cluster_labels[i]
    for t_idx, t_val in enumerate(time_grid):
        traj_data.append({
            'id': id_val, 'wave': t_val,
            f'{SCORE_VAR}_interp': interpolated_original[i, t_idx],
            'group': group
        })
df_traj = pd.DataFrame(traj_data)

palette_colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
unique_groups = sorted(df['group'].dropna().unique().tolist())
color_map = {g: palette_colors[g % len(palette_colors)] for g in unique_groups}

# Trajectory plot 1: Mean trajectories
plt.figure(figsize=(10, 6))
for g in unique_groups:
    gdata = df[df['group'] == g]
    n_g = gdata['id'].nunique()
    sns.lineplot(data=gdata, x='wave', y=SCORE_VAR, color=color_map[g], 
                 errorbar=('ci', 95), linewidth=2.5, label=f'Group {g} (n={n_g})')
plt.xlim(WAVE_MIN, WAVE_MAX)
plt.xlabel('Wave', fontsize=12)
plt.ylabel(SCORE_VAR, fontsize=12)
plt.title(f'{DATASET_NAME}: Mean {SCORE_VAR} Trajectories by Group (k=3)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend(title='Trajectory Group')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'mean_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.pdf'),
            format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(output_dir, f'mean_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.png'),
            format='png', dpi=300, bbox_inches='tight')
plt.close()

# Trajectory plot 2: Smoothed mean trajectories
plt.figure(figsize=(10, 6))
for g in unique_groups:
    gdata = df_traj[df_traj['group'] == g]
    n_g = gdata['id'].nunique()
    sns.lineplot(data=gdata, x='wave', y=f'{SCORE_VAR}_interp', color=color_map[g],
                 errorbar=('ci', 95), linewidth=2.5, label=f'Group {g} (n={n_g})')
plt.xlim(WAVE_MIN, WAVE_MAX)
plt.xlabel('Wave', fontsize=12)
plt.ylabel(f'{SCORE_VAR} (Interpolated)', fontsize=12)
plt.title(f'{DATASET_NAME}: Smoothed {SCORE_VAR} Trajectories (k=3)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend(title='Trajectory Group')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'smooth_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.pdf'),
            format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(output_dir, f'smooth_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.png'),
            format='png', dpi=300, bbox_inches='tight')
plt.close()

# Trajectory plot 3: Individual trajectories
plt.figure(figsize=(10, 6))
for g in unique_groups:
    gdata = df[df['group'] == g]
    for uid in gdata['id'].unique():
        uid_data = gdata[gdata['id'] == uid].sort_values('wave')
        plt.plot(uid_data['wave'], uid_data[SCORE_VAR], color=color_map[g], alpha=0.1, linewidth=0.5)
    n_g = gdata['id'].nunique()
    sns.lineplot(data=gdata, x='wave', y=SCORE_VAR, color=color_map[g],
                 errorbar=None, linewidth=3, label=f'Group {g} (n={n_g})')
plt.xlim(WAVE_MIN, WAVE_MAX)
plt.xlabel('Wave', fontsize=12)
plt.ylabel(SCORE_VAR, fontsize=12)
plt.title(f'{DATASET_NAME}: Individual {SCORE_VAR} Trajectories by Group (k=3)', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend(title='Trajectory Group')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'individual_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.pdf'),
            format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(output_dir, f'individual_trajectories_{DATASET_NAME}_{SCORE_VAR}_original_k3.png'),
            format='png', dpi=300, bbox_inches='tight')
plt.close()

# Silhouette scores
sil_df = pd.DataFrame(sil_scores, columns=['k', 'silhouette'])
plt.figure(figsize=(8, 5))
plt.plot(sil_df['k'], sil_df['silhouette'], 'o-', linewidth=2, markersize=8)
plt.axvline(x=FIXED_K, color='red', linestyle='--', label=f'Selected k={FIXED_K}')
plt.xlabel('Number of Clusters (k)', fontsize=12)
plt.ylabel('Silhouette Score', fontsize=12)
plt.title(f'{DATASET_NAME}: Silhouette Score by Number of Clusters', fontsize=14)
plt.xticks(list(k_range))
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'silhouette_scores_{DATASET_NAME}_{SCORE_VAR}_adjusted_k3.pdf'),
            format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(output_dir, f'silhouette_scores_{DATASET_NAME}_{SCORE_VAR}_adjusted_k3.png'),
            format='png', dpi=300, bbox_inches='tight')
plt.close()

print("\nTrajectory plots saved")

# =========================
# Completion
# =========================
print("\n" + "="*60)
print(f"=== {DATASET_NAME} processing completed ===")
print("="*60)
print(f"\n★ Group interpretation:")
print("  Group 0 = Low level (low protection → high dementia risk)")
print("  Group 1 = Medium level")
print("  Group 2 = High level (high protection → low dementia risk, reference)")