# *[Working Title]* Resting-State EEG Aperiodic Exponent Moderates the Association Between Age and Memory Performance in Older Adults

Alicia J. Campbell, Toomas Erik Anijärv, Mikael Johansson, Jim Lagopoulos, Daniel F. Hermens, Jacob M. Levenstein, & Sophie C. Andrews

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import statsmodels.api as sm
import numpy as np
import seaborn as sns
from scipy import stats
import pingouin as pg
import warnings
import statsmodels.stats.multitest as smm
from HLR import HierarchicalLinearRegression

warnings.simplefilter(action='ignore', category=FutureWarning)

### data prep

In [None]:
fullsample_df = pd.read_csv('data/leisure_t1_demo_cantab_resting_eeg_raw.csv')

In [None]:
#### Exclusions

# Less than 50% epochs remaining after artifact rejection 
# (do not use for alpha or aperiodic analysis)
       
    # ID: [36] n=1

# The fit of the parameterized power spectrum to the original PSD was below a cut-off of explained variance (R2) < .9 
# (do not use for alpha or aperiodic analysis)

    # Fronto-central: [ID : 9, 93, 119] n=3

# Bad aperiodic fit as per visual inspection (i.e., fit includes significant omissions of signal) (do not use for alpha or aperiodic analysis)

    # Parieto-occipital: [ID : 14, 118, 119] n=3
    # Fronto-central: [ID : 16, 38, 61, 118] n=4

# No individual alpha peak could be detected. 
# (still use for aperiodic analyses)

    # Parieto-occipital: [ID : 61, 75, 93] n=3
    # Fronto-central [ID : 14, 80, 130] n=3

#### Exclusions together per measure:

    # IAF_po [ID :  14, 16, 36, 113, 118, 120] n=6
    # Exponent_po [ID :  14, 36, 113, 118, 119] n=6
    # IAF_fc [ID : 9, 14, 16, 36, 38, 61, 80, 93, 109, 119, 118, 130] n=11
    # Exponent_fc [ID : 9, 16, 36, 38, 61, 93, 118, 119] n=7

In [None]:
# List of variables to process
variables = ['Exponent_po', 'Exponent_fc', 'IAF_po', 'IAF_fc']

# Dictionary of subjects to exclude for each variable
exclude_subjects = {
    'Exponent_po': [14, 36, 118, 119],
    'Exponent_fc': [9, 16, 36, 38, 61, 93, 118, 119],
    'IAF_po': [14, 36, 61, 75, 93, 118, 119 ],
    'IAF_fc': [9, 14, 16, 36, 38, 61, 80, 93, 118, 119, 130]
}

# Exclude specified subjects and create new columns with the suffix '_EX'
for var in variables:
    subjects_to_exclude = exclude_subjects.get(var, [])
    mask = fullsample_df['Subject'].isin(subjects_to_exclude)
    fullsample_df[var + '_EX'] = fullsample_df.loc[~mask, var]

In [None]:
### IDENTIFY OUTLIERS

columns_to_check = ['Education', 'Exponent_po_EX', 'Exponent_fc_EX', 'IAF_po_EX', 'IAF_fc_EX', 'DMSPCAD', 'PALTEA', 'SWMBE'] 

df_vis = fullsample_df.dropna(subset=columns_to_check)

# Plotting z-scores for visualization
plt.figure(figsize=(12, 8))
for col in columns_to_check:
    # Calculate z-scores on the fly
    zs = stats.zscore(df_vis[col])
    plt.plot(df_vis['Subject'], zs, marker='o', linestyle='none', label=f'{col} z-score')
plt.axhline(3, color='red', linestyle='--', label='Outlier threshold (+3)')
plt.axhline(-3, color='green', linestyle='--', label='Outlier threshold (-3)')
plt.xlabel('Subject')
plt.ylabel('Z-Score')
plt.title('Z-Scores of Columns')
plt.legend()
plt.show()

