# Setup

In [1]:
import numpy as np, timeit, time, matplotlib.pyplot as plt, json, os
from tqdm import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from matplotlib import colormaps

In [2]:
from scipy.optimize import curve_fit
from collections.abc import Iterable
from itertools import product

In [3]:
from binary_models import *
from benchmark_models import rock_throwing_model, SMK_model, get_SMK_dim_labels, get_mSMK_SCM, get_bbSMK_SCM
from benchmark_models import get_noisy_suzy_SCM, get_nSMK_SCM
from actualcauses import beam_search, show_rules, iterative_identification

# Utils

# Main

## Example on rock throwing

In [4]:
variables = ("ST", "BT", "SH", "BH", "BS")
u = (1,1,1,0,1)
SCM = make_SCM(variables=variables, V_exo=u, model=rock_throwing_model)

In [5]:
SCM["variables"]

('ST', 'BT', 'SH', 'BH')

In [6]:
# causes = beam_search(**SCM, max_steps=-1,beam_size=3,early_stop=False,verbose=5)

In [7]:
# show_rules(causes, SCM["variables"])

## Structure example with SMK

In [8]:
n_attacker = 3
variables = get_SMK_dim_labels(n_attacker)
u = [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]
SCM = make_SCM(variables, u, SMK_model, n_attacker=n_attacker)
dag, init_var_ids = build_DAG(n_attacker, SCM["variables"])

In [9]:
t = time.time()
bs_causes = beam_search(**SCM,max_steps=6,beam_size=200, verbose=2, early_stop=False)
print(f"{time.time()-t:.3f}s")

Evaluating 35 rules
Number of causes found: 1
Number of non-causes remaining: 34
Best non-cause:
C={'FS-U1': '1'}, W={}, output=True, score=19.000
Worst non-cause:
C={'SD': '1'}, W={}, output=True, score=18.000
Evaluating 1717 rules
Number of causes found: 3
Number of non-causes remaining: 1705
Best non-cause:
C={'FS-U2': '0'}, W={'FS-U1': '0'}, output=True, score=16.000
Worst non-cause:
C={'SD': '1'}, W={'DK': '1'}, output=1, score=18.000
Evaluating 8729 rules
Number of causes found: 2
Number of non-causes remaining: 8721
Best non-cause:
C={'FS-U2': '0', 'FS-U3': '0'}, W={'FS-U1': '0'}, output=True, score=15.000
Worst non-cause:
C={'SD': '1', 'GK-U1': '0', 'GK-U3': '0'}, W={}, output=True, score=16.000
Evaluating 8820 rules
Number of causes found: 0
Number of non-causes remaining: 8820
Best non-cause:
C={'FS-U2': '0', 'FS-U3': '0', 'FN-U3': '0'}, W={'FS-U1': '0'}, output=True, score=13.000
Worst non-cause:
C={'FF-U3': '0', 'SD': '1', 'GP-U3': '0', 'GK-U1': '0'}, W={}, output=True, sco

In [10]:
show_rules(bs_causes, SCM["variables"])

C={'FS-U2': '0', 'FN-U2': '0'}, W={'DK-U3': '0'}, output=False, score=11.000
C={'FDB-U2': '0', 'FF-U2': '0'}, W={'DK-U3': '0'}, output=False, score=11.000
C={'GP-U2': '0'}, W={'DK-U3': '0'}, output=False, score=13.000
C={'GK-U2': '0'}, W={'DK-U3': '0'}, output=False, score=13.000
C={'DK-U2': '0'}, W={'DK-U3': '0'}, output=False, score=14.000
C={'DK': '0'}, W={}, output=False, score=15.000


In [12]:
t = time.time()
sbs_causes = iterative_identification(**SCM, dag=dag, 
                                            init_var_ids=init_var_ids,
                                            max_steps=-1, beam_size=-1, 
                                            verbose=1, early_stop=False)
print(f"{(time.time()-t):.3f}s")

len(queue)=1
var_ids=(34, 33), beams_w=None


2it [00:00, 2465.79it/s]


----> Found 1 causes.
C={'DK': '0'}, W={}, output=False, score=15.000
  Cause C={'DK': '0'}, W={} -> (27, 28, 29) ()
len(queue)=1
var_ids=(27, 28, 29), beams_w=()


3it [00:00, 548.13it/s]


