In [None]:
import pickle
from pathlib import Path

import networkx as nx
import pandas as pd
from dowhy import CausalModel

from structure_data import choose_columns, preprocess_data

In [None]:
ROOT = Path('./')
RAW_DATA = ROOT / "raw_data/csv/eqls_2007and2011.csv"
DICT_PATH = ROOT / "data/dictionary.json"
GRAPH_PATH = ROOT / "graphs/full_causal.gpickle"
TREATMENT = "Y11_Q57"
OUTCOME = "Y11_MWIndex"


def _prep_ordinals_inplace(df: pd.DataFrame, ord_cols, known_orders=None):
    """
    Map ordinal levels to 0..m-1 and scale to [0,1] so L1 equals Gower per-feature.
    known_orders: optional dict {col: [lowest,...,highest]} to control ordering.
    """
    for c in ord_cols:
        if known_orders and c in known_orders:
            order = {v:i for i, v in enumerate(known_orders[c])}
            df[c] = df[c].map(order).astype(float)
            m = max(order.values()) if order else 1
            df[c] = df[c] / (m if m > 0 else 1.0)
        else:
            # infer: if numeric-coded, scale min–max; else sort unique labels
            if pd.api.types.is_numeric_dtype(df[c]):
                lo, hi = df[c].min(), df[c].max()
                rng = (hi - lo) if hi > lo else 1.0
                df[c] = (df[c] - lo) / rng
            else:
                levels = sorted(df[c].dropna().unique().tolist())
                mapping = {v:i for i, v in enumerate(levels)}
                m = max(mapping.values()) if mapping else 1
                df[c] = df[c].map(mapping).astype(float) / (m if m > 0 else 1.0)


def load_data() -> pd.DataFrame:
    """Load and preprocess the raw EQLS data."""
    bdvs = ['Y11_EmploymentStatus', 'Y11_HHstructure', 'Y11_HHsize', 'Y11_Agecategory', 'Y11_Q7', 'Y11_Q31', 'Y11_Country', 'Y11_Q32', 'Y11_HH2a', TREATMENT, OUTCOME]
    df = choose_columns()
    df = preprocess_data(
        df,
        na_threshold=0.5,
        impute_strategy="drop",
        treatment_dichotomize_value="median",
        treatment_column=TREATMENT,
        backdoor_variables=bdvs
    )
    df.to_csv("data/eqls_processed.csv", index=False)
    return df

def get_schema() -> dict:
    categorical = ['Y11_Q32', 'Y11_Q7']
    ordinal = ['Y11_Agecategory','Y11_Country','Y11_EmploymentStatus','Y11_HH2a','Y11_HHsize','Y11_HHstructure','Y11_Q31']
    return {
        'cat': categorical,
        'ord': ordinal
    }

def load_graph() -> nx.DiGraph:
    """Load the causal graph describing relationships among variables."""
    with open(GRAPH_PATH, "rb") as f:
        return pickle.load(f)


def estimate_effects(df: pd.DataFrame, graph) -> dict:
    df2 = df.copy()
    _prep_ordinals_inplace(df2, get_schema()['ord'])

    model = CausalModel(
        data=df2,
        treatment=TREATMENT,    # must be binary {0,1}
        outcome=OUTCOME,
        graph=nx.nx_pydot.to_pydot(graph).to_string(),
    )
    estimand = model.identify_effect()

    methods = [
        "backdoor.propensity_score_matching",
        "backdoor.propensity_score_weighting",
        "backdoor.propensity_score_stratification",
        "backdoor.linear_regression",
        "backdoor.distance_matching",
    ]
    kwargs = {
        "backdoor.distance_matching": dict(
            target_units="ate",                    # or "att"/"atc"
            method_params={
                "distance_metric": "minkowski",  # L1 with feature weights
                "p": 1,
                "num_matches_per_unit": 1,        # change if you want m:1 matching
                "exact_match_cols": get_schema()['cat'],  # force exact match on pure categoricals
            },
        )
    }

    results = {}
    for m in methods:
        try:
            est = model.estimate_effect(
                estimand, method_name=m, **kwargs.get(m, {})
            )
            results[m] = float(est.value)
        except Exception as e:
            print(f"[!] Estimation with {m} failed: {e}")
            results[m] = float("nan")
    return results


In [None]:
df = load_data()   
graph = load_graph()
results = estimate_effects(df, graph)

print("Estimation results (ATE):")
for name, val in results.items():
    print(f"  {name:<40} {val:.4f}")