In [5]:
from IPython.display import display

import warnings
warnings.filterwarnings('ignore')

from carla.data.causal_model import CausalModel


In [6]:
scm = CausalModel("sanity-3-lin")
dataset = scm.generate_dataset(10000)

display(dataset.df)

Unnamed: 0,label,x1,x2,x3
0,1.0,-1.915919,3.786182,1.578276
1,1.0,0.735970,1.262074,-0.603052
2,0.0,0.266367,-0.740880,-0.817968
3,1.0,2.593845,-3.314230,0.050053
4,1.0,-2.930150,3.099031,1.105855
...,...,...,...,...
9995,0.0,-3.300128,2.471631,0.405281
9996,0.0,1.405657,-1.602510,-0.537921
9997,0.0,-0.683962,0.513155,-0.415351
9998,0.0,2.088685,-2.621772,-1.592011


In [7]:
from carla.models.catalog import MLModelCatalog


training_params = {"lr": 0.01, "epochs": 10, "batch_size": 16, "hidden_size": [18, 9, 3]}

ml_model = MLModelCatalog(
    dataset, model_type="ann", load_online=False, backend="pytorch"
)
ml_model.train(
    learning_rate=training_params["lr"],
    epochs=training_params["epochs"],
    batch_size=training_params["batch_size"],
    hidden_size=training_params["hidden_size"],
    force_train=True
)

balance on test set 0.5249333333333334, balance on test set 0.5244
Epoch 0/9
----------
train Loss: 0.3843 Acc: 0.8309

test Loss: 0.3742 Acc: 0.8268

Epoch 1/9
----------
train Loss: 0.3603 Acc: 0.8441

test Loss: 0.3598 Acc: 0.8420

Epoch 2/9
----------
train Loss: 0.3580 Acc: 0.8391

test Loss: 0.3554 Acc: 0.8456

Epoch 3/9
----------
train Loss: 0.3578 Acc: 0.8444

test Loss: 0.3486 Acc: 0.8460

Epoch 4/9
----------
train Loss: 0.3573 Acc: 0.8428

test Loss: 0.3502 Acc: 0.8436

Epoch 5/9
----------
train Loss: 0.3558 Acc: 0.8459

test Loss: 0.3791 Acc: 0.8216

Epoch 6/9
----------
train Loss: 0.3540 Acc: 0.8463

test Loss: 0.3588 Acc: 0.8432

Epoch 7/9
----------
train Loss: 0.3521 Acc: 0.8449

test Loss: 0.3659 Acc: 0.8424

Epoch 8/9
----------
train Loss: 0.3547 Acc: 0.8407

test Loss: 0.3514 Acc: 0.8436

Epoch 9/9
----------
train Loss: 0.3504 Acc: 0.8471

test Loss: 0.3536 Acc: 0.8396



In [8]:
from carla.models.negative_instances import predict_negative_instances
from carla.recourse_methods.catalog.causal_recourse import (
    CausalRecourse,
    constraints,
    samplers,
)


In [9]:
# get factuals
factuals = predict_negative_instances(ml_model, dataset.df)
test_factual = factuals.iloc[:5]

In [10]:
hyperparams = {
    "optimization_approach": "brute_force",
    "num_samples": 10,
    "scm": scm,
    "constraint_handle": constraints.point_constraint,
    "sampler_handle": samplers.sample_true_m0,
}
cfs = CausalRecourse(ml_model, hyperparams).get_counterfactuals(test_factual)

display(cfs)

Unnamed: 0,x1,x2,x3
0,2.662037,-1.07115,-0.817968
1,2.593845,-3.31423,1.233533
2,3.461964,-2.031785,-1.961044
3,2.252802,-3.727581,-3.317253
4,-0.413102,-0.472335,1.090267