# Identify outliers and the associated subject and column
outliers = []
for col in columns_to_check:
    zs = stats.zscore(df_vis[col])
    for idx, z in enumerate(zs):
        if z > 3 or z < -3:
            # Use .iloc to ensure correct positional indexing
            outliers.append((df_vis.iloc[idx]['Subject'], col, df_vis.iloc[idx][col], z))

# Print outlier information
print("Outliers found:")
for outlier in outliers:
    print(f"Subject: {outlier[0]}, Column: {outlier[1]}, Value: {outlier[2]}, Z-Score: {outlier[3]:.2f}")

In [None]:
columns_to_check = ['Age', 'Gender_F', 'Education', 'Handedness_right', 
                    'DMSPCAD', 'PALTEA', 'SWMBE', 
                    'IAF_fc_EX', 'Exponent_fc_EX', 'IAF_po_EX', 'Exponent_po_EX']

# Clean the data by removing rows with missing values
df_skew = fullsample_df.dropna(subset=columns_to_check)

# Calculate skewness and kurtosis
skewness = df_skew[columns_to_check].skew()
kurtosis = df_skew[columns_to_check].kurtosis()

# Combine results into a DataFrame
skewness_kurtosis_df = pd.DataFrame({
    'Skewness': skewness,
    'Kurtosis': kurtosis
}).round(3)

print(skewness_kurtosis_df)

In [None]:
fullsample_df_outliers_clean = fullsample_df.copy()

### Winsorize outliers (Z-score: >3)

column_to_adjust = ['Education', 'SWMBE', 'IAF_fc_EX'] 

def safe_winsorize(series, z_threshold=3):
    """Winsorize based on dynamic z-score thresholds"""
    zs = stats.zscore(series, nan_policy='omit')
    upper_bound = series[zs <= z_threshold].max()
    lower_bound = series[zs >= -z_threshold].min()
    # return lower_bound, upper_bound
    return series.clip(lower=lower_bound, upper=upper_bound)

# Apply to create _OA columns
for col in column_to_adjust:
    fullsample_df_outliers_clean.loc[:, f'{col}_OA'] = safe_winsorize(fullsample_df_outliers_clean[col])


In [None]:
columns_to_check = ['Education_OA', 'Exponent_po_EX', 'Exponent_fc_EX', 'IAF_po_EX', 'IAF_fc_EX_OA', 'DMSPCAD', 'PALTEA', 'SWMBE_OA'] 

df_vis = fullsample_df_outliers_clean.dropna(subset=columns_to_check)

# Plotting z-scores for visualization
plt.figure(figsize=(12, 8))
for col in columns_to_check:
    # Calculate z-scores on the fly
    zs = stats.zscore(df_vis[col])
    plt.plot(df_vis['Subject'], zs, marker='o', linestyle='none', label=f'{col} z-score')
plt.axhline(3, color='red', linestyle='--', label='Outlier threshold (+3)')
plt.axhline(-3, color='green', linestyle='--', label='Outlier threshold (-3)')
plt.xlabel('Subject')
plt.ylabel('Z-Score')
plt.title('Z-Scores of Columns')
plt.legend()
plt.show()

# Identify outliers and the associated subject and column
outliers = []
for col in columns_to_check:
    zs = stats.zscore(df_vis[col])
    for idx, z in enumerate(zs):
        if z > 3 or z < -3:
            # Use .iloc to ensure correct positional indexing
            outliers.append((df_vis.iloc[idx]['Subject'], col, df_vis.iloc[idx][col], z))

# Print outlier information
print("Outliers found:")
for outlier in outliers:
    print(f"Subject: {outlier[0]}, Column: {outlier[1]}, Value: {outlier[2]}, Z-Score: {outlier[3]:.2f}")

In [None]:
columns_to_check = ['Education_OA', 'IAF_fc_EX_OA', 'SWMBE_OA']

# Clean the data by removing rows with missing values
df_skew_corrected = fullsample_df_outliers_clean.dropna(subset=columns_to_check)

