In [None]:
import sys
sys.path.append('..')
from datetime import datetime 
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from data.linear.binary import BinaryLinearData

from algorithm.permutation_based import PermutationBasedTest

from experiment.utils import run
from experiment.plot import plot_vary_sample_env, set_mpl_default_settings

set_mpl_default_settings()

# Experiment: Vary number of samples and environments with permutation based procedure

In [None]:
load = True 

# Data and algorithm
simulate_class = BinaryLinearData
test_method_list = [PermutationBasedTest()]

# Experiment parameters
nbr_env = [25, 100, 200]
nbr_samples = [2, 10, 25, 50, 100]
repetitions = 50
sign_level = 0.05

# Fixed dataset
conf_strength = [5]
dist_param = {'X': {'a': 0.0, 'b': 1},
              'Y': {'a': 0.0, 'b': 1},
              'T': {'a': 0.0, 'b': 1}
                   }

# Get timestamp for experiment
now = datetime.now()
timestamp = now.strftime("%m%d%H%M")
print('Timestamp:', timestamp)


In [None]:
if not load:
    
    experiment_results = {}
    for alg in tqdm(test_method_list):
        
        args = [dist_param, nbr_env, nbr_samples, conf_strength, simulate_class, alg, repetitions, sign_level]
        alg_name = type(alg).__name__
        
        res = run(args, save_during_run=f'results/vary_sample_env_{alg_name}_{timestamp}.csv')
        experiment_results[alg_name] = pd.concat(res)

else:
    
    # Load data
    timestamp_str = "11151107"
    test_method_list = ['PermutationBasedTest']
    timestamp = int(timestamp_str)

    experiment_results = {}
    for alg in test_method_list:
        path = f'results/vary_sample_env_{alg}_{timestamp_str}.csv'
        df = pd.read_csv(path)
        df.sort_values('nbr_samples', inplace=True)
        experiment_results[alg] = df

# Plot results

In [None]:
def filter_experiment(exp_dict : dict, conf_strength : float) -> dict:
    '''
    Select exerpeimental results with a specific confounding strength
    '''
    new_exp_dict = {}
    for alg in exp_dict:
        df = exp_dict[alg]
        new_exp_dict[alg] = df[(np.abs(df.confounder_strength - conf_strength) < 1e-3)]
    return new_exp_dict

def plot_experiment(confounder_strength : float):
    '''
    Plot curves for each fixed number of environments
    '''

    for e in nbr_env:

        tmp_res = filter_experiment(experiment_results, confounder_strength)
        plot_vary_sample_env(tmp_res, x_axis='nbr_samples', fix_val=e,label=f'K={e}',iter=repetitions)

    plt.ylabel('Detection rate')
    plt.legend()

    path = f'results/figures/vary_sample_env_cs{confounder_strength}_{timestamp}.pdf'
    plt.savefig(path, format='pdf', bbox_inches='tight')

In [None]:
for c in conf_strength:
    plt.figure()
    plot_experiment(c)