# Role Moderation of SDT → Intervention-Specific Acceptance (China Sample)

Goal of H3

Test whether the association between self-determination (SDT; TENS_Life_mean_imputed) and intervention-specific acceptance:

- Accept_avatar_imputed (AI avatar / generic AI therapist)
- Accept_chatbot_imputed (AI chatbot)
- Accept_tele_imputed (teletherapy / human therapist)

is moderated by clinical role (role_label: client vs therapist) in the Chinese sample.

Note: Because the USA sample has role_label = "unknown" for all cases, a joint SDT × Country × Role model is not identified. Cross-country differences are instead handled via Country main effects in H1/H2. H3 focuses on role moderation within China where both clients and therapists are observed.

# 0.0 Paths and Data Loading

In [None]:
from __future__ import annotations

import warnings
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm
from statsmodels.stats.outliers_influence import variance_inflation_factor

warnings.filterwarnings("ignore", category=FutureWarning)

sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 5)
plt.rcParams["axes.titlesize"] = 13
plt.rcParams["axes.labelsize"] = 12
plt.rcParams["font.size"] = 11

PROJECT_ROOT = Path.cwd().resolve()
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = DATA_DIR / "output"

PROCESSED_PATH = OUTPUT_DIR / "processed_for_analysis.csv"

# 1.0. H3 - Country moderation of SDT → intervention-specific acceptance (3 technologies)

In [None]:
h3_vars = [
    # Outcomes
    "Accept_avatar_imputed",
    "Accept_chatbot_imputed",
    "Accept_tele_imputed",
    # SDT predictor
    "TENS_Life_mean_imputed",
    # Confounders
    "GAAIS_mean_imputed",
    "ET_mean_imputed",
    "PHQ5_mean_imputed",
    "SSRPH_mean_imputed",
    "age_imputed",
    "gender",
    # Country moderator
    "Country",
]

missing_h3 = [c for c in h3_vars if c not in processed.columns]
print("Missing H3 variables:", missing_h3)

h3_df = processed[h3_vars].copy()

In [None]:
# Restrict to China + USA
h3_df = h3_df[h3_df["Country"].isin(["China", "USA"])].copy()

In [None]:
# Drop rows missing key categorical covariates (gender, Country)
n_total = len(h3_df)
h3_df = h3_df.dropna(subset=["gender", "Country"])
n_analytic = len(h3_df)

In [None]:
print("H3 analytic sample (China + USA):")
print(f"N total (China & USA before drop): {n_total}")
print(f"N with non-missing gender & Country: {n_analytic}")

In [None]:
print("Country distribution:")
print(h3_df["Country"].value_counts(dropna=False))

In [None]:
print("Gender distribution:")
print(h3_df["gender"].value_counts(dropna=False))

# 2.0 Descriptive & Correlation for H3 Variables

In [None]:
continuous_h3 = [
    "TENS_Life_mean_imputed",
    "GAAIS_mean_imputed",
    "ET_mean_imputed",
    "PHQ5_mean_imputed",
    "SSRPH_mean_imputed",
    "age_imputed",
    "Accept_avatar_imputed",
    "Accept_chatbot_imputed",
    "Accept_tele_imputed",
]

In [None]:
print("Descriptive statistics (H3 continuous variables, China + USA):")
display(h3_df[continuous_h3].describe().T)

In [None]:
print("Correlation matrix (H3 – SDT, GAAIS, ET, PHQ, SSRPH, age, outcomes):")
corr_h3 = h3_df[continuous_h3].corr()
display(corr_h3.round(3))

# 3.0. Center Continuous Predictors in H3 Sample

We center SDT (TENS), GAAIS, ET, PHQ, SSRPH, age for interpretability

In [None]:
center_cols_h3 = [
    "TENS_Life_mean_imputed",
    "GAAIS_mean_imputed",
    "ET_mean_imputed",
    "PHQ5_mean_imputed",
    "SSRPH_mean_imputed",
    "age_imputed",
]