# Calculate skewness and kurtosis
skewness = df_skew_corrected[columns_to_check].skew()
kurtosis = df_skew_corrected[columns_to_check].kurtosis()

# Combine results into a DataFrame
skewness_kurtosis_corrected_df = pd.DataFrame({
    'Skewness': skewness,
    'Kurtosis': kurtosis
}).round(3)

print(skewness_kurtosis_corrected_df)

In [None]:
# Z-score continuous variables
columns_to_zscore = ['Age', 'Education_OA', 'DMSPCAD', 'PALTEA', 'SWMBE_OA', 'Exponent_po_EX', 'Exponent_fc_EX', 'IAF_po_EX', 'IAF_fc_EX_OA']

fullsample_df_outliers_clean_z = fullsample_df_outliers_clean.copy()

for column in columns_to_zscore:
    mean_val = fullsample_df_outliers_clean_z[column].mean()
    std_val = fullsample_df_outliers_clean_z[column].std()
    fullsample_df_outliers_clean_z.loc[:, column + '_z'] = (fullsample_df_outliers_clean_z[column] - mean_val) / std_val

In [None]:
# List of columns to include in the new DataFrame
columns_to_include = ['Subject', 
                      'Age', 'Age_z',
                      'Gender_F', 
                      'Education_OA', 'Education_OA_z',
                      'Handedness_right', 
                      'IAF_po_EX', 'IAF_po_EX_z',
                      'IAF_fc_EX_OA', 'IAF_fc_EX_OA_z',
                      'Exponent_po_EX', 'Exponent_po_EX_z',
                      'Exponent_fc_EX','Exponent_fc_EX_z',
                      'DMSPCAD', 'DMSPCAD_z',
                      'PALTEA', 'PALTEA_z',
                      'SWMBE_OA', 'SWMBE_OA_z']

# Create a new DataFrame with only the selected columns and make a copy
fullsample_df_CLEAN = fullsample_df_outliers_clean_z[columns_to_include].copy()

# Dictionary of old column names and new column names
new_column_names = {
    'Education_OA': 'Education',
    'Education_OA_z': 'Education_z',
    'IAF_po_EX': 'IAF_po',
    'IAF_po_EX_z': 'IAF_po_z',
    'IAF_fc_EX_OA': 'IAF_fc',
    'IAF_fc_EX_OA_z':'IAF_fc_z',
    'Exponent_po_EX': 'Exponent_po',
    'Exponent_po_EX_z': 'Exponent_po_z',
    'Exponent_fc_EX': 'Exponent_fc',
    'Exponent_fc_EX_z': 'Exponent_fc_z',
    'SWMBE_OA': 'SWMBE',
    'SWMBE_OA_z': 'SWMBE_z'}

# Rename the columns in the new DataFrame
fullsample_df_CLEAN.rename(columns=new_column_names, inplace=True)

In [None]:
# Interaction terms dictionary
interaction_terms = {
    'Age_Exponent_po': ['Age_z', 'Exponent_po_z'],
    'Age_IAF_po': ['Age_z', 'IAF_po_z'],
    'Age_Exponent_fc': ['Age_z', 'Exponent_fc_z'],
    'Age_IAF_fc': ['Age_z', 'IAF_fc_z']
}

# Compute interaction terms
for interaction_name, columns in interaction_terms.items():
    col1, col2 = columns
    mask = fullsample_df_CLEAN[[col1, col2]].notna().all(axis=1)
    fullsample_df_CLEAN.loc[mask, interaction_name] = fullsample_df_CLEAN.loc[mask, col1] * fullsample_df_CLEAN.loc[mask, col2]

In [None]:
# fullsample_df_CLEAN.to_csv('data/leisure_t1_demo_cantab_resting_eeg_clean_2.csv', index=False)

## Analysis

In [None]:
fullsample_df_CLEAN = pd.read_csv('data/leisure_t1_demo_cantab_resting_eeg_clean_2.csv')

