In [19]:
import sys
import itertools

sys.path.append("../../")

import datetime 
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from src.ours.method import FalsificationAlgorithm
from experiments.data.dpg import DGP

In [None]:
n_envs_list = [10]
n_samples_list = [50, 100, 500]
confounding_strength_list = [0.0, 1.0]
degree_list = [1,2,3,4,5,6,7,8, 9,10]
iterations = 250
bootstrap = False

rows = []

current_datetime = datetime.datetime.now()
date_stamp_string = current_datetime.strftime("%Y-%m-%d_%H%M%S")
print(date_stamp_string)

for n_envs, n_samples, confounding_strength, degree in tqdm(
    list(
        itertools.product(
            n_envs_list, n_samples_list, confounding_strength_list, degree_list
        )
    )
):

    for i in range(iterations):
        dgp = DGP(
            n_envs=n_envs,
            n_observed_confounders=1,
            conf_strength=confounding_strength,
            degree=2,
        )

        data = dgp.sample(n_samples=n_samples)["data"]
        out = FalsificationAlgorithm(
            feature_representation="poly",
            feature_representation_params={"degree": degree},
            n_bootstraps=1000 if bootstrap else None,
        ).test(data, observed_covariates=dgp.get_covar())

        pval = out["pval"]
        outcome_model_diagnostics = out["outcome_model_diagnostics"]
        average_diagnostics = np.mean(
            [outcome_model_diagnostics[key] for key in outcome_model_diagnostics]
        )
        rows.append(
            {
                "pval": pval,
                "reject": pval < 0.05,
                "n_envs": n_envs,
                "n_samples": n_samples,
                "confounding_strength": confounding_strength,
                "degree": degree,
                "diagnostics": average_diagnostics,
            }
        )
df = pd.DataFrame(rows)  

In [21]:
df.to_csv(f'results-model_specification/{date_stamp_string}.csv')

read_ts = None 
if read_ts:
    df = pd.read_csv(f'results-model_specification/{read_ts}.csv')


In [None]:
# Assuming `df` is defined
grouped_df = df.groupby(['degree', 'confounding_strength', 'n_envs', 'n_samples']).agg(
    reject_mean=('reject', 'mean'),
    count=('reject', 'count')
).reset_index()

grouped_df['reject_se'] = np.sqrt(grouped_df['reject_mean'] * (1 - grouped_df['reject_mean']) / grouped_df['count'])

grouped_df["confounding_strength"] = grouped_df["confounding_strength"].replace(
    {0.0: "No", 1.0: "Yes"}
)
grouped_df = grouped_df.rename(columns={'confounding_strength': 'Unmeasured confounder', 'n_samples': 'Number of samples'})
sns.set_context("talk", font_scale=0.8)  # 'talk' is larger than default, scale up if needed
palette = sns.color_palette("colorblind")
# Create the main plot
ax = sns.lineplot(
    data=grouped_df,
    x="degree",
    y="reject_mean",
    hue="Unmeasured confounder",
    style="Number of samples",
    errorbar=None,
    palette=palette
)

# Add error bars
for _, sub_df in grouped_df.groupby(["degree", "Unmeasured confounder"]):
    ax.errorbar(
        sub_df["degree"],
        sub_df["reject_mean"],
        yerr=sub_df["reject_se"],
        fmt="none",
        capsize=3,
        color="gray",
        alpha=0.5,
    )

# Draw the horizontal line at y = 0.05
ax.axhline(0.05, linestyle="--", color="black", alpha=0.6)

plt.xlabel("Degree of polynomial model")
plt.ylabel("Falsification rate")
plt.ylim([-0.05, 1.05])
plt.yticks(np.arange(0.0, 1.1, 0.1))


# Get handles and labels for custom legend arrangement
handles, labels = ax.get_legend_handles_labels()

# Separate handles for each sub-legend
unmeasured_handles = handles[1:3]  # First two for 'Unmeasured confounder'
sample_handles = handles[4:]      # Remaining for 'Number of samples'

# Create a legend for 'Unmeasured confounder'
legend1 = ax.legend(
    handles=unmeasured_handles,
    labels=labels[1:3],
    title="Unmeasured confounder",
    loc="upper center",
    bbox_to_anchor=(0.2, -0.15),  # Adjust to position on the left
    ncol=2,
    frameon=False,
    fontsize=11
)

# Add a second legend for 'Number of samples' positioned next to the first legend
legend2 = plt.legend(
    handles=sample_handles,
    labels=labels[4:],
    title="Number of samples",
    loc="upper center",
    bbox_to_anchor=(0.75, -0.15),  # Adjust to position on the right
    ncol=3,
    frameon=False,
    fontsize=11
)

# Re-add the first legend to the axis
ax.add_artist(legend1)

#plt.tight_layout()
plt.savefig(f"results-model_specification/complexity-vs-reject-{date_stamp_string}.pdf", bbox_inches="tight")