----> Found 1 causes.
C={'DK-U2': '0'}, W={'DK-U3': '0'}, output=False, score=14.000
  Cause C={'DK-U2': '0'}, W={'DK-U3': '0'} -> (19, 22, 27) ((29, False),)
len(queue)=1
var_ids=(19, 22, 27), beams_w=((29, False),)


3it [00:00, 3824.59it/s]


----> Found 2 causes.
C={'GP-U2': '0'}, W={}, output=False, score=13.000
  Cause C={'GP-U2': '0'}, W={'DK-U3': '0'} -> (1, 4) ((29, False),)
  Cause C={'GK-U2': '0'}, W={'DK-U3': '0'} -> (7, 10) ((29, False),)
len(queue)=2
var_ids=(1, 4), beams_w=((29, False),)


2it [00:00, 3247.62it/s]


----> Found 1 causes.
C={'FS-U2': '0', 'FN-U2': '0'}, W={}, output=False, score=11.000
len(queue)=1
var_ids=(7, 10), beams_w=((29, False),)


2it [00:00, 3093.14it/s]

----> Found 1 causes.
C={'FF-U2': '0', 'FDB-U2': '0'}, W={}, output=False, score=11.000
0.040s





In [13]:
show_rules(sbs_causes, SCM["variables"])

C={'DK': '0'}, W={}, output=False, score=15.000
C={'DK-U2': '0'}, W={'DK-U3': '0'}, output=False, score=14.000
C={'GP-U2': '0'}, W={'DK-U3': '0'}, output=False, score=13.000
C={'GK-U2': '0'}, W={'DK-U3': '0'}, output=False, score=13.000
C={'FS-U2': '0', 'FN-U2': '0'}, W={'DK-U3': '0'}, output=False, score=11.000
C={'FDB-U2': '0', 'FF-U2': '0'}, W={'DK-U3': '0'}, output=False, score=11.000


## Non boolean SMK example

In [14]:
u = [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]
SCM = get_mSMK_SCM(3, u)

In [15]:
list(zip(range(len(SCM["variables"])),SCM["variables"], SCM["instance"])), SCM["instance"][-1]

([(0, 'FS', (1, 2)),
  (1, 'FN', (1, 2)),
  (2, 'FF', (0, 1, 2)),
  (3, 'FDB', (0, 1)),
  (4, 'A', (1,)),
  (5, 'AD', ()),
  (6, 'GP', (1, 2)),
  (7, 'GK', (0, 1, 2)),
  (8, 'KMS', ()),
  (9, 'DK', 1),
  (10, 'SD', -1)],
 True)

In [16]:
def solve_non_boolean(variables, instance):
    FS, FN, FF, FDB, A, AD, GP, GK, KMS, DK, SD, _ = instance
    print(f"{FS=}, {FN=}, {FF=}, {FDB=}, {A=}, {AD=}, {GP=}, {GK=}, {KMS=}, {DK=}, {SD=}")
    causes = []
    left_causes = []
    right_causes = []
    if DK != -1:
        left_causes.append([0])
        

In [17]:
solve_non_boolean(SCM["variables"], SCM["instance"])

FS=(1, 2), FN=(1, 2), FF=(0, 1, 2), FDB=(0, 1), A=(1,), AD=(), GP=(1, 2), GK=(0, 1, 2), KMS=(), DK=1, SD=-1


In [18]:
causes = beam_search(**SCM, early_stop=False, verbose=2)

Evaluating 69 rules
Number of causes found: 3
Number of non-causes remaining: 65
Best non-cause:
C={'FS': '()'}, W={}, output=True, score=13.000
Worst non-cause:
C={'SD': '2'}, W={}, output=True, score=18.000
Evaluating 518 rules
Number of causes found: 2
Number of non-causes remaining: 513
Best non-cause:
C={'FS': '()', 'FN': '(0,)'}, W={}, output=True, score=10.000
Worst non-cause:
C={'SD': '2', 'FDB': '()'}, W={}, output=True, score=16.000
Evaluating 330 rules
Number of causes found: 0
Number of non-causes remaining: 330
Best non-cause:
C={'FS': '()', 'FF': '()'}, W={'FN': '(1, 2)'}, output=True, score=9.000
Worst non-cause:
C={'FN': '(0,)', 'FF': '()', 'SD': '2'}, W={}, output=True, score=13.000
Evaluating 247 rules
Number of causes found: 0
Number of non-causes remaining: 247
Best non-cause:
C={'FS': '()', 'FF': '()'}, W={'FN': '(1, 2)', 'FDB': '(0, 1)'}, output=True, score=9.000
Worst non-cause:
C={'FN': '()', 'FF': '()', 'SD': '2', 'A': '()'}, W={}, output=True, score=11.000
Eva