### Sample descriptives

In [None]:
num_females = fullsample_df_CLEAN['Gender_F'].sum()
print(f"Number of females: {num_females}")

num_righthanded = fullsample_df_CLEAN['Handedness_right'].sum()
print(f"Number of right handed: {num_righthanded}")



descriptives_df = fullsample_df_CLEAN[['Age', 'Education', 'DMSPCAD', 'PALTEA', 'SWMBE', 'IAF_fc', 'Exponent_fc', 'IAF_po', 'Exponent_po']].describe().round(2)
filtered_stats = descriptives_df.loc[['count', 'mean', 'std', 'min', 'max']]

# filtered_stats.to_csv('results/sampledescriptives.csv')

display(filtered_stats)

#### *Spearman correlations*

In [None]:
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests

correlation_list = ['Age', 'DMSPCAD', 'PALTEA', 'SWMBE', 'IAF_fc', 'Exponent_fc', 'IAF_po', 'Exponent_po']

correlation_matrix = pd.DataFrame(index=correlation_list, columns=correlation_list)
p_value_matrix = pd.DataFrame(index=correlation_list, columns=correlation_list)

all_p_values = []
p_value_locs = []

# Compute pairwise Spearman correlation and p-values
for col1 in correlation_list:
    for col2 in correlation_list:
        if col1 == col2:
            correlation_matrix.loc[col1, col2] = 1.000  
            p_value_matrix.loc[col1, col2] = np.nan  
        else:
            valid_data = fullsample_df_CLEAN[[col1, col2]].dropna()
            
            if not valid_data.empty:
                corr, p_value = spearmanr(valid_data[col1], valid_data[col2])
                
                correlation_matrix.loc[col1, col2] = f"{corr:.3f}"
                p_value_matrix.loc[col1, col2] = p_value
                
                all_p_values.append(p_value)
                p_value_locs.append((col1, col2))
            else:
                correlation_matrix.loc[col1, col2] = np.nan
                p_value_matrix.loc[col1, col2] = np.nan

## Untoggle for raw p values
# Apply FDR correction to the list of p-values
rejected, pvals_corrected, _, _ = multipletests(all_p_values, alpha=0.05, method='fdr_bh')
# Replace the original p-values in the matrix with the FDR-corrected p-values
for idx, (col1, col2) in enumerate(p_value_locs):
    p_value_matrix.loc[col1, col2] = pvals_corrected[idx]

annot_matrix = correlation_matrix.copy()

# Iterate through the p_value_matrix and format the annotations
for col1 in correlation_list:
    for col2 in correlation_list:
        if pd.notna(p_value_matrix.loc[col1, col2]):
            if p_value_matrix.loc[col1, col2] < 0.001:
                annot_matrix.loc[col1, col2] = f"{correlation_matrix.loc[col1, col2]}\n(<0.001)"
            else:
                annot_matrix.loc[col1, col2] = f"{correlation_matrix.loc[col1, col2]}\n({p_value_matrix.loc[col1, col2]:.3f})"

renaming_dict = {
    'DMSPCAD': "DMS PCAD", 
    'PALTEA': "PAL TEA",
    'SWMBE': "SWM BE",
    'Exponent_po': "PO exponent", 
    'Exponent_fc': "FC exponent", 
    'IAF_po': "PO IAF", 
    'IAF_fc': "FC IAF",
}

correlation_matrix.rename(columns=renaming_dict, index=renaming_dict, inplace=True)
annot_matrix.rename(columns=renaming_dict, index=renaming_dict, inplace=True)

sns.set_theme(style='white')

cmap = plt.get_cmap("RdBu_r")

