In [1]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from multiprocess import Pool
from cfs_erf_spatial import erf_spatial
from pytorch_lightning import seed_everything

seed_everything(42);

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 42


### Load Environment

In [6]:
from spacebench import SpaceEnv, DataMaster
datamaster = DataMaster()
datasets = datamaster.master 
# find datasets that have continuous exposure
cts = datasets.index[datasets['exposure'] == 'continuous'].values
discrete = datasets.index[datasets['exposure'] == 'binary'].values


### Run spatial and spatial+ for each cts dataset environment, and dataset in parallel

Results are being saved in continuous.csv.

In [7]:
# for each environment
for envname in cts:
    env = SpaceEnv(str(envname))
    # for each masked variable
    dataset_list = list(env.make_all())
    args = list(zip(dataset_list, range(len(dataset_list))))
    with Pool(4) as p: # 4 is the number of processes
        pool_outputs = np.column_stack(
        (
        tqdm(
                p.imap(lambda m: erf_spatial(*m, envname = envname, filename = 'continuous.csv'), # does not preserve order of datasets
                            args),
                        # should write to continuous.csv a vector of envname, dataset id, confounding, smoothness, 
                        # erf_error_spatial, erf_error_spatialplus, pehe_spatial, pehe_spatialplus
                        total=len(dataset_list)
            )
        )
        )

  pool_outputs = np.column_stack(
100%|██████████| 12/12 [00:50<00:00,  4.20s/it]
  pool_outputs = np.column_stack(
 40%|████      | 4/10 [00:26<00:39,  6.64s/it]


KeyboardInterrupt: 

### Analyze results by env: continuous

In [9]:
for envname in cts:
    outputs = pd.read_csv('continuous.csv', header = None)
    outputs = outputs[outputs.iloc[:, 0] == envname]

    smoothness_scores = outputs.iloc[:, 2]
    confounding_scores = outputs.iloc[:, 3] 
    erf_error_spatial = outputs.iloc[:, 4]
    erf_error_spatialplus = outputs.iloc[:, 5]
    #pehe_spatial = outputs.iloc[:, 6]
    #pehe_spatialplus = outputs.iloc[:, 7]

    erf_errors = dict(
        smoothness=["low" if x < 0.5 else "high" for x in smoothness_scores],
        confounding=["low" if x < 0.1 else "high" for x in confounding_scores],
        spatial_erf_error=erf_error_spatial,
        spatialplus_erf_error=erf_error_spatialplus,
        #spatial_pehe_avg = pehe_spatial,
        #spatialplus_pehe_avg = pehe_spatialplus
    )
    erf_errors = pd.DataFrame(erf_errors)
    display(erf_errors.groupby(["smoothness", "confounding"]).agg(["mean", "std"]))

Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
high,high,4.986489,2.942052,4.98649,2.942053
low,high,7.585207,1.848049,7.585067,1.84699
low,low,5.771517,,5.771518,


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
high,low,0.005352,0.00117,0.005352,0.00117


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2


Unnamed: 0_level_0,Unnamed: 1_level_0,spatial_erf_error,spatial_erf_error,spatialplus_erf_error,spatialplus_erf_error
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
smoothness,confounding,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
