In [30]:
import networkx as nx
from bcause.factors import DeterministicFactor, MultinomialFactor
from bcause.models.cmodel import StructuralCausalModel
from bcause.inference.causal.elimination import CausalVariableElimination



In [31]:
# Define a DAG and the domains
dag = nx.DiGraph([("X", "Y"), ("U", "Y"), ("V", "X")])
domains = dict(X=["x1", "x2"], Y=["y1","y2"], U=["u1", "u2", "u3", "u4"], V=["v1", "v2"])


In [32]:
domx = {'V': domains['V'], 'X': domains['X']}
fx = DeterministicFactor(domx, right_vars=["V"], values=["x1", "x2"])
fx

<DeterministicFactor fX(V), cardinality = (V:2,X:2), values=[x1,x2]>

In [33]:
# the inner dimension is the rightmost variable (following the variable in the domain dict)
values = [['y1', 'y1', 'y2', 'y1'],['y2', 'y2', 'y1', 'y1']]
domy = {'X': domains["X"], 'U': domains["U"], 'Y': domains["Y"]}
fy = DeterministicFactor(domy, left_vars=["Y"], values=values)
fy

<DeterministicFactor fY(X,U), cardinality = (X:2,U:4,Y:2), values=[y1,y1,y2,y1,...,y1]>

In [34]:
domv = {"V": domains["V"]}
pv = MultinomialFactor(domv, values=[.5, .5])
pv

<MultinomialFactor P(V), cardinality = (V:2), values=[0.5,0.5]>

In [35]:
domu =  {"U": domains["U"]}
pu = MultinomialFactor(domu, values=[.2, .2, .6, .0])
pu

<MultinomialFactor P(U), cardinality = (U:4), values=[0.2,0.2,0.6,0.0]>

In [36]:
model = StructuralCausalModel(dag, [fx, fy, pu, pv])
model

<StructuralCausalModel (X:2,Y:2|U:4,V:2), dag=[X|V][Y|X:U][U][V]>

In [37]:
model.factors

{'X': <MultinomialFactor P(X|V), cardinality = (V:2,X:2), values=[1.0,0.0,0.0,1.0]>,
 'Y': <MultinomialFactor P(Y|X,U), cardinality = (X:2,U:4,Y:2), values=[1.0,0.0,1.0,0.0,...,0.0]>,
 'U': <MultinomialFactor P(U), cardinality = (U:4), values=[0.2,0.2,0.6,0.0]>,
 'V': <MultinomialFactor P(V), cardinality = (V:2), values=[0.5,0.5]>}

In [38]:
model.to_bnet()

<BayesianNetwork (X:2,Y:2,U:4,V:2), dag=[X|V][Y|X:U][U][V]>

In [39]:
# Run causal inference with Variable Elimination
cve = CausalVariableElimination(model)
cve.causal_query("Y", do=dict(X="x1"))

<MultinomialFactor P(Y), cardinality = (Y:2), values=[0.4,0.6]>

In [40]:
# Run a counterfactual query
cve.counterfactual_query("Y",do=dict(X="x1"), evidence=dict(X="x1"))

KeyError: 'V'