plt.figure(figsize=(7.5, 6), dpi=150)
ax = sns.heatmap(
    correlation_matrix.astype(float), 
    cmap=cmap, # coolwarm
    fmt="", 
    cbar=True, 
    mask=np.triu(np.ones_like(correlation_matrix, dtype=bool), k=1),
    annot=annot_matrix,
    annot_kws={"size": 9},
    linewidths=.5,
    vmin=-1.0, 
    vmax=1.0 ,  
    center=0.0 
)

xticklabels = ax.get_xticklabels()
yticklabels = ax.get_yticklabels()

# xticklabels[-1] = ''
# yticklabels[0] = ''

ax.set_xticklabels(xticklabels, rotation=30, ha='right', fontsize=9)
ax.set_yticklabels(yticklabels, rotation=0, fontsize=9)

cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=9)

plt.tight_layout()

# Save the plot to file
plt.savefig('results/plots/corrmatrix.pdf', format='pdf', bbox_inches='tight', dpi=300)
plt.show()

### *Heirachical linear regressions / Simple slopes analysis and plots*

In [None]:
def simple_slope_analysis(model, level, var='Age_z', interaction='Age_Exponent_po'):
    
    # step3 model covariance matrix, coefficients, and DoF
    b_var, b_int = model.params[var], model.params[interaction]
    cov_matrix = model.cov_params()
    df = model.df_resid

    # calculate slope at the specified level of iv
    slope = b_var + b_int * level

    # variances and covariance from the covariance matrix
    var_var = cov_matrix.loc[var, var]
    var_int = cov_matrix.loc[interaction, interaction]
    cov_age_int = cov_matrix.loc[var, interaction]

    # slope SE, tval, pval, and CIs
    se = np.sqrt(var_var + level**2 * var_int + 2 * level * cov_age_int)
    tval = slope / se
    pval = 2 * (1 - stats.t.cdf(np.abs(tval), df))
    ci = [slope-stats.t.ppf(0.975, df)*se, slope+stats.t.ppf(0.975, df)*se]
    return slope, se, tval, pval, ci

dict_hlr = {}
df_ = fullsample_df_CLEAN.copy()

sns.set_theme(style='whitegrid', font='Helvetica', context='paper') 
plt.rc('grid', linestyle='-', alpha=0.3)
plt.rcParams.update({
    'font.size':9,
    'axes.labelsize':9,
    'axes.titlesize':9,
    'xtick.labelsize':9,
    'ytick.labelsize':9,
    'legend.fontsize':9,
    'legend.title_fontsize':9
})
(fig, axs), i = plt.subplots(1, 2, figsize=(180/25.4, 90/25.4), dpi=300), 0

dict_labels = {
    'Exponent_fc': 'Fronto-central exponent', 'Exponent_po': 'Parieto-occipital exponent',
    'DMSPCAD_z_pred': 'DMS PCAD (z-scored)', 'PALTEA_z_pred': 'PAL TEA (z-scored)', 
}
# colors = ['#2E627AFF', '#AEB2B7FF', '#B53737FF']
cmap = plt.get_cmap("RdBu_r")
colors = [cmap(0.1), '#AEB2B7FF', cmap(0.825)]
grp_conditions = {'Low (≤-1SD)':-1, 'Mean': 0, 'High (≥1SD)': 1}

import os
HLR_output_dir = 'results/hlr_summaries'
os.makedirs(HLR_output_dir, exist_ok=True)

