In [1]:
import pandas as pd
import numpy as np
import matplotlib as plt
import os
import logging
import arviz as az

ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(


In [None]:
# Set up logging: messages will include timestamp, log level, and message content
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

PATH_MODEL = '/volumes/hyijie_psy/CPP_stage2_HDDM/high_Sun_2023/model_fitted'
PATH_MODEL_FIt = '../results/ddm_new'
os.makedirs(PATH_MODEL_FIt, exist_ok=True)

MODEL_CONFIG = {
    'm2_ams': {
        'main_param': 'v_ams',
        'interactions': [
            'v_C(condition)[T.2]:ams',
            'v_C(condition)[T.3]:ams',
            'v_C(condition)[T.4]:ams'
        ]
    },
    'm2_pams': {
        'main_param': 'v_pams',
        'interactions': [
            'v_C(condition)[T.2]:pams',
            'v_C(condition)[T.3]:pams',
            'v_C(condition)[T.4]:pams'
        ]
    },
    'm2_slps': {
        'main_param': 'v_slps',
        'interactions': [
            'v_C(condition)[T.2]:slps',
            'v_C(condition)[T.3]:slps',
            'v_C(condition)[T.4]:slps'
        ]
    },
    'm2_ams_bin': {
        'main_param': 'v_ams_bin',
        'interactions': [
            'v_C(condition)[T.2]:ams_bin',
            'v_C(condition)[T.3]:ams_bin',
            'v_C(condition)[T.4]:ams_bin'
        ]
    },
    'm2_pams_bin': {
        'main_param': 'v_pam_bin',
        'interactions': [
            'v_C(condition)[T.2]:pam_bin',
            'v_C(condition)[T.3]:pam_bin',
            'v_C(condition)[T.4]:pam_bin'
        ]
    },
    'm2_slps_bin': {
        'main_param': 'v_slp_bin',
        'interactions': [
            'v_C(condition)[T.2]:slp_bin',
            'v_C(condition)[T.3]:slp_bin',
            'v_C(condition)[T.4]:slp_bin'
        ]
    }
}

In [6]:
# Extract and save posterior and Loo value

loo_results = []

for model_name, config in MODEL_CONFIG.items():

    main_param = config['main_param']
    interaction_names = config['interactions']
    
    path_model_infdata = os.path.join(PATH_MODEL, model_name + '.nc')
    assert os.path.exists(path_model_infdata), f" nc file not found: {model_name}"
    
    logger.info(f"Processing model: {model_name}")

    # Extract and save Loo for all models
    m_infdata = az.from_netcdf(path_model_infdata)
    loo = az.loo(m_infdata, pointwise=True)
    loo_dict = {
            "model": main_param,
            "loo": loo.elpd_loo,
            "p_loo": loo.p_loo
        }
    loo_results.append(loo_dict)

    # Save posterior data for all models
    posterior_dict = {}

    # Main effect
    if main_param in m_infdata.posterior:
        posterior_dict[main_param] = m_infdata.posterior[main_param].stack(sample=("chain", "draw")).values
    else:
        logger.error(f"Main parameter '{main_param}' not in {model_name}")
        continue
    # Interaction effect
    for inter in interaction_names:
        if inter in m_infdata.posterior:
            posterior_dict[inter] = m_infdata.posterior[inter].stack(sample=("chain", "draw")).values
        else:
            logger.warning(f"Interaction '{inter}' not found in model '{model_name}'")

    # Save data
    data_posterior = pd.DataFrame(posterior_dict)
    path_save_posterior = os.path.join(PATH_MODEL_FIt, f'{model_name}_posterior.csv')
    data_posterior.to_csv(path_save_posterior, index=False)

loo_df = pd.DataFrame(loo_results)
loo_df.to_csv(os.path.join(PATH_MODEL_FIt, "loo_comparison.csv"), index=False)

2025-12-22 20:42:53,417 - INFO - Processing model: m2_ams
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  x = np.expm1(-kappa * np.log1p(-probs)) / kappa
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
2025-12-22 20:50:38,374 - INFO - Processing model: m2_pams
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
  x = np.expm1(-kappa * np.log1p(-probs)) / kappa
  b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
  len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
  sigma = -k_post / b_post
2025-12-22 20:58:39,717 - INFO - Processing model: m2_slps
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  x = np.expm1(-kappa * np.log1p(-probs)) / kappa
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
2025-12-22 21:06:03,717 - INFO - Processing model: m2_ams_bin
  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  return umr_sum(a, 