diff --git a/doubleml/data/base_data.py b/doubleml/data/base_data.py index a7ae30f6..93543e8b 100644 --- a/doubleml/data/base_data.py +++ b/doubleml/data/base_data.py @@ -463,7 +463,9 @@ def x_cols(self): @x_cols.setter def x_cols(self, value): reset_value = hasattr(self, "_x_cols") + if value is not None: + # Basic checks if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -476,7 +478,17 @@ def x_cols(self, value): if not set(value).issubset(set(self.all_variables)): raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.") assert set(value).issubset(set(self.all_variables)) - self._x_cols = value + + if not reset_value: + self._x_cols = value + else: + previous_value = self._x_cols + self._x_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._x_cols = previous_value + raise else: excluded_cols = {self.y_col} | set(self.d_cols) @@ -486,8 +498,6 @@ def x_cols(self, value): self._x_cols = [col for col in self.data.columns if col not in excluded_cols] if reset_value: - self._check_disjoint_sets() - # by default, we initialize to the first treatment variable self.set_x_d(self.d_cols[0]) @property @@ -500,6 +510,8 @@ def d_cols(self): @d_cols.setter def d_cols(self, value): reset_value = hasattr(self, "_d_cols") + + # Basic checks if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -511,10 +523,19 @@ def d_cols(self, value): raise ValueError("Invalid treatment variable(s) d_cols: Contains duplicate values.") if not set(value).issubset(set(self.all_variables)): raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.") - self._d_cols = value + + if not reset_value: + self._d_cols = value + else: + previous_value = self._d_cols + self._d_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._d_cols = previous_value + raise + if reset_value: - self._check_disjoint_sets() - # by default, we initialize to the first treatment variable self.set_x_d(self.d_cols[0]) @property @@ -527,15 +548,27 @@ def y_col(self): @y_col.setter def y_col(self, value): reset_value = hasattr(self, "_y_col") + + # Basic checks if not isinstance(value, str): raise TypeError( f"The outcome variable y_col must be of str type. {str(value)} of type {str(type(value))} was passed." ) if value not in self.all_variables: raise ValueError(f"Invalid outcome variable y_col. {value} is no data column.") - self._y_col = value + + if not reset_value: + self._y_col = value + else: + previous_value = self._y_col + self._y_col = value + try: + self._check_disjoint_sets() + except ValueError: + self._y_col = previous_value + raise + if reset_value: - self._check_disjoint_sets() self._set_y_z() @property @@ -548,7 +581,9 @@ def z_cols(self): @z_cols.setter def z_cols(self, value): reset_value = hasattr(self, "_z_cols") + if value is not None: + # Basic validation if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -562,12 +597,22 @@ def z_cols(self, value): raise ValueError( "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column." ) - self._z_cols = value + + if not reset_value: + self._z_cols = value + else: + previous_value = self._z_cols + self._z_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._z_cols = previous_value + raise + else: self._z_cols = None if reset_value: - self._check_disjoint_sets() self._set_y_z() @property diff --git a/doubleml/data/tests/test_dml_data.py b/doubleml/data/tests/test_dml_data.py index 4890ac7a..9fb72934 100644 --- a/doubleml/data/tests/test_dml_data.py +++ b/doubleml/data/tests/test_dml_data.py @@ -619,3 +619,45 @@ def test_dml_data_w_missing_d(generate_data1): assert dml_data.force_all_d_finite is False dml_data.force_all_d_finite = "allow-nan" assert dml_data.force_all_d_finite == "allow-nan" + + +@pytest.mark.ci +def test_property_setter_rollback_on_validation_failure(): + """Test that property setters don't mutate the object if validation fails.""" + np.random.seed(3141) + dml_data = make_plr_CCDDHNR2018(n_obs=100) + + # Store original values + original_y_col = dml_data.y_col + original_d_cols = dml_data.d_cols.copy() + original_x_cols = dml_data.x_cols.copy() + original_z_cols = dml_data.z_cols + + # Test y_col setter - try to set y_col to a value that's already in d_cols + with pytest.raises( + ValueError, match=r"d cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``" + ): + dml_data.y_col = "d" + # Object should remain unchanged + assert dml_data.y_col == original_y_col + + # Test d_cols setter - try to set d_cols to include the outcome variable + with pytest.raises( + ValueError, match=r"y cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``" + ): + dml_data.d_cols = ["y", "d"] + # Object should remain unchanged + assert dml_data.d_cols == original_d_cols + + # Test x_cols setter - try to set x_cols to include the outcome variable + with pytest.raises(ValueError, match=r"y cannot be set as outcome variable ``y_col`` and covariate in ``x_cols``"): + dml_data.x_cols = ["X1", "y", "X2"] + # Object should remain unchanged + assert dml_data.x_cols == original_x_cols + + # Test z_cols setter - try to set z_cols to include the outcome variable + msg = r"At least one variable/column is set as outcome variable \(``y_col``\) and instrumental variable \(``z_cols``\)" + with pytest.raises(ValueError, match=msg): + dml_data.z_cols = ["y"] + # Object should remain unchanged + assert dml_data.z_cols == original_z_cols