Skip to content

Commit

Permalink
working sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
akelleh committed Feb 11, 2019
1 parent 18a03b0 commit 441d714
Show file tree
Hide file tree
Showing 3 changed files with 515 additions and 1,362 deletions.
6 changes: 3 additions & 3 deletions dowhy/api/causal_data_frame.py
Expand Up @@ -27,10 +27,10 @@ def __init__(self, pandas_obj):

def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, params=None, dot_graph=None,
common_causes=None, instruments=None, estimand_type='ate', proceed_when_unidentifiable=False,
keep_original_treatment=False):
keep_original_treatment=False, use_previous_sampler=False):
if not method:
raise Exception("You must specify a do sampling method.")
if not self._obj._causal_model:
if not self._obj._causal_model or not use_previous_sampler:
self._obj._causal_model = CausalModel(self._obj,
[xi for xi in x.keys()][0],
outcome,
Expand All @@ -41,7 +41,7 @@ def do(self, x, method=None, num_cores=1, variable_types={}, outcome=None, param
proceed_when_unidentifiable=proceed_when_unidentifiable)
self._obj._identified_estimand = self._obj._causal_model.identify_effect()
do_sampler_class = do_samplers.get_class_object(method + "_sampler")
if not self._obj._sampler:
if not self._obj._sampler or not use_previous_sampler:
self._obj._sampler = do_sampler_class(self._obj,
self._obj._identified_estimand,
self._obj._causal_model._treatment,
Expand Down
9 changes: 0 additions & 9 deletions dowhy/do_samplers/mcmc_sampler.py
Expand Up @@ -109,25 +109,16 @@ def make_intervention_effective(self, x):

def do_sample(self, x):
self.reset()
print(self._df.sample(10))
g_for_surgery = nx.DiGraph(self.g)
g_modified = self.do_x_surgery(g_for_surgery, x)
print(self._df.sample(10))

self._df = self.make_intervention_effective(x)
print(self._df.sample(10))

g_modified, trace = self.sample_prior_causal_model(g_modified,
self._df,
self._variable_types,
initialization_trace=self.fit_trace)
print(self._df.sample(10))

for col in self._df:
if col in trace and col not in self._treatment_names:
self._df[col] = trace[col]
print(self._df.sample(10))

return self._df.copy()

def _construct_sampler(self):
Expand Down

0 comments on commit 441d714

Please sign in to comment.