Skip to content

Commit

Permalink
"uber#596: Bug fix - Allow using externally generated propensity scor…
Browse files Browse the repository at this point in the history
…es in estimate_ate without triggering a retraining of the model."
  • Loading branch information
AlxClt committed Jan 17, 2023
1 parent aa02308 commit 2d8b604
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions causalml/inference/meta/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,19 @@ def estimate_ate(
Returns:
The mean and confidence interval (LB, UB) of the ATE estimate.
"""
if pretrain and p is None:
# when p is null, use pretrain propensity score
if not self.propensity:
raise ValueError("no propensity score, please call fit() first")
te, dhat_cs, dhat_ts = self.predict(
X, treatment, y, p=self.propensity, return_components=True
)
if pretrain:
if p is None:
# when p is null, use pretrain propensity score
if not self.propensity:
raise ValueError("no propensity score, please call fit() first")
te, dhat_cs, dhat_ts = self.predict(
X, treatment, y, p=self.propensity, return_components=True
)
else:
p = self._format_p(p, self.t_groups)
te, dhat_cs, dhat_ts = self.predict(
X, treatment, y, p=p, return_components=True
)
else:
te, dhat_cs, dhat_ts = self.fit_predict(
X, treatment, y, p, return_components=True
Expand Down

0 comments on commit 2d8b604

Please sign in to comment.