covars = ['Age_z', 'Gender_F', 'Education_z', 'Handedness_right']
for iv in ['Exponent_fc', 'Exponent_po', 'IAF_fc', 'IAF_po']:
    for dv in ['DMSPCAD_z' , 'PALTEA_z', 'SWMBE_z']:
        X = {
            1: covars,
            2: covars + [f'{iv}_z'],
            3: covars + [f'Age_{iv}'] + [f'{iv}_z']
        }

        df_t = df_[['Subject', 'Age']+X[3]+[dv]].dropna()

        hlr_model = HierarchicalLinearRegression(df_t, X, dv)
        hlr_summary = hlr_model.summary()

        hlr_summary_csv_path = os.path.join(HLR_output_dir, f'hlr_summary_{iv}_{dv}.csv')
        # hlr_summary.to_csv(hlr_summary_csv_path, index=False)

        # predict dv values based on final model
        s3_model = hlr_model.fit_models()[3]
        df_t[f'{dv}_pred'] = s3_model.predict(sm.add_constant(df_t[X[3]]))

        # add HLR model and step3 LM results to dict
        dict_hlr[(iv, dv)] = {'hlr_model': hlr_model, 'hlr_summary': hlr_summary,
                              's3_model': s3_model, 's3_df': df_t}
        
        # check final model significance
        s3_pval = hlr_summary.iloc[2]['P-value (F-value change)']
        
        # Within the significant model block:
        if s3_pval < 0.05:
            print(f'\n---\niv={iv}, dv={dv}, step 3 model p-val = {s3_pval:.3f}')

            # 1. IV-BASED GROUPS (for Age slope analysis)
            df_t.loc[:, 'iv_group'] = np.where(df_t[f'{iv}_z'] >= 1, 'High (≥1SD)',
                                            np.where(df_t[f'{iv}_z'] <= -1, 'Low (≤-1SD)', 'Mean'))

            desired_order = ['Low (≤-1SD)', 'Mean', 'High (≥1SD)']
            ax = axs[i]
            sns.scatterplot(ax=ax,x='Age_z', y=f'{dv}_pred', data=df_t, hue='iv_group', hue_order=desired_order, 
                            palette=colors, edgecolor='#494949', lw=0.8, s=20, alpha=0.7)
            for c, g_lab in enumerate(grp_conditions.keys()):
                sns.regplot(ax=ax, x='Age_z', y=f'{dv}_pred', data=df_t[df_t['iv_group']==g_lab], color=colors[c], scatter=False)
            ax.legend(title=dict_labels[iv], ncols=3, loc='upper center', frameon=False, bbox_to_anchor=(0.5, 1.2), 
                      handletextpad=0.1, columnspacing=2)
            ax.set_ylabel(dict_labels[f'{dv}_pred'])
            ax.set_xlabel("Age (z-scored)")  
            ax.set_ylim([-1.65, 1.65])

            # Slope Analysis
            # 1. Age Slopes in IV Groups
            print("\n=== Age Effects Across IV Groups ===")
            for iv_group, iv_level in grp_conditions.items():
                group_size = df_t['iv_group'].value_counts()[iv_group]
                slope, se, tval, pval, ci = simple_slope_analysis(
                    model=s3_model,
                    level=iv_level,
                    var='Age_z',
                    interaction=f'Age_{iv}'
                )
                print(f"\n{iv_group} IV Group (n={group_size}):")
                print(f"Age Slope: {slope:.3f}, SE: {se:.3f}, [95% CI: {ci[0]:.3f}, {ci[1]:.3f}], t = {tval:.3f}, p = {pval:.3f}")
            i += 1