In [None]:
for col in center_cols_h3:
    mean_val = h3_df[col].mean()
    h3_df[f"{col}_c"] = h3_df[col] - mean_val
    print(f"{col} mean for centering (H3 sample): {mean_val:.3f}")

In [None]:
print("Means of centered variables (should be ≈ 0):")
display(h3_df[[f"{c}_c" for c in center_cols_h3]].mean())

# 4.0. Helper Function: Baseline vs Country-Moderation Models

### Baseline H3 model (main effects):
Outcome ~ TENS_c + GAAIS_c + ET_c + PHQ_c + SSRPH_c + age_c + C(gender) + C(Country)

### Full H3 model (Country moderation):
Outcome ~ TENS_c * C(Country) + GAAIS_c + ET_c + PHQ_c + SSRPH_c + age_c + C(gender)

Only the SDT × Country term is added in the full model → clean ANOVA.

In [None]:
def fit_country_moderation_h3(outcome: str, data: pd.DataFrame):
    """
    Fit H3 models for one outcome:
    - Baseline: SDT + confounders + Country (main effects)
    - Full: SDT * Country + same confounders
    """

    cols = [
        outcome,
        "TENS_Life_mean_imputed_c",
        "GAAIS_mean_imputed_c",
        "ET_mean_imputed_c",
        "PHQ5_mean_imputed_c",
        "SSRPH_mean_imputed_c",
        "age_imputed_c",
        "gender",
        "Country",
    ]

    sub_df = data[cols].dropna().copy()
    if sub_df.empty:
        print(f"\n{outcome}: no complete cases for H3.")
        return None, None, None

    print(f"=== H3 Country moderation for {outcome} (China + USA; N={len(sub_df)}) ===")

    # Baseline: SDT + confounders + Country main effect
    baseline_formula = (
        f"{outcome} ~ "
        "TENS_Life_mean_imputed_c "
        "+ GAAIS_mean_imputed_c "
        "+ ET_mean_imputed_c "
        "+ PHQ5_mean_imputed_c "
        "+ SSRPH_mean_imputed_c "
        "+ age_imputed_c "
        "+ C(gender) "
        "+ C(Country)"
    )

    h3_baseline = smf.ols(formula=baseline_formula, data=sub_df).fit()
    print("Baseline model (main effects only):")
    display(h3_baseline.summary().tables[1])
    print(f"R² (baseline) = {h3_baseline.rsquared:.3f}")

    # Full: add SDT × Country interaction
    full_formula = (
        f"{outcome} ~ "
        "TENS_Life_mean_imputed_c * C(Country) "
        "+ GAAIS_mean_imputed_c "
        "+ ET_mean_imputed_c "
        "+ PHQ5_mean_imputed_c "
        "+ SSRPH_mean_imputed_c "
        "+ age_imputed_c "
        "+ C(gender)"
    )

    h3_country_model = smf.ols(formula=full_formula, data=sub_df).fit()
    print("Country-moderation model (SDT × Country):")
    display(h3_country_model.summary().tables[1])
    print(f"R² (country-moderation) = {h3_country_model.rsquared:.3f}")

    # Model comparison (confounders identical)
    print("Model comparison (Baseline vs Country-moderation):")
    comp = anova_lm(h3_baseline, h3_country_model)
    display(comp)

    return sub_df, h3_baseline, h3_country_model

## 4.1. Fit H3 Models for All 3 Technologies

In [None]:
h3_outcomes = [
    "Accept_avatar_imputed",
    "Accept_chatbot_imputed",
    "Accept_tele_imputed",
]

h3_models: Dict[str, Dict[str, object]] = {}

In [None]:
for outcome in h3_outcomes:
    sub_df, base_m, country_m = fit_country_moderation_h3(outcome, h3_df)
    h3_models[outcome] = {
        "data": sub_df,
        "baseline": base_m,
        "country_model": country_m,
    }

## 4.2. Summary Table of SDT × Country Effects

We extract the difference in SDT slope for USA vs China for each outcome.

In [None]:
h3_summary_rows = []