In [19]:
show_rules(causes, SCM["variables"])

C={'FF': '()', 'FDB': '()'}, W={}, output=False, score=5.000
C={'FS': '()', 'FN': '()'}, W={}, output=False, score=7.000
C={'GK': '()'}, W={}, output=False, score=10.000
C={'GP': '()'}, W={}, output=False, score=11.000
C={'DK': '-1'}, W={}, output=False, score=13.000


In [21]:
dag, init_var_ids = build_DAG_non_boolean(3, SCM["variables"])
t = time.time()
causes = iterative_identification(**SCM, dag=dag, 
                                   init_var_ids=init_var_ids,
                                   max_steps=-1, beam_size=-1, 
                                   verbose=0, early_stop=False)
print(f"{(time.time()-t):.3f}s")

0.008s


In [22]:
show_rules(causes, SCM["variables"])

C={'DK': '-1'}, W={}, output=False, score=13.000
C={'GK': '()'}, W={}, output=False, score=10.000
C={'GP': '()'}, W={}, output=False, score=11.000
C={'FF': '()', 'FDB': '()'}, W={}, output=False, score=5.000
C={'FS': '()', 'FN': '()'}, W={}, output=False, score=7.000


## Black box SMK example

In [23]:
u = [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]
SCM = get_bbSMK_SCM(3, u)

In [24]:
SCM["instance"][-1]

1

In [25]:
SCM["simulation"]([[(1,0),(2,0),(4,0),(5,0)]])

[([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], 0.0, 5)]

In [26]:
causes = beam_search(**SCM, early_stop=False, verbose=1, beam_size=50, max_steps=10)

100%|██████████| 10/10 [00:00<00:00, 26.96it/s]

----> Found 4 causes.
C={'FS-U2': '0', 'FS-U3': '0', 'FN-U2': '0', 'FN-U3': '0'}, W={}, output=0.0, score=5.000





In [27]:
show_rules(causes, SCM["variables"])

C={'FS-U2': '0', 'FS-U3': '0', 'FN-U2': '0', 'FN-U3': '0'}, W={}, output=0.0, score=5.000
C={'FS-U3': '0', 'FDB-U2': '0', 'FN-U3': '0', 'FF-U2': '0'}, W={}, output=0.0, score=5.000
C={'FF-U3': '0', 'FS-U2': '0', 'FN-U2': '0'}, W={}, output=0.0, score=6.000
C={'FF-U3': '0', 'FDB-U2': '0', 'FF-U2': '0'}, W={}, output=0.0, score=6.000


## Noisy SMK example

In [28]:
bs = 50
N = 50
eps = .35
nl = .01
u = [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]

### Suzy example

In [29]:
lucb_params = {"beam_size":bs, "a":.25, "cause_eps":.05, 
               "beam_eps":.1, "max_iter":N, "verbose":2, 
               "batch_size":1, "non_cause_esp":.05}
nSuzy_SCM = get_noisy_suzy_SCM(lucb_params=lucb_params)

In [30]:
causes = beam_search(**nSuzy_SCM, epsilon=.25, beam_size=bs)

100%|██████████| 4/4 [00:00<00:00, 1319.27it/s]
 34%|███▍      | 69/200 [00:00<00:00, 393.23it/s]


Success: beam_bound=0.0000 / cause_bound=0.0000 / non_cause_bound=0.0402)


100%|██████████| 18/18 [00:00<00:00, 49184.02it/s]
 83%|████████▎ | 751/900 [00:01<00:00, 454.67it/s] 

Success: beam_bound=0.0000 / cause_bound=0.0498 / non_cause_bound=0.0447)





In [31]:
show_rules(causes, nSuzy_SCM["variables"])

C={'ST': '0'}, W={'BH': '0'}, output=0.008130081300813009, score=0.198
C={'SH': '0'}, W={'BH': '0'}, output=0.008130081300813009, score=0.397


### Average SMK example

In [None]:
SCM_avg = get_nSMK_SCM(3, u, do_lucb=False, N=N, nl=nl)

In [None]:
SCM_avg["simulation"]([[(33,0)]])[0][1:]