plt.tight_layout(pad=-0.3)
plt.savefig('results/plots/simpleslopesplot.pdf', format='pdf', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
def round_summary_df(summary_df):
    """Helper function to round numeric values in HLR summary DataFrame"""
    # 1. Round top-level numeric columns
    numeric_cols = [
        'R-squared', 'F-value', 'P-value (F)', 'SSR', 'SSTO',
        'MSE (model)', 'MSE (residuals)', 'MSE (total)',
        'R-squared change', 'F-value change', 'P-value (F-value change)'
    ]
    for col in numeric_cols:
        if col in summary_df.columns:
            summary_df[col] = summary_df[col].round(3)
    
    # 2. Process dictionary columns with numeric values
    dict_cols = [
        'Beta coefs', 'P-values (beta coefs)', 'T-values (beta coefs)',
        'Standard errors', 'Std Beta coefs', 'Partial correlations',
        'Semi-partial correlations', 'Unique variance %'
    ]
    
    def round_dict(d):
        return {k: round(v, 3) if isinstance(v, (int, float)) else v 
                for k, v in d.items()} if isinstance(d, dict) else d
    
    for col in dict_cols:
        if col in summary_df.columns:
            summary_df[col] = summary_df[col].apply(round_dict)
    
    return summary_df


def simple_slope_analysis(model, level, var='Age_z', interaction='Age_Exponent_po'):
    
    # step3 model covariance matrix, coefficients, and DoF
    b_var, b_int = model.params[var], model.params[interaction]
    cov_matrix = model.cov_params()
    df = model.df_resid

    # calculate slope at the specified level of iv
    slope = b_var + b_int * level

    # variances and covariance from the covariance matrix
    var_var = cov_matrix.loc[var, var]
    var_int = cov_matrix.loc[interaction, interaction]
    cov_age_int = cov_matrix.loc[var, interaction]

    # slope SE, tval, pval, and CIs
    se = np.sqrt(var_var + level**2 * var_int + 2 * level * cov_age_int)
    tval = slope / se
    pval = 2 * (1 - stats.t.cdf(np.abs(tval), df))
    ci = [slope-stats.t.ppf(0.975, df)*se, slope+stats.t.ppf(0.975, df)*se]
    return slope, se, tval, pval, ci

dict_hlr = {}
df_ = fullsample_df_CLEAN.copy()

sns.set_theme(style='whitegrid', font='Helvetica', context='paper') 
plt.rc('grid', linestyle='-', alpha=0.3)
plt.rcParams.update({
    'font.size':9,
    'axes.labelsize':9,
    'axes.titlesize':9,
    'xtick.labelsize':9,
    'ytick.labelsize':9,
    'legend.fontsize':9,
    'legend.title_fontsize':9
})
(fig, axs), i = plt.subplots(1, 2, figsize=(180/25.4, 90/25.4), dpi=300), 0

dict_labels = {
    'Exponent_fc': 'Fronto-central exponent', 'Exponent_po': 'Parieto-occipital exponent',
    'DMSPCAD_z_pred': 'DMS PCAD (z-scored)', 'PALTEA_z_pred': 'PAL TEA (z-scored)', 
}
# colors = ['#2E627AFF', '#AEB2B7FF', '#B53737FF']
cmap = plt.get_cmap("RdBu_r")
colors = [cmap(0.1), '#AEB2B7FF', cmap(0.825)]
grp_conditions = {'Low (≤-1SD)':-1, 'Mean': 0, 'High (≥1SD)': 1}

import os
HLR_output_dir = 'results/hlr_summaries'
os.makedirs(HLR_output_dir, exist_ok=True)

# slope_output_dir = 'results/simple_slope_analyses'
# os.makedirs(slope_output_dir, exist_ok=True)

covars = ['Age_z', 'Gender_F', 'Education_z', 'Handedness_right']
for iv in ['Exponent_fc', 'Exponent_po', 'IAF_fc', 'IAF_po']:
    for dv in ['DMSPCAD_z' , 'PALTEA_z', 'SWMBE_z']:
        X = {
            1: covars,
            2: covars + [f'{iv}_z'],
            3: covars + [f'Age_{iv}'] + [f'{iv}_z']
        }

        df_t = df_[['Subject', 'Age']+X[3]+[dv]].dropna()

        hlr_model = HierarchicalLinearRegression(df_t, X, dv)
        hlr_summary = hlr_model.summary()

        hlr_summary_rounded = round_summary_df(hlr_summary)

        hlr_summary_csv_path = os.path.join(HLR_output_dir, f'hlr_summary_rounded_{iv}_{dv}.csv')
        hlr_summary_rounded.to_csv(hlr_summary_csv_path, index=False)

        # predict dv values based on final model
        s3_model = hlr_model.fit_models()[3]
        df_t[f'{dv}_pred'] = s3_model.predict(sm.add_constant(df_t[X[3]]))

        # add HLR model and step3 LM results to dict
        dict_hlr[(iv, dv)] = {'hlr_model': hlr_model, 'hlr_summary': hlr_summary,
                              's3_model': s3_model, 's3_df': df_t}
        
        # check final model significance
        s3_pval = hlr_summary.iloc[2]['P-value (F-value change)']
        
        # Within the significant model block:
        if s3_pval < 0.05:
            print(f'\n---\niv={iv}, dv={dv}, step 3 model p-val = {s3_pval:.3f}')

            # 1. IV-BASED GROUPS (for Age slope analysis)
            df_t.loc[:, 'iv_group'] = np.where(df_t[f'{iv}_z'] >= 1, 'High (≥1SD)',
                                            np.where(df_t[f'{iv}_z'] <= -1, 'Low (≤-1SD)', 'Mean'))
            
            # # 2. AGE-BASED GROUPS (for IV slope analysis)
            # age_z = (df_t['Age'] - df_t['Age'].mean())/df_t['Age'].std()
            # df_t.loc[:, 'age_group'] = np.where(age_z >= 1, 'Old (≥1SD)',
            #                                 np.where(age_z <= -1, 'Young (≤-1SD)', 'Middle Aged'))

            desired_order = ['Low (≤-1SD)', 'Mean', 'High (≥1SD)']
            ax = axs[i]
            sns.scatterplot(ax=ax,x='Age_z', y=f'{dv}_pred', data=df_t, hue='iv_group', hue_order=desired_order, 
                            palette=colors, edgecolor='#494949', lw=0.8, s=20, alpha=0.7)
            for c, g_lab in enumerate(grp_conditions.keys()):
                sns.regplot(ax=ax, x='Age_z', y=f'{dv}_pred', data=df_t[df_t['iv_group']==g_lab], color=colors[c], scatter=False)
            ax.legend(title=dict_labels[iv], ncols=3, loc='upper center', frameon=False, bbox_to_anchor=(0.5, 1.2), 
                      handletextpad=0.1, columnspacing=2)
            ax.set_ylabel(dict_labels[f'{dv}_pred'])
            ax.set_xlabel("Age (z-scored)")  
            ax.set_ylim([-1.65, 1.65])

            # Slope Analysis
            # 1. Age Slopes in IV Groups
            print("\n=== Age Effects Across IV Groups ===")
            for iv_group, iv_level in grp_conditions.items():
                group_size = df_t['iv_group'].value_counts()[iv_group]
                slope, se, tval, pval, ci = simple_slope_analysis(
                    model=s3_model,
                    level=iv_level,
                    var='Age_z',
                    interaction=f'Age_{iv}'
                )
                print(f"\n{iv_group} IV Group (n={group_size}):")
                print(f"Age Slope: {slope:.3f} [95% CI: {ci[0]:.3f}, {ci[1]:.3f}], t = {tval:.2f}, p = {pval:.3f}")

            # # 2. IV Slopes in Age Groups
            # print("\n=== IV Effects Across Age Groups ===")
            # age_group_levels = {
            #     'Young (≤-1SD)': -1,
            #     'Middle Aged': 0,
            #     'Old (≥1SD)': 1
            # }
            # for age_group, age_level in age_group_levels.items():
            #     group_size = df_t['age_group'].value_counts()[age_group]
            #     slope, se, tval, pval, ci = simple_slope_analysis(
            #         model=s3_model,
            #         level=age_level,
            #         var=f'{iv}_z',
            #         interaction=f'Age_{iv}'
            #     )
            #     print(f"\n{age_group} Group (n={group_size}):")
            #     print(f"{iv} Slope: {slope:.3f} [95% CI: {ci[0]:.3f}, {ci[1]:.3f}], t = {tval:.2f}, p = {pval:.3f}")

            i += 1

plt.tight_layout(pad=-0.3)
# plt.savefig('results/plots/simpleslopesplot.pdf', format='pdf', bbox_inches='tight', dpi=300)
plt.show()
