Skip to content

Commit

Permalink
Doubly robust models renaming + multiple treatment-feature options (#28)
Browse files Browse the repository at this point in the history
* Change `DoublyRobust` to `BaseDoublyRobust`

* Change `DoublyRobustVanilla` to `ResidualCorrectedStandardization`

* Change `DoublyRobustJoffe` to `WeightedStandardization`

* Change `DoublyRobustIpFeature` to `PropensityFeatureStandardization`

* Multiple treatment-based features in `PropensityFeatureStandardization`

Adds several options for propensity features:
 1. inverse-propensity weight: 1/Pr[A=a_i|X].
 2. inverse-propensity matrix: 1/Pr[A=a|X] for all possible `a`s.
 3. propensity vector: Pr[A=1|X]
 4. logit-transformed propensities: logit(Pr[A=1|X])
 5. propensity matrix: Pr[A=a|X] for all possible `a`s.

Also adds corresponding tests for these options.
  • Loading branch information
ehudkr committed Jan 25, 2022
1 parent 6572fae commit 63e171d
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 89 deletions.
2 changes: 1 addition & 1 deletion causallib/estimation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .doubly_robust import DoublyRobustIpFeature, DoublyRobustJoffe, DoublyRobustVanilla
from .doubly_robust import PropensityFeatureStandardization, WeightedStandardization, ResidualCorrectedStandardization
from .ipw import IPW
from .overlap_weights import OverlapWeights
from .standardization import Standardization, StratifiedStandardization
Expand Down
155 changes: 125 additions & 30 deletions causallib/estimation/doubly_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
import warnings

import pandas as pd
import numpy as np

from .base_estimator import IndividualOutcomeEstimator
from .base_weight import WeightEstimator
from .base_weight import WeightEstimator, PropensityEstimator
from ..utils import general_tools as g_tools


class DoublyRobust(IndividualOutcomeEstimator):
class BaseDoublyRobust(IndividualOutcomeEstimator):
"""
Abstract class defining the interface and general initialization of specific doubly-robust methods.
"""
Expand All @@ -52,7 +53,7 @@ def __init__(self, outcome_model, weight_model,
If None - all covariates passed will be used.
Either list of column names or boolean mask.
"""
super(DoublyRobust, self).__init__(lambda **x: None) # Dummy initialization
super(BaseDoublyRobust, self).__init__(lambda **x: None) # Dummy initialization
delattr(self, "learner") # To remove the learner attribute a IndividualOutcomeEstimator has
self.outcome_model = outcome_model
self.weight_model = weight_model
Expand Down Expand Up @@ -98,25 +99,23 @@ def __repr__(self):
return repr_string


class DoublyRobustVanilla(DoublyRobust):
class ResidualCorrectedStandardization(BaseDoublyRobust):
"""
Given the measured outcome Y, the assignment Y, and the coefficients X calculate a doubly-robust estimator
of the effect of treatment
Calculates a doubly-robust estimate of the treatment effect by performing
potential-outcome prediction (`outcome_model`) and then correcting its
prediction-residuals using re-weighting from a treatment model (`weight_model`, like IPW).
Let e(X) be the estimated propensity score and m(X, A) is the estimated effect by an estimator,
then the individual estimates are:
| Y + (A-e(X))*(Y-m(X,1)) / e(X) if A==1, and
| Y + (e(X)-A)*(Y-m(X,0)) / (1-e(X)) if A==0
| m(X,1) + A*(Y-m(X,1))/e(X), and
| m(X,0) + (1-A)*(Y-m(X,0))/(1-e(X))
These expressions show that when e(X) is an unbiased estimator of A, or when m is an unbiased estimator of Y
then the resulting estimator is unbiased. Note that the term for A==0 is derived from (1-A)-(1-e(X))
Another way of writing these equation is by "correcting" the individual prediction rather than the individual
outcome:
| m(X,1) + A*(Y-m(X,1))/e(X), and
| m(X,0) + (1-A)*(Y-m(X,0))/(1-e(X))
Kang and Schafer (https://dx.doi.org/10.1214/07-STS227) attribute this method to
Cassel, Särndal and Wretman.
"""

def fit(self, X, a, y, refit_weight_model=True, **kwargs):
Expand Down Expand Up @@ -247,15 +246,50 @@ def estimate_effect(self, outcome1, outcome2, agg="population", effect_types="di
"not corrected for population effect.\n"
"In case you want individual effect use agg='individual', or in case you want population"
"effect use the estimate_population_effect() output as your input to this function.")
effect = super(DoublyRobustVanilla, self).estimate_effect(outcome1, outcome2, agg, effect_types)
effect = super(ResidualCorrectedStandardization, self).estimate_effect(outcome1, outcome2, agg, effect_types)
return effect


class DoublyRobustIpFeature(DoublyRobust):
"""
A doubly-robust estimator of the effect of treatment.
This model adds the weighting (inverse probability weighting) as feature to the model.
"""
class PropensityFeatureStandardization(BaseDoublyRobust):
def __init__(self, outcome_model, weight_model,
outcome_covariates=None, weight_covariates=None,
feature_type="weight_vector"):
"""
A doubly-robust estimator of the effect of treatment.
This model adds the weighting (inverse probability weighting)
as additional feature to the outcome model.
References:
* Bang and Robins, https://doi.org/10.1111/j.1541-0420.2005.00377.x
* Kang and Schafer, section 3.3, https://dx.doi.org/10.1214/07-STS227
Args:
outcome_model(IndividualOutcomeEstimator): A causal model that estimate on individuals level
weight_model (WeightEstimator | PropensityEstimator): A causal model for weighting individuals (e.g. IPW).
outcome_covariates (array): Covariates to use for outcome model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
weight_covariates (array): Covariates to use for weight model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
feature_type (str): the type of covariate to add. One of the following options:
* "weight_vector": uses a signed weight vector. Only defined for binary treatment.
For example, if `weight_model` is IPW then: 1/Pr[A=a_i|X] for each sample `i`.
As described in Bang and Robins (2005).
* "weight_matrix": uses the entire weight matrix.
For example, if `weight_model` is IPW then: 1/Pr[A_i=a|X_i=x],
for all treatment values `a` and for every sample `i`.
* "propensity_vector": uses the probabilities for being in treatment group: Pr[A=1|X].
Better defined for binary treatment.
Equivalent to Scharfstein, Rotnitzky, and Robins (1999) that use its inverse.
* "logit_propensity_vector": uses logit transformation of the propensity to treat Pr[A=1|X].
As described in Kang and Schafer (2007)
* "propensity_matrix": uses the probabilities for all treatment options,
Pr[A_i=a|X_i=x] for all treatment values `a` and samples `i`.
"""
super().__init__(outcome_model, weight_model,
outcome_covariates, weight_covariates)
self.feature_type = feature_type

self._feature_functions = self._define_feature_functions()

def estimate_individual_outcome(self, X, a, treatment_values=None, predict_proba=None):
X_augmented = self._augment_outcome_model_data(X, a)
Expand All @@ -276,12 +310,13 @@ def _augment_outcome_model_data(self, X, a):
matrix (W | X).
"""
X_outcome, X_weight = self._prepare_data(X, a)
try:
weights_feature = self.weight_model.compute_weight_matrix(X_weight, a)
weights_feature = weights_feature.add_prefix("ipf_")
except NotImplementedError:
weights_feature = self.weight_model.compute_weights(X_weight, a)
weights_feature = weights_feature.rename("ipf")
feature_func = self._feature_functions.get(self.feature_type)
if feature_func is None:
raise ValueError(
f"feature type {self.feature_type} is not recognized."
f"Supported options are: {set(self._feature_functions.keys())}"
)
weights_feature = feature_func(X_weight, a)
# Let standardization deal with incorporating treatment assignment (a) into the data:
X_augmented = pd.concat([weights_feature, X_outcome], join="outer", axis="columns")
return X_augmented
Expand All @@ -300,12 +335,72 @@ def fit(self, X, a, y, refit_weight_model=True, **kwargs):
self.outcome_model.fit(X=X_augmented, y=y, a=a)
return self


class DoublyRobustJoffe(DoublyRobust):
def _define_feature_functions(self):

def weight_vector(X, a):
w = self.weight_model.compute_weights(X, a)
w = w.rename("ipf")
return w

def signed_weight_vector(X, a):
if a.nunique() != 2:
raise AssertionError(
f"`feature_type` 'weight_vector' can only be used with binary treatment."
f"Instead, treatment values are {set(a)}."
)
w = weight_vector(X, a)
w[a == 0] *= -1
return w

def weight_matrix(X, a):
W = self.weight_model.compute_weight_matrix(X, a)
W = W.add_prefix("ipf_")
return W

def masked_weight_matrix(X, a):
W = weight_matrix(X, a)
A = pd.get_dummies(a)
A = A.add_prefix("ipf_") # To match naming of `W`
W_masked = W * A
return W_masked

def propensity_vector(X, a):
p = self.weight_model.compute_propensity(X, a)
p = p.rename("propensity")
return p

def logit_propensity_vector(X, a, safe=True):
p = propensity_vector(X, a)
if safe:
epsilon = np.finfo(float).eps
p = np.clip(p, epsilon, 1 - epsilon)
return np.log(p / (1 - p))

def propensity_matrix(X, a):
P = self.weight_model.compute_propensity_matrix(X)
# P = P.iloc[:, 1:] # Drop first column
P = P.add_prefix("propensity_")
return P

feature_functions = {
"weight_vector": weight_vector,
# "signed_weight_vector": signed_weight_vector, # Hernan & Robins Fine Point 13.2, but seems to be biased
"weight_matrix": weight_matrix,
# "masked_weight_matrix": masked_weight_matrix, # Seems to be biased
"propensity_vector": propensity_vector,
"logit_propensity_vector": logit_propensity_vector,
"propensity_matrix": propensity_matrix,
}
return feature_functions


class WeightedStandardization(BaseDoublyRobust):
"""
A doubly-robust estimator of the effect of treatment.
This model uses the weights from the weight-model (e.g. inverse probability weighting) as individual weights for
fitting the outcome model.
This model uses the weights from the weight-model (e.g. inverse probability weighting)
as individual weights for fitting the outcome model.
References:
* Kang and Schafer, section 3.2, https://dx.doi.org/10.1214/07-STS227
"""

def estimate_individual_outcome(self, X, a, treatment_values=None, predict_proba=None):
Expand Down
2 changes: 1 addition & 1 deletion causallib/estimation/tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.utils.multiclass import type_of_target

from .doubly_robust import DoublyRobust as BaseDoublyRobust
from .doubly_robust import BaseDoublyRobust
from causallib.estimation.base_estimator import IndividualOutcomeEstimator
from causallib.estimation.base_weight import PropensityEstimator
from causallib.utils.stat_utils import robust_lookup
Expand Down
76 changes: 64 additions & 12 deletions causallib/tests/test_doublyrobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from sklearn.linear_model import LogisticRegression, LinearRegression
from warnings import simplefilter, catch_warnings

from causallib.estimation import DoublyRobustVanilla, DoublyRobustIpFeature, DoublyRobustJoffe
from causallib.estimation import (
ResidualCorrectedStandardization, PropensityFeatureStandardization, WeightedStandardization
)
from causallib.estimation import IPW
from causallib.estimation import Standardization, StratifiedStandardization

Expand Down Expand Up @@ -124,7 +126,7 @@ def ensure_model_combinations_work(self, estimator_class):
self.assertTrue(True) # Dummy assert, didn't crash
with self.subTest("Check prediction"):
ind_outcome = dr.estimate_individual_outcome(data["X"], data["a"])
y = data["y"] if isinstance(dr, DoublyRobustVanilla) else None # Avoid warnings
y = data["y"] if isinstance(dr, ResidualCorrectedStandardization) else None # Avoid warnings
pop_outcome = dr.estimate_population_outcome(data["X"], data["a"], y)
dr.estimate_effect(ind_outcome[1], ind_outcome[0], agg="individual")
dr.estimate_effect(pop_outcome[1], pop_outcome[0])
Expand Down Expand Up @@ -189,14 +191,14 @@ def ensure_many_models(self, clip_min=None, clip_max=None):
self.assertTrue(True) # Fit did not crash


class TestDoublyRobustVanilla(TestDoublyRobustBase):
class TestResidualCorrectedStandardization(TestDoublyRobustBase):
@classmethod
def setUpClass(cls):
TestDoublyRobustBase.setUpClass()
# Avoids regularization of the model:
ipw = IPW(LogisticRegression(C=1e6, solver='lbfgs'), use_stabilized=False)
std = Standardization(LinearRegression(normalize=True))
cls.estimator = DoublyRobustVanilla(std, ipw)
cls.estimator = ResidualCorrectedStandardization(std, ipw)

def test_uninformative_tx_leads_to_std_like_results(self):
self.ensure_uninformative_tx_leads_to_std_like_results(self.estimator)
Expand All @@ -214,7 +216,7 @@ def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

def test_model_combinations_work(self):
self.ensure_model_combinations_work(DoublyRobustVanilla)
self.ensure_model_combinations_work(ResidualCorrectedStandardization)

def test_pipeline_learner(self):
self.ensure_pipeline_learner()
Expand All @@ -223,14 +225,14 @@ def test_many_models(self):
self.ensure_many_models()


class TestDoublyRobustJoffe(TestDoublyRobustBase):
class TestWeightedStandardization(TestDoublyRobustBase):
@classmethod
def setUpClass(cls):
TestDoublyRobustBase.setUpClass()
# Avoids regularization of the model:
ipw = IPW(LogisticRegression(C=1e6, solver='lbfgs'), use_stabilized=False)
std = Standardization(LinearRegression(normalize=True))
cls.estimator = DoublyRobustJoffe(std, ipw)
cls.estimator = WeightedStandardization(std, ipw)

def test_uninformative_tx_leads_to_std_like_results(self):
self.ensure_uninformative_tx_leads_to_std_like_results(self.estimator)
Expand All @@ -248,7 +250,7 @@ def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

def test_model_combinations_work(self):
self.ensure_model_combinations_work(DoublyRobustJoffe)
self.ensure_model_combinations_work(WeightedStandardization)

def test_pipeline_learner(self):
self.ensure_pipeline_learner()
Expand Down Expand Up @@ -300,14 +302,14 @@ def test_many_models(self):
model.fit(data["X"], data["a"], data["y"], refit_weight_model=False)


class TestDoublyRobustIPFeature(TestDoublyRobustBase):
class TestPropensityFeatureStandardization(TestDoublyRobustBase):
@classmethod
def setUpClass(cls):
TestDoublyRobustBase.setUpClass()
# Avoids regularization of the model:
ipw = IPW(LogisticRegression(C=1e6, solver='lbfgs'), use_stabilized=False)
std = Standardization(LinearRegression(normalize=True))
cls.estimator = DoublyRobustIpFeature(std, ipw)
cls.estimator = PropensityFeatureStandardization(std, ipw)

def fit_and_predict_all_learners(self, data, estimator):
X, a, y = data["X"], data["a"], data["y"]
Expand All @@ -327,16 +329,66 @@ def test_is_fitted(self):
self.ensure_is_fitted(self.estimator)

def test_data_is_separated_between_models(self):
self.ensure_data_is_separated_between_models(self.estimator, 2 + 1) # 2 ip-features + 1 treatment assignment
self.ensure_data_is_separated_between_models(self.estimator, 1 + 1) # 1 ip-feature + 1 treatment assignment

def test_weight_refitting_refits(self):
self.ensure_weight_refitting_refits(self.estimator)

def test_model_combinations_work(self):
self.ensure_model_combinations_work(DoublyRobustIpFeature)
self.ensure_model_combinations_work(PropensityFeatureStandardization)

def test_pipeline_learner(self):
self.ensure_pipeline_learner()

def test_many_models(self):
self.ensure_many_models(clip_min=0.001, clip_max=1-0.001)

def test_many_feature_types(self):
with self.subTest("Ensure all feature types are tested"):
feature_types = [
"weight_vector", # "signed_weight_vector",
"weight_matrix", # "masked_weight_matrix",
"propensity_vector", "propensity_matrix",
"logit_propensity_vector",
]
model_feature_types = set(self.estimator._feature_functions.keys())
if set(feature_types) != model_feature_types:
raise AssertionError(
"Hey there, there's a mismatch between `PropensityFeatureStandardization._feature_types"
"and its corresponding tests. Did you add a new type without testing?"
)

use_tmle_data = True
if use_tmle_data: # Align the datasets to the same attributes
from causallib.tests.test_tmle import generate_data
data = generate_data(1100, 2, 0, seed=0)
data['y'] = data['y_cont']
else:
data = self.create_uninformative_ox_dataset()
data['treatment_effect'] = data['beta']

for feature_type in feature_types:
with self.subTest(f"Testing {feature_type}"):
self.estimator.feature_type = feature_type
self.estimator.fit(data['X'], data['a'], data['y'])

# Test estimation:
pop_outcomes = self.estimator.estimate_population_outcome(data['X'], data['a'])
effect = pop_outcomes[1] - pop_outcomes[0]
np.testing.assert_allclose(
data['treatment_effect'], effect,
atol=0.05
)

# Test added covariates:
X_size = data['X'].shape[1]
added_covariates = 1 if "vector" in feature_type else 2 # Else it's a matrix
n_coefs = self.estimator.outcome_model.learner.coef_.size
self.assertEqual(n_coefs, X_size + added_covariates + 1) # 1 for treatment assignment

# with self.subTest("Test signed_weight_vector takes only binary", skip=True):
# a = data['a'].copy()
# a.iloc[-a.shape[0] // 4:] += 1
# self.estimator.feature_type = "signed_weight_vector"
# with self.assertRaises(AssertionError):
# self.estimator.fit(data['X'], a, data['y'])
Loading

0 comments on commit 63e171d

Please sign in to comment.