In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib as mpl 

from data.binary import BinaryCauseData

from algorithm.permutation_test import PermutationConfounderTest

from experiment import sample_efficiency, save_results

# Experiment: Vary number of samples and environments with permutation based procedure

In [None]:
load = False

# Data and algorithm
SimulateClass = BinaryCauseData
TestClass_list = [PermutationConfounderTest]

# Experiment parameters
nbr_env = [25, 100, 200]
nbr_samples = [25, 50, 100, 200, 400, 800, 1600]
repetitions = 25
sign_level = 0.05

# Fixed dataset
conf_strength = [1]
dist_param = {'X': {'a': 0.0, 'b': 1},
              'Y': {'a': 0.0, 'b': 1},
              'T': {'a': 0.0, 'b': 1}
                   }

# Get timestamp for experiment
now = datetime.now()
timestamp = now.strftime("%m%d%H%M")
print('Timestamp:', timestamp)


In [None]:
if not load:
    experiment_results  =sample_efficiency(dist_param,
                                            nbr_env,
                                            nbr_samples,
                                            conf_strength,
                                            SimulateClass, 
                                            TestClass_list, 
                                            repetitions=repetitions, 
                                            sign_level=sign_level)

    for alg_name in experiment_results:

        save_results(experiment_results[alg_name], f'exp_res_{alg_name}', timestamp)


# Plot results

In [None]:
import matplotlib.pyplot as plt
from plot_tools import plot_curve, set_mpl_default_settings

set_mpl_default_settings()

In [None]:
if load:
    # Load data
    timestamp_str = "05171530"
    alg_list = ['PermutationConfounderTest']
    timestamp = int(timestamp_str)

    experiment_results = {}
    for alg in alg_list:
        path = f'results/exp_res_{alg}_{timestamp_str}.csv'
        df = pd.read_csv(path)
        df.sort_values('nbr_samples', inplace=True)
        experiment_results[alg] = df

def filter_experiment(exp_dict : dict, conf_strength : float) -> dict:
    
    new_exp_dict = {}
    for alg in exp_dict:
        df = exp_dict[alg]
        new_exp_dict[alg] = df[(np.abs(df.confounder_strength - conf_strength) < 1e-3)]
    return new_exp_dict


## Type 1 error with either fixed nbr of environments or fixed nbr of samples

In [None]:
# Type 1 error
fixed_nbr_samples = nbr_samples[-1]
confounder_strength = 1



# Experiment parameters to show
fixed_nbr_env = [25, 100, 200]

for e in fixed_nbr_env:

    tmp_res = filter_experiment(experiment_results, confounder_strength)
    
    plot_curve(tmp_res,
                'nbr_samples',
                e,
                label=f'K={e}',
                iter=25)

plt.ylabel('Probability of detection')
#plt.xscale(mpl.scale.LogScale(axis=0,base=2))
#plt.xticks(nbr_samples, labels=nbr_samples)
#plt.title(f'nbr_env={fixed_nbr_env}, conf_strength={confounder_strength}')
plt.legend()


path = f'results/figures/sample_gamma{confounder_strength}_{timestamp}.pdf'
plt.savefig(path, format='pdf', bbox_inches='tight')

In [None]:
# Type 2 error
confounder_strength = 0.0
tmp_res = filter_experiment(experiment_results, confounder_strength)

plot_curve(tmp_res, 'nbr_samples', fixed_nbr_env)
plt.ylabel('Type 2 error')
plt.title(f'nbr_env={fixed_nbr_env}, conf_strength={confounder_strength}')

path = f'results/figures/sample_type2_{timestamp}.pdf'
plt.savefig(path, format='pdf', bbox_inches='tight')