In [7]:
from nudging.dataset import Pennycook1, Pennycook2, Balaban, Hotard, Lieberoth, Vandenbroele
from nudging.model import BiRegressor, MonoRegressor, XRegressor, ProbModel
from sklearn.linear_model import BayesianRidge, LogisticRegression
from nudging.cate import get_cate
from nudging.evaluate_outcome import evaluate_outcome
import numpy as np
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning


In [8]:
all_data = {
    "penny1": Pennycook1.from_file("data"),
    "penny2": Pennycook2.from_file("data"),
    "balaban": Balaban.from_file("data"),
    "hotard": Hotard.from_file("data"),
    "lieberoth": Lieberoth.from_file("data"),
    "broele": Vandenbroele.from_file("data"),
}

all_models = {
    "biregressor": BiRegressor(BayesianRidge()),
    "monoregressor": MonoRegressor(BayesianRidge()),
    "xregressor": XRegressor(BayesianRidge()),
    "prob_model": ProbModel(LogisticRegression())
}

In [9]:
@ignore_warnings(category=ConvergenceWarning)
def get_results(model, dataset, model_name):
    get_cate(model, dataset)
    print(model_name, data_name, np.mean(evaluate_outcome(model, dataset, n=10)))

    

In [10]:
for model_name, model in all_models.items():
    for data_name, dataset in all_data.items():
        get_results(model, dataset, model_name)


biregressor penny1 0.5562139371803883
biregressor penny2 0.23099081291222973
biregressor balaban -0.0417916995417283
biregressor hotard 0.10169648512098349
biregressor lieberoth 0.13575054273846932
biregressor broele 0.41254253878281266
monoregressor penny1 0.573754354803573
monoregressor penny2 0.272151487193651
monoregressor balaban 0.028679841629658514
monoregressor hotard 0.0630887456455396
monoregressor lieberoth -0.054040073328721155
monoregressor broele 0.20119500933585796
xregressor penny1 0.5526021582651908
xregressor penny2 0.23407396362735475
xregressor balaban -0.058145079461449346
xregressor hotard 0.10245123097902145
xregressor lieberoth 0.08260214575234044
xregressor broele 0.3114004615375747
prob_model penny1 0.2405331560431224
prob_model penny2 0.19110417904413704
prob_model balaban 0.14946256210808026
prob_model hotard 0.030470223567570138
prob_model lieberoth -0.14030859896973685
prob_model broele 0.03589360021810301
