From cdc938882cdb9265186d1fe1d93ed9aec5f0bfa7 Mon Sep 17 00:00:00 2001 From: PhilippBach Date: Tue, 18 Nov 2025 14:31:40 +0100 Subject: [PATCH 1/5] add sensitivity results to ci dataframe --- doubleml/did/did_multi.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/doubleml/did/did_multi.py b/doubleml/did/did_multi.py index a9e9e790..8b155dfe 100644 --- a/doubleml/did/did_multi.py +++ b/doubleml/did/did_multi.py @@ -1431,6 +1431,8 @@ def _create_ci_dataframe(self, level=0.95, joint=True): - 'CI Lower': Lower bound of confidence intervals - 'CI Upper': Upper bound of confidence intervals - 'Pre-Treatment': Boolean indicating if evaluation period is before treatment + - 'RV': Robustness values (if sensitivity_analysis() has been called before) + - 'RVa': Robustness values for (1-a) confidence bounds (if sensitivity_analysis() has been called before) Notes ----- @@ -1459,5 +1461,11 @@ def _create_ci_dataframe(self, level=0.95, joint=True): "Pre-Treatment": [gt_combination[2] < gt_combination[0] for gt_combination in self.gt_combinations], } ) - - return df + if self._framework.sensitivity_params is not None: + df["RV"] = self.framework.sensitivity_params["rv"] + df["RVa"] = self.framework.sensitivity_params["rva"] + df["CI Lower Bound"] = self.framework.sensitivity_params["ci"]["lower"] + df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"] + df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"] + df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"] + return df \ No newline at end of file From 5f63746367e241bc93eb98fafbcb36a4910ccec3 Mon Sep 17 00:00:00 2001 From: PhilippBach Date: Tue, 18 Nov 2025 18:11:42 +0100 Subject: [PATCH 2/5] add plotting options for rv and bounds --- doubleml/did/did_multi.py | 93 +++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 23 deletions(-) diff --git a/doubleml/did/did_multi.py b/doubleml/did/did_multi.py index 8b155dfe..6080e409 100644 --- a/doubleml/did/did_multi.py +++ b/doubleml/did/did_multi.py @@ -979,12 +979,13 @@ def aggregate(self, aggregation="group"): def plot_effects( self, level=0.95, + result_type="effect", joint=True, figsize=(12, 8), color_palette="colorblind", date_format=None, - y_label="Effect", - title="Estimated ATTs by Group", + y_label=None, + title=None, jitter_value=None, default_jitter=0.1, ): @@ -996,6 +997,10 @@ def plot_effects( level : float The confidence level for the intervals. Default is ``0.95``. + result_type : str + Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values, + ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds. + Default is ``'effect'``. joint : bool Indicates whether joint confidence intervals are computed. Default is ``True``. @@ -1010,10 +1015,10 @@ def plot_effects( Default is ``None``. y_label : str Label for y-axis. - Default is ``"Effect"``. + Default is ``None``. title : str Title for the entire plot. - Default is ``"Estimated ATTs by Group"``. + Default is ``None``. jitter_value : float Amount of jitter to apply to points. Default is ``None``. @@ -1035,8 +1040,24 @@ def plot_effects( """ if self.framework is None: raise ValueError("Apply fit() before plot_effects().") + + if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]: + raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.") + df = self._create_ci_dataframe(level=level, joint=joint) + # Set default y_label and title based on result_type + label_configs = { + "effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"}, + "rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"}, + "est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"}, + "ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"} + } + + config = label_configs[result_type] + y_label = y_label if y_label is not None else config["y_label"] + title = title if title is not None else config["title"] + # Sort time periods and treatment groups first_treated_periods = sorted(df["First Treated"].unique()) n_periods = len(first_treated_periods) @@ -1068,7 +1089,7 @@ def plot_effects( period_df = df[df["First Treated"] == period] ax = axes[idx] - self._plot_single_group(ax, period_df, period, colors, is_datetime, jitter_value) + self._plot_single_group(ax, period_df, period, result_type, colors, is_datetime, jitter_value) # Set axis labels if idx == n_periods - 1: # Only bottom plot gets x label @@ -1085,7 +1106,7 @@ def plot_effects( legend_ax.axis("off") legend_elements = [ Line2D([0], [0], color="red", linestyle=":", alpha=0.7, label="Treatment start"), - Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label="Zero effect"), + Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label=f"Zero {result_type}"), Line2D([0], [0], marker="o", color=colors["pre"], linestyle="None", label="Pre-treatment", markersize=5), ] @@ -1108,7 +1129,7 @@ def plot_effects( return fig, axes - def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_value): + def _plot_single_group(self, ax, period_df, period, result_type, colors, is_datetime, jitter_value): """ Plot estimates for a single treatment group on the given axis. @@ -1120,6 +1141,10 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ DataFrame containing estimates for a specific time period. period : int or datetime Treatment period for this group. + result_type : str + Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values, + ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds. + Default is ``'effect'``. colors : dict Dictionary with 'pre', 'anticipation' (if applicable), and 'post' color values. is_datetime : bool @@ -1165,6 +1190,21 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ # Define category mappings categories = [("pre", pre_treatment_mask), ("anticipation", anticipation_mask), ("post", post_treatment_mask)] + # Define plot configurations for each result type + plot_configs = { + "effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30}, + "rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50}, + "est_bounds": {"plot_col": "Estimate", "err_col_upper": "Estimate Upper Bound", "err_col_lower": "Estimate Lower Bound", "s_val": 30}, + "ci_bounds": {"plot_col": "Estimate", "err_col_upper": "CI Upper Bound", "err_col_lower": "CI Lower Bound", "s_val": 30} + } + + config = plot_configs[result_type] + plot_col = config["plot_col"] + plot_col_2 = config.get("plot_col_2") + err_col_upper = config.get("err_col_upper") + err_col_lower = config.get("err_col_lower") + s_val = config["s_val"] + # Plot each category for category_name, mask in categories: if not mask.any(): @@ -1179,23 +1219,30 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ if not category_data.empty: ax.scatter( - category_data["jittered_x"], category_data["Estimate"], color=colors[category_name], alpha=0.8, s=30 - ) - ax.errorbar( - category_data["jittered_x"], - category_data["Estimate"], - yerr=[ - category_data["Estimate"] - category_data["CI Lower"], - category_data["CI Upper"] - category_data["Estimate"], - ], - fmt="o", - capsize=3, - color=colors[category_name], - markersize=4, - markeredgewidth=1, - linewidth=1, + category_data["jittered_x"], category_data[plot_col], color=colors[category_name], alpha=0.8, s=s_val ) - + if result_type in ["effect", "est_bounds", "ci_bounds"]: + ax.errorbar( + category_data["jittered_x"], + category_data[plot_col], + yerr=[ + category_data[plot_col] - category_data[err_col_lower], + category_data[err_col_upper] - category_data[plot_col], + ], + fmt="o", + capsize=3, + color=colors[category_name], + markersize=4, + markeredgewidth=1, + linewidth=1, + ) + + elif result_type == "rv": + ax.scatter( + category_data["jittered_x"], category_data[plot_col_2], color=colors[category_name], alpha=0.8, s=s_val, + marker="s" + ) + # Format axes if is_datetime: period_str = np.datetime64(period, self._dml_data.datetime_unit) From 5aa1f5d5e6931da17d17dfab5ff13d883db203a8 Mon Sep 17 00:00:00 2001 From: PhilippBach Date: Tue, 18 Nov 2025 18:12:12 +0100 Subject: [PATCH 3/5] add unit tests for extended did plot --- doubleml/did/tests/test_did_multi_plot.py | 110 ++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/doubleml/did/tests/test_did_multi_plot.py b/doubleml/did/tests/test_did_multi_plot.py index 4a55449d..b8ade701 100644 --- a/doubleml/did/tests/test_did_multi_plot.py +++ b/doubleml/did/tests/test_did_multi_plot.py @@ -184,3 +184,113 @@ def test_plot_effects_jitter(doubleml_did_fixture): assert fig_default != fig plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_types(doubleml_did_fixture): + """Test plot_effects with different result types.""" + dml_obj = doubleml_did_fixture["model"] + + # Test default result_type='effect' + fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect") + assert isinstance(fig_effect, plt.Figure) + assert isinstance(axes_effect, list) + + # Check that the default y-label is set correctly + assert axes_effect[0].get_ylabel() == "Effect" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_rv(doubleml_did_fixture): + """Test plot_effects with result_type='rv' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='rv' + fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv") + assert isinstance(fig_rv, plt.Figure) + assert isinstance(axes_rv, list) + + # Check that the y-label is set correctly + assert axes_rv[0].get_ylabel() == "Robustness Value" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_est_bounds(doubleml_did_fixture): + """Test plot_effects with result_type='est_bounds' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='est_bounds' + fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds") + assert isinstance(fig_est, plt.Figure) + assert isinstance(axes_est, list) + + # Check that the y-label is set correctly + assert axes_est[0].get_ylabel() == "Estimate Bounds" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture): + """Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='ci_bounds' + fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds") + assert isinstance(fig_ci, plt.Figure) + assert isinstance(axes_ci, list) + + # Check that the y-label is set correctly + assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_invalid(doubleml_did_fixture): + """Test plot_effects with invalid result_type.""" + dml_obj = doubleml_did_fixture["model"] + + # Test with invalid result_type + with pytest.raises(ValueError, match="result_type must be either"): + dml_obj.plot_effects(result_type="invalid_type") + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture): + """Test plot_effects with result_type and custom labels.""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type with custom labels + custom_title = "Custom Sensitivity Plot" + custom_ylabel = "Custom Bounds Label" + + fig, axes = dml_obj.plot_effects( + result_type="est_bounds", + title=custom_title, + y_label=custom_ylabel + ) + + assert isinstance(fig, plt.Figure) + assert fig._suptitle.get_text() == custom_title + assert axes[0].get_ylabel() == custom_ylabel + + plt.close("all") From 831dcb8ee193c93e2fc5e69dcbabb3b740b899f2 Mon Sep 17 00:00:00 2001 From: PhilippBach Date: Tue, 18 Nov 2025 18:20:15 +0100 Subject: [PATCH 4/5] check for successful sensitivity when generating plot for RV --- doubleml/did/did_multi.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doubleml/did/did_multi.py b/doubleml/did/did_multi.py index 6080e409..50b2428a 100644 --- a/doubleml/did/did_multi.py +++ b/doubleml/did/did_multi.py @@ -1044,6 +1044,10 @@ def plot_effects( if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]: raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.") + if result_type != "effect" and self._framework.sensitivity_params is None: + raise ValueError(f"result_type='{result_type}' requires sensitivity analysis. " + "Please call sensitivity_analysis() first.") + df = self._create_ci_dataframe(level=level, joint=joint) # Set default y_label and title based on result_type From 3fab6f949f028cfd5587e855b5f48b2968c03826 Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Tue, 18 Nov 2025 19:05:16 +0100 Subject: [PATCH 5/5] formatting --- doubleml/did/did_multi.py | 45 ++++++++++++++------- doubleml/did/tests/test_did_multi_plot.py | 48 +++++++++++------------ 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/doubleml/did/did_multi.py b/doubleml/did/did_multi.py index 50b2428a..94ef112c 100644 --- a/doubleml/did/did_multi.py +++ b/doubleml/did/did_multi.py @@ -1040,14 +1040,15 @@ def plot_effects( """ if self.framework is None: raise ValueError("Apply fit() before plot_effects().") - + if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]: raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.") - + if result_type != "effect" and self._framework.sensitivity_params is None: - raise ValueError(f"result_type='{result_type}' requires sensitivity analysis. " - "Please call sensitivity_analysis() first.") - + raise ValueError( + f"result_type='{result_type}' requires sensitivity analysis. " "Please call sensitivity_analysis() first." + ) + df = self._create_ci_dataframe(level=level, joint=joint) # Set default y_label and title based on result_type @@ -1055,9 +1056,9 @@ def plot_effects( "effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"}, "rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"}, "est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"}, - "ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"} + "ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"}, } - + config = label_configs[result_type] y_label = y_label if y_label is not None else config["y_label"] title = title if title is not None else config["title"] @@ -1198,10 +1199,20 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date plot_configs = { "effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30}, "rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50}, - "est_bounds": {"plot_col": "Estimate", "err_col_upper": "Estimate Upper Bound", "err_col_lower": "Estimate Lower Bound", "s_val": 30}, - "ci_bounds": {"plot_col": "Estimate", "err_col_upper": "CI Upper Bound", "err_col_lower": "CI Lower Bound", "s_val": 30} + "est_bounds": { + "plot_col": "Estimate", + "err_col_upper": "Estimate Upper Bound", + "err_col_lower": "Estimate Lower Bound", + "s_val": 30, + }, + "ci_bounds": { + "plot_col": "Estimate", + "err_col_upper": "CI Upper Bound", + "err_col_lower": "CI Lower Bound", + "s_val": 30, + }, } - + config = plot_configs[result_type] plot_col = config["plot_col"] plot_col_2 = config.get("plot_col_2") @@ -1240,13 +1251,17 @@ def _plot_single_group(self, ax, period_df, period, result_type, colors, is_date markeredgewidth=1, linewidth=1, ) - + elif result_type == "rv": ax.scatter( - category_data["jittered_x"], category_data[plot_col_2], color=colors[category_name], alpha=0.8, s=s_val, - marker="s" + category_data["jittered_x"], + category_data[plot_col_2], + color=colors[category_name], + alpha=0.8, + s=s_val, + marker="s", ) - + # Format axes if is_datetime: period_str = np.datetime64(period, self._dml_data.datetime_unit) @@ -1519,4 +1534,4 @@ def _create_ci_dataframe(self, level=0.95, joint=True): df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"] df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"] df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"] - return df \ No newline at end of file + return df diff --git a/doubleml/did/tests/test_did_multi_plot.py b/doubleml/did/tests/test_did_multi_plot.py index b8ade701..d4275cde 100644 --- a/doubleml/did/tests/test_did_multi_plot.py +++ b/doubleml/did/tests/test_did_multi_plot.py @@ -195,10 +195,10 @@ def test_plot_effects_result_types(doubleml_did_fixture): fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect") assert isinstance(fig_effect, plt.Figure) assert isinstance(axes_effect, list) - + # Check that the default y-label is set correctly assert axes_effect[0].get_ylabel() == "Effect" - + plt.close("all") @@ -206,18 +206,18 @@ def test_plot_effects_result_types(doubleml_did_fixture): def test_plot_effects_result_type_rv(doubleml_did_fixture): """Test plot_effects with result_type='rv' (requires sensitivity analysis).""" dml_obj = doubleml_did_fixture["model"] - + # Perform sensitivity analysis first dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) - + # Test result_type='rv' fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv") assert isinstance(fig_rv, plt.Figure) assert isinstance(axes_rv, list) - + # Check that the y-label is set correctly assert axes_rv[0].get_ylabel() == "Robustness Value" - + plt.close("all") @@ -225,18 +225,18 @@ def test_plot_effects_result_type_rv(doubleml_did_fixture): def test_plot_effects_result_type_est_bounds(doubleml_did_fixture): """Test plot_effects with result_type='est_bounds' (requires sensitivity analysis).""" dml_obj = doubleml_did_fixture["model"] - + # Perform sensitivity analysis first dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) - + # Test result_type='est_bounds' fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds") assert isinstance(fig_est, plt.Figure) assert isinstance(axes_est, list) - + # Check that the y-label is set correctly assert axes_est[0].get_ylabel() == "Estimate Bounds" - + plt.close("all") @@ -244,18 +244,18 @@ def test_plot_effects_result_type_est_bounds(doubleml_did_fixture): def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture): """Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis).""" dml_obj = doubleml_did_fixture["model"] - + # Perform sensitivity analysis first dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) - + # Test result_type='ci_bounds' fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds") assert isinstance(fig_ci, plt.Figure) assert isinstance(axes_ci, list) - + # Check that the y-label is set correctly assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds" - + plt.close("all") @@ -263,11 +263,11 @@ def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture): def test_plot_effects_result_type_invalid(doubleml_did_fixture): """Test plot_effects with invalid result_type.""" dml_obj = doubleml_did_fixture["model"] - + # Test with invalid result_type with pytest.raises(ValueError, match="result_type must be either"): dml_obj.plot_effects(result_type="invalid_type") - + plt.close("all") @@ -275,22 +275,18 @@ def test_plot_effects_result_type_invalid(doubleml_did_fixture): def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture): """Test plot_effects with result_type and custom labels.""" dml_obj = doubleml_did_fixture["model"] - + # Perform sensitivity analysis first dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) - + # Test result_type with custom labels custom_title = "Custom Sensitivity Plot" custom_ylabel = "Custom Bounds Label" - - fig, axes = dml_obj.plot_effects( - result_type="est_bounds", - title=custom_title, - y_label=custom_ylabel - ) - + + fig, axes = dml_obj.plot_effects(result_type="est_bounds", title=custom_title, y_label=custom_ylabel) + assert isinstance(fig, plt.Figure) assert fig._suptitle.get_text() == custom_title assert axes[0].get_ylabel() == custom_ylabel - + plt.close("all")