In [None]:
import sys
from tqdm import tqdm
import numpy as np
import itertools
from typing import List

# sys.path.append('..')
sys.path.append("../src")

from causal_falsify.utils.simulate_data import simulate_data
from causal_falsify.transport import TransportabilityTest
from causal_falsify.mint import MINT
from causal_falsify.hgic import HGIC

In [None]:
data = simulate_data(
    100,
    degree=1,
    conf_strength=0.0,
    transportability_violation=0.0,
    n_envs=50,
    n_observed_confounders=5,
    seed=42,
)

data.head()

In [None]:
def test_method_grid(
    method,
    iterations: int,
    seed: int,
    alpha: float = 0.05,
    n_samples_list=[100],
    degree_list=[1, 3],
    n_envs_list=[100],
    transportability_violation_list=[0.0, 1.0],
    n_observed_confounders_list=[2],
):

    config_grid = list(
        itertools.product(
            n_samples_list,
            degree_list,
            n_envs_list,
            transportability_violation_list,
            n_observed_confounders_list,
        )
    )

    for (
        n_samples,
        degree,
        n_envs,
        transportability_violation,
        n_observed_confounders,
    ) in config_grid:

        print("\n--- Testing Configuration ---")
        print(
            f"n_samples: {n_samples}, degree: {degree}, n_envs: {n_envs}, "
            f"transportability_violation: {transportability_violation}, "
            f"n_observed_confounders: {n_observed_confounders}"
        )

        def run_test(conf_strength):
            rejections = []
            for _ in tqdm(
                range(iterations), desc=f"conf_strength={conf_strength}", leave=False
            ):
                data = simulate_data(
                    n_samples=n_samples,
                    degree=degree,
                    conf_strength=conf_strength,
                    transportability_violation=transportability_violation,
                    n_envs=n_envs,
                    n_observed_confounders=n_observed_confounders,
                    seed=seed,
                )

                covariates = [f"X_{i}" for i in range(n_observed_confounders)]

                result = method.test(
                    data,
                    covariate_vars=covariates,
                    treatment_var="A",
                    outcome_var="Y",
                    source_var="S",
                )
                rejections.append(result < alpha)
            return rejections

        rejections_null_true = run_test(conf_strength=0.0)
        rejections_null_false = run_test(conf_strength=1.0)

        type_1_error = np.mean(rejections_null_true)
        type_2_error = 1 - np.mean(rejections_null_false)

        print(f"\nResults for current configuration:")
        print(f"  Type 1 error:  {type_1_error:.4f} (should be < {alpha})")
        print(f"  Type 2 error:  {type_2_error:.4f} (should be < 0.2)")

        if type_1_error < alpha:
            print(f"  ✅ PASS Type 1 Error check")
        else:
            print(f"  ❌ FAIL Type 1 Error check")

        if type_2_error < 0.2:
            print(f"  ✅ PASS Type 2 Error check")
        else:
            print(f"  ❌ FAIL Type 2 Error check")

In [None]:
hgic_method = HGIC(max_tests=-1, cond_indep_test="fisherz")
test_method_grid(hgic_method, iterations=100, seed=42)

In [None]:
hgic_method = HGIC(max_tests=-1, cond_indep_test="kcit_rbf")
test_method_grid(hgic_method, iterations=100, seed=42)

In [None]:
mint_method = MINT(feature_representation="linear")
test_method_grid(mint_method, iterations=100, seed=42)

In [None]:
mint_method = MINT(
    feature_representation="poly", feature_representation_params={"degree": 3}
)
test_method_grid(mint_method, iterations=100, seed=42)

In [None]:
transportability_method = TransportabilityTest(cond_indep_test="fisherz")
test_method_grid(transportability_method, iterations=100, seed=42)

In [None]:
transportability_method = TransportabilityTest(cond_indep_test="kcit_rbf", max_sample_size=250)
test_method_grid(transportability_method, iterations=100, seed=42)