From 1c48f0eb951dbffa3a03e3d4b8dfb23f9d52246c Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Fri, 17 Oct 2025 10:56:12 +0200 Subject: [PATCH 1/4] add property setter test --- doubleml/data/tests/test_dml_data.py | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/doubleml/data/tests/test_dml_data.py b/doubleml/data/tests/test_dml_data.py index 4890ac7a..9fb72934 100644 --- a/doubleml/data/tests/test_dml_data.py +++ b/doubleml/data/tests/test_dml_data.py @@ -619,3 +619,45 @@ def test_dml_data_w_missing_d(generate_data1): assert dml_data.force_all_d_finite is False dml_data.force_all_d_finite = "allow-nan" assert dml_data.force_all_d_finite == "allow-nan" + + +@pytest.mark.ci +def test_property_setter_rollback_on_validation_failure(): + """Test that property setters don't mutate the object if validation fails.""" + np.random.seed(3141) + dml_data = make_plr_CCDDHNR2018(n_obs=100) + + # Store original values + original_y_col = dml_data.y_col + original_d_cols = dml_data.d_cols.copy() + original_x_cols = dml_data.x_cols.copy() + original_z_cols = dml_data.z_cols + + # Test y_col setter - try to set y_col to a value that's already in d_cols + with pytest.raises( + ValueError, match=r"d cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``" + ): + dml_data.y_col = "d" + # Object should remain unchanged + assert dml_data.y_col == original_y_col + + # Test d_cols setter - try to set d_cols to include the outcome variable + with pytest.raises( + ValueError, match=r"y cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``" + ): + dml_data.d_cols = ["y", "d"] + # Object should remain unchanged + assert dml_data.d_cols == original_d_cols + + # Test x_cols setter - try to set x_cols to include the outcome variable + with pytest.raises(ValueError, match=r"y cannot be set as outcome variable ``y_col`` and covariate in ``x_cols``"): + dml_data.x_cols = ["X1", "y", "X2"] + # Object should remain unchanged + assert dml_data.x_cols == original_x_cols + + # Test z_cols setter - try to set z_cols to include the outcome variable + msg = r"At least one variable/column is set as outcome variable \(``y_col``\) and instrumental variable \(``z_cols``\)" + with pytest.raises(ValueError, match=msg): + dml_data.z_cols = ["y"] + # Object should remain unchanged + assert dml_data.z_cols == original_z_cols From 2fcf523dde632bcf60e2562d7a96281810caffab Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Fri, 17 Oct 2025 10:56:49 +0200 Subject: [PATCH 2/4] update y_col setter --- doubleml/data/base_data.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/doubleml/data/base_data.py b/doubleml/data/base_data.py index a7ae30f6..b81c2ff3 100644 --- a/doubleml/data/base_data.py +++ b/doubleml/data/base_data.py @@ -527,15 +527,27 @@ def y_col(self): @y_col.setter def y_col(self, value): reset_value = hasattr(self, "_y_col") + + # Basic checks if not isinstance(value, str): raise TypeError( f"The outcome variable y_col must be of str type. {str(value)} of type {str(type(value))} was passed." ) if value not in self.all_variables: raise ValueError(f"Invalid outcome variable y_col. {value} is no data column.") - self._y_col = value + + if reset_value: + previous_value = self._y_col + self._y_col = value + try: + self._check_disjoint_sets() + except ValueError as e: + self._y_col = previous_value + raise e + else: + self._y_col = value + if reset_value: - self._check_disjoint_sets() self._set_y_z() @property From 2f12d4bbfa81c44485b33a8b1f1507c5c139eed8 Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Fri, 17 Oct 2025 11:30:58 +0200 Subject: [PATCH 3/4] update z and x cols setter --- doubleml/data/base_data.py | 53 +++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/doubleml/data/base_data.py b/doubleml/data/base_data.py index b81c2ff3..eee2d9ae 100644 --- a/doubleml/data/base_data.py +++ b/doubleml/data/base_data.py @@ -463,7 +463,9 @@ def x_cols(self): @x_cols.setter def x_cols(self, value): reset_value = hasattr(self, "_x_cols") + if value is not None: + # Basic checks if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -476,7 +478,17 @@ def x_cols(self, value): if not set(value).issubset(set(self.all_variables)): raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.") assert set(value).issubset(set(self.all_variables)) - self._x_cols = value + + if reset_value: + previous_value = self._x_cols + self._x_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._x_cols = previous_value + raise + else: + self._x_cols = value else: excluded_cols = {self.y_col} | set(self.d_cols) @@ -486,8 +498,6 @@ def x_cols(self, value): self._x_cols = [col for col in self.data.columns if col not in excluded_cols] if reset_value: - self._check_disjoint_sets() - # by default, we initialize to the first treatment variable self.set_x_d(self.d_cols[0]) @property @@ -500,6 +510,8 @@ def d_cols(self): @d_cols.setter def d_cols(self, value): reset_value = hasattr(self, "_d_cols") + + # Basic checks if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -511,10 +523,19 @@ def d_cols(self, value): raise ValueError("Invalid treatment variable(s) d_cols: Contains duplicate values.") if not set(value).issubset(set(self.all_variables)): raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.") - self._d_cols = value + + if reset_value: + previous_value = self._d_cols + self._d_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._d_cols = previous_value + raise + else: + self._d_cols = value + if reset_value: - self._check_disjoint_sets() - # by default, we initialize to the first treatment variable self.set_x_d(self.d_cols[0]) @property @@ -541,9 +562,9 @@ def y_col(self, value): self._y_col = value try: self._check_disjoint_sets() - except ValueError as e: + except ValueError: self._y_col = previous_value - raise e + raise else: self._y_col = value @@ -560,7 +581,9 @@ def z_cols(self): @z_cols.setter def z_cols(self, value): reset_value = hasattr(self, "_z_cols") + if value is not None: + # Basic validation if isinstance(value, str): value = [value] if not isinstance(value, list): @@ -574,12 +597,22 @@ def z_cols(self, value): raise ValueError( "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column." ) - self._z_cols = value + + if reset_value: + previous_value = self._z_cols + self._z_cols = value + try: + self._check_disjoint_sets() + except ValueError: + self._z_cols = previous_value + raise + else: + self._z_cols = value + else: self._z_cols = None if reset_value: - self._check_disjoint_sets() self._set_y_z() @property From c89526241dbd7d47da1d839601a23edb494d239c Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Fri, 17 Oct 2025 11:56:15 +0200 Subject: [PATCH 4/4] reorder setter values --- doubleml/data/base_data.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/doubleml/data/base_data.py b/doubleml/data/base_data.py index eee2d9ae..93543e8b 100644 --- a/doubleml/data/base_data.py +++ b/doubleml/data/base_data.py @@ -479,7 +479,9 @@ def x_cols(self, value): raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.") assert set(value).issubset(set(self.all_variables)) - if reset_value: + if not reset_value: + self._x_cols = value + else: previous_value = self._x_cols self._x_cols = value try: @@ -487,8 +489,6 @@ def x_cols(self, value): except ValueError: self._x_cols = previous_value raise - else: - self._x_cols = value else: excluded_cols = {self.y_col} | set(self.d_cols) @@ -524,7 +524,9 @@ def d_cols(self, value): if not set(value).issubset(set(self.all_variables)): raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.") - if reset_value: + if not reset_value: + self._d_cols = value + else: previous_value = self._d_cols self._d_cols = value try: @@ -532,8 +534,6 @@ def d_cols(self, value): except ValueError: self._d_cols = previous_value raise - else: - self._d_cols = value if reset_value: self.set_x_d(self.d_cols[0]) @@ -557,7 +557,9 @@ def y_col(self, value): if value not in self.all_variables: raise ValueError(f"Invalid outcome variable y_col. {value} is no data column.") - if reset_value: + if not reset_value: + self._y_col = value + else: previous_value = self._y_col self._y_col = value try: @@ -565,8 +567,6 @@ def y_col(self, value): except ValueError: self._y_col = previous_value raise - else: - self._y_col = value if reset_value: self._set_y_z() @@ -598,7 +598,9 @@ def z_cols(self, value): "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column." ) - if reset_value: + if not reset_value: + self._z_cols = value + else: previous_value = self._z_cols self._z_cols = value try: @@ -606,8 +608,6 @@ def z_cols(self, value): except ValueError: self._z_cols = previous_value raise - else: - self._z_cols = value else: self._z_cols = None