Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions doubleml/data/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def x_cols(self):
@x_cols.setter
def x_cols(self, value):
reset_value = hasattr(self, "_x_cols")

if value is not None:
# Basic checks
if isinstance(value, str):
value = [value]
if not isinstance(value, list):
Expand All @@ -476,7 +478,17 @@ def x_cols(self, value):
if not set(value).issubset(set(self.all_variables)):
raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.")
assert set(value).issubset(set(self.all_variables))
self._x_cols = value

if not reset_value:
self._x_cols = value
else:
previous_value = self._x_cols
self._x_cols = value
try:
self._check_disjoint_sets()
except ValueError:
self._x_cols = previous_value
raise

else:
excluded_cols = {self.y_col} | set(self.d_cols)
Expand All @@ -486,8 +498,6 @@ def x_cols(self, value):
self._x_cols = [col for col in self.data.columns if col not in excluded_cols]

if reset_value:
self._check_disjoint_sets()
# by default, we initialize to the first treatment variable
self.set_x_d(self.d_cols[0])

@property
Expand All @@ -500,6 +510,8 @@ def d_cols(self):
@d_cols.setter
def d_cols(self, value):
reset_value = hasattr(self, "_d_cols")

# Basic checks
if isinstance(value, str):
value = [value]
if not isinstance(value, list):
Expand All @@ -511,10 +523,19 @@ def d_cols(self, value):
raise ValueError("Invalid treatment variable(s) d_cols: Contains duplicate values.")
if not set(value).issubset(set(self.all_variables)):
raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.")
self._d_cols = value

if not reset_value:
self._d_cols = value
else:
previous_value = self._d_cols
self._d_cols = value
try:
self._check_disjoint_sets()
except ValueError:
self._d_cols = previous_value
raise

if reset_value:
self._check_disjoint_sets()
# by default, we initialize to the first treatment variable
self.set_x_d(self.d_cols[0])

@property
Expand All @@ -527,15 +548,27 @@ def y_col(self):
@y_col.setter
def y_col(self, value):
reset_value = hasattr(self, "_y_col")

# Basic checks
if not isinstance(value, str):
raise TypeError(
f"The outcome variable y_col must be of str type. {str(value)} of type {str(type(value))} was passed."
)
if value not in self.all_variables:
raise ValueError(f"Invalid outcome variable y_col. {value} is no data column.")
self._y_col = value

if not reset_value:
self._y_col = value
else:
previous_value = self._y_col
self._y_col = value
try:
self._check_disjoint_sets()
except ValueError:
self._y_col = previous_value
raise

if reset_value:
self._check_disjoint_sets()
self._set_y_z()

@property
Expand All @@ -548,7 +581,9 @@ def z_cols(self):
@z_cols.setter
def z_cols(self, value):
reset_value = hasattr(self, "_z_cols")

if value is not None:
# Basic validation
if isinstance(value, str):
value = [value]
if not isinstance(value, list):
Expand All @@ -562,12 +597,22 @@ def z_cols(self, value):
raise ValueError(
"Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
)
self._z_cols = value

if not reset_value:
self._z_cols = value
else:
previous_value = self._z_cols
self._z_cols = value
try:
self._check_disjoint_sets()
except ValueError:
self._z_cols = previous_value
raise

else:
self._z_cols = None

if reset_value:
self._check_disjoint_sets()
self._set_y_z()

@property
Expand Down
42 changes: 42 additions & 0 deletions doubleml/data/tests/test_dml_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,45 @@ def test_dml_data_w_missing_d(generate_data1):
assert dml_data.force_all_d_finite is False
dml_data.force_all_d_finite = "allow-nan"
assert dml_data.force_all_d_finite == "allow-nan"


@pytest.mark.ci
def test_property_setter_rollback_on_validation_failure():
"""Test that property setters don't mutate the object if validation fails."""
np.random.seed(3141)
dml_data = make_plr_CCDDHNR2018(n_obs=100)

# Store original values
original_y_col = dml_data.y_col
original_d_cols = dml_data.d_cols.copy()
original_x_cols = dml_data.x_cols.copy()
original_z_cols = dml_data.z_cols

# Test y_col setter - try to set y_col to a value that's already in d_cols
with pytest.raises(
ValueError, match=r"d cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``"
):
dml_data.y_col = "d"
# Object should remain unchanged
assert dml_data.y_col == original_y_col

# Test d_cols setter - try to set d_cols to include the outcome variable
with pytest.raises(
ValueError, match=r"y cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``"
):
dml_data.d_cols = ["y", "d"]
# Object should remain unchanged
assert dml_data.d_cols == original_d_cols

# Test x_cols setter - try to set x_cols to include the outcome variable
with pytest.raises(ValueError, match=r"y cannot be set as outcome variable ``y_col`` and covariate in ``x_cols``"):
dml_data.x_cols = ["X1", "y", "X2"]
# Object should remain unchanged
assert dml_data.x_cols == original_x_cols

# Test z_cols setter - try to set z_cols to include the outcome variable
msg = r"At least one variable/column is set as outcome variable \(``y_col``\) and instrumental variable \(``z_cols``\)"
with pytest.raises(ValueError, match=msg):
dml_data.z_cols = ["y"]
# Object should remain unchanged
assert dml_data.z_cols == original_z_cols