In [None]:
np.random.seed(0)
causes = beam_search(**SCM_avg, max_steps=4, beam_size=bs,early_stop=False,verbose=1, epsilon=eps)
print("\nRESULTS\n")
show_rules(causes, SCM_avg["variables"])

### LUCB SMK example

In [33]:
lucb_params = {"beam_size":bs, "a":eps, "cause_eps":.01, 
               "beam_eps":.7, "max_iter":N, "verbose":2, 
               "batch_size":30, "non_cause_esp":.05}
SCM_lucb = get_nSMK_SCM(3, u, do_lucb=True, N=N, nl=nl, lucb_params=lucb_params)

In [34]:
np.random.seed(0)
causes = beam_search(**SCM_lucb, max_steps=4, beam_size=bs, early_stop=False,verbose=2, epsilon=eps)
print("\nRESULTS\n")
show_rules(causes, SCM_lucb["variables"])

Evaluating 35 rules


100%|██████████| 35/35 [00:00<00:00, 255.51it/s]
 65%|██████▌   | 1140/1750 [00:00<00:00, 44614.01it/s] 


Success: beam_bound=0.0000 / cause_bound=-0.0336 / non_cause_bound=-0.0672)
Number of causes found: 1
Number of non-causes remaining: 34
Best non-cause:
C={'FS-U1': '1'}, W={}, output=1.0, score=0.558
Worst non-cause:
C={'SD': '1'}, W={}, output=1.0, score=0.533
Evaluating 1717 rules


100%|██████████| 1717/1717 [00:04<00:00, 358.60it/s]
 65%|██████▍   | 55530/85850 [00:03<00:02, 15007.98it/s] 


Success: beam_bound=0.6913 / cause_bound=0.0099 / non_cause_bound=0.0413)
Number of causes found: 3
Number of non-causes remaining: 1705
Best non-cause:
C={'FS-U2': '0'}, W={'FS-U1': '0'}, output=0.9666666666666667, score=0.476
Worst non-cause:
C={'SD': '1'}, W={'DK': '1'}, output=0.9666666666666667, score=0.533
Evaluating 2763 rules


100%|██████████| 2763/2763 [00:08<00:00, 337.55it/s]
 86%|████████▌ | 118590/138150 [00:34<00:05, 3460.12it/s] 


Success: beam_bound=0.6902 / cause_bound=0.0091 / non_cause_bound=0.0402)
Number of causes found: 4
Number of non-causes remaining: 2756
Best non-cause:
C={'FF-U3': '0', 'FS-U2': '0'}, W={'FS-U1': '0'}, output=0.9666666666666667, score=0.419
Worst non-cause:
C={'SD': '1', 'GK-U3': '0'}, W={'SD-U3': '0'}, output=0.9666666666666667, score=0.500
Evaluating 2942 rules


100%|██████████| 2942/2942 [00:11<00:00, 251.45it/s]
 85%|████████▍ | 125010/147100 [00:36<00:06, 3420.50it/s] 


Success: beam_bound=0.6922 / cause_bound=0.0000 / non_cause_bound=0.0422)
Number of causes found: 0
Number of non-causes remaining: 2942
Best non-cause:
C={'FF-U3': '0', 'FS-U2': '0', 'GK-U1': '0'}, W={'FS-U1': '0'}, output=1.0, score=0.394
Worst non-cause:
C={'SD': '1', 'GK-U3': '0'}, W={'DK-U1': '0', 'SD-U1': '0'}, output=1.0, score=0.494
----> Found 8 causes.
C={'GK-U2': '0'}, W={'DK-U3': '0'}, output=0.05333333333333334, score=0.393

RESULTS

C={'GK-U2': '0'}, W={'DK-U3': '0'}, output=0.05333333333333334, score=0.393
C={'DK': '0'}, W={}, output=0.058333333333333334, score=0.446
C={'GP-U2': '0'}, W={'DK-U3': '0'}, output=0.07333333333333333, score=0.393
C={'FS-U2': '0', 'FN-U2': '0'}, W={'DK-U3': '0'}, output=0.10476190476190476, score=0.342
C={'DK-U2': '0'}, W={'DK-U3': '0'}, output=0.12380952380952381, score=0.426
C={'FDB-U2': '0', 'GP-U3': '0', 'FF-U2': '0'}, W={}, output=0.12962962962962962, score=0.315
C={'FDB-U2': '0', 'GK-U3': '0', 'FF-U2': '0'}, W={}, output=0.12962962962962