for outcome in h3_outcomes:
    country_model = h3_models[outcome]["country_model"]
    if country_model is None:
        continue

    # Interaction term
    term_name = "TENS_Life_mean_imputed_c:C(Country)[T.USA]"

    if term_name not in country_model.params.index:
        print(f"{outcome}: interaction term not found in model.")
        continue

    beta = country_model.params[term_name]
    se = country_model.bse[term_name]
    p = country_model.pvalues[term_name]
    ci_low, ci_high = country_model.conf_int().loc[term_name]
    r2 = country_model.rsquared

    h3_summary_rows.append({
        "Outcome": outcome,
        "Interaction_term": term_name,
        "beta_SDTxCountry(USA_vs_China)": beta,
        "SE": se,
        "p": p,
        "CI_low": ci_low,
        "CI_high": ci_high,
        "R2_country_model": r2,
    })

In [None]:
h3_summary_df = pd.DataFrame(h3_summary_rows)
print("H3: SDT × Country interaction summary (China + USA):")
display(h3_summary_df)

## 4.3. VIF Check for One Focal H3 Model

In [None]:
focal_outcome_h3 = "Accept_chatbot_imputed"
focal_country_model = h3_models[focal_outcome_h3]["country_model"]

if focal_country_model is not None:
    X = focal_country_model.model.exog
    names = focal_country_model.model.exog_names

    vif_rows = []
    for i, name in enumerate(names):
        if name == "Intercept":
            continue
        vif_val = variance_inflation_factor(X, i)
        vif_rows.append({"Predictor": name, "VIF": vif_val})

    vif_h3_df = pd.DataFrame(vif_rows).sort_values("VIF", ascending=False)

    print(f"Variance Inflation Factors – H3 Country model for {focal_outcome_h3}:")
    display(vif_h3_df)

## 4.4. SDT → Acceptance by Country for the AI chatbot outcome:

In [None]:
if focal_country_model is not None:
    focal_df = h3_models[focal_outcome_h3]["data"].copy()

    tens_min = focal_df["TENS_Life_mean_imputed_c"].quantile(0.05)
    tens_max = focal_df["TENS_Life_mean_imputed_c"].quantile(0.95)
    tens_grid = np.linspace(tens_min, tens_max, 50)

    # Typical covariate profile: mean-centered = 0 on centered variables
    gender_ref = focal_df["gender"].mode()[0]
    gaais_ref = 0.0
    et_ref = 0.0
    phq_ref = 0.0
    ssrph_ref = 0.0
    age_ref = 0.0

    country_levels = ["China", "USA"]

    pred_rows = []
    for country in country_levels:
        for t_val in tens_grid:
            pred_rows.append({
                "TENS_Life_mean_imputed_c": t_val,
                "GAAIS_mean_imputed_c": gaais_ref,
                "ET_mean_imputed_c": et_ref,
                "PHQ5_mean_imputed_c": phq_ref,
                "SSRPH_mean_imputed_c": ssrph_ref,
                "age_imputed_c": age_ref,
                "gender": gender_ref,
                "Country": country,
            })

    pred_df = pd.DataFrame(pred_rows)
    pred_df["pred_accept"] = focal_country_model.predict(pred_df)

    # Back-transform SDT to raw scale for nicer x-axis
    tens_raw_mean = h3_df["TENS_Life_mean_imputed"].mean()
    pred_df["TENS_Life_raw"] = (
        pred_df["TENS_Life_mean_imputed_c"] + tens_raw_mean
    )

    plt.figure(figsize=(8, 6))
    sns.lineplot(
        data=pred_df,
        x="TENS_Life_raw",
        y="pred_accept",
        hue="Country"
    )
    plt.xlabel("Self-Determination (TENS_Life_mean, raw scale)")
    plt.ylabel(f"Predicted {focal_outcome_h3}")
    plt.title(
        f"H3: Predicted {focal_outcome_h3} across SDT\n"
        "for China vs USA (SDT × Country)"
    )
    plt.tight_layout()
    plt.show()
