In [6]:
from nudging.dataset import Pennycook1, Pennycook2, Balaban, Hotard, Lieberoth, Vandenbroele
from nudging.model import BiRegressor, MonoRegressor, XRegressor, ProbModel
from sklearn.linear_model import LogisticRegression, ARDRegression
from nudging.cate import get_cate
from nudging.evaluate_outcome import evaluate_outcome
import numpy as np
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning


In [7]:
all_data = {
    "penny1": Pennycook1.from_file("data"),
    "penny2": Pennycook2.from_file("data"),
    "balaban": Balaban.from_file("data"),
    "hotard": Hotard.from_file("data"),
    "lieberoth": Lieberoth.from_file("data"),
    "broele": Vandenbroele.from_file("data"),
}

all_models = {
    "biregressor": BiRegressor(ARDRegression()),
    "monoregressor": MonoRegressor(ARDRegression()),
    "xregressor": XRegressor(ARDRegression()),
    "prob_model": ProbModel(LogisticRegression())
}

In [3]:
@ignore_warnings(category=ConvergenceWarning)
def get_results(model, dataset, model_name):
    get_cate(model, dataset)
    print(model_name, data_name, np.mean(evaluate_outcome(model, dataset, n=10)))

    

In [4]:
for model_name, model in all_models.items():
    for data_name, dataset in all_data.items():
        get_results(model, dataset, model_name)


biregressor penny1 0.5559482479849938
biregressor penny2 0.23035300460356256
biregressor balaban -0.039560665976679316
biregressor hotard 0.09943216251015835
biregressor lieberoth 0.12067393869959686
biregressor broele 0.3848517602891001
monoregressor penny1 0.5741932314539356
monoregressor penny2 0.27444504142539916
monoregressor balaban 0.019614969047176564
monoregressor hotard 0.0681980084738301
monoregressor lieberoth -0.0777965298220599
monoregressor broele 0.1247735837724014
xregressor penny1 0.553433179367072
xregressor penny2 0.2312932277188758
xregressor balaban -0.06413450015515106
xregressor hotard 0.10240547887171632
xregressor lieberoth 0.11259671158460477
xregressor broele 0.2779365702432528
prob_model penny1 0.242410834580125
prob_model penny2 0.18148394667322865
prob_model balaban 0.15290530448624856
prob_model hotard 0.024887315795875928
prob_model lieberoth -0.1091747675785817
prob_model broele -0.05232876663382273
