Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 97 additions & 23 deletions doubleml/did/did_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,12 +979,13 @@ def aggregate(self, aggregation="group"):
def plot_effects(
self,
level=0.95,
result_type="effect",
joint=True,
figsize=(12, 8),
color_palette="colorblind",
date_format=None,
y_label="Effect",
title="Estimated ATTs by Group",
y_label=None,
title=None,
jitter_value=None,
default_jitter=0.1,
):
Expand All @@ -996,6 +997,10 @@ def plot_effects(
level : float
The confidence level for the intervals.
Default is ``0.95``.
result_type : str
Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
Default is ``'effect'``.
joint : bool
Indicates whether joint confidence intervals are computed.
Default is ``True``.
Expand All @@ -1010,10 +1015,10 @@ def plot_effects(
Default is ``None``.
y_label : str
Label for y-axis.
Default is ``"Effect"``.
Default is ``None``.
title : str
Title for the entire plot.
Default is ``"Estimated ATTs by Group"``.
Default is ``None``.
jitter_value : float
Amount of jitter to apply to points.
Default is ``None``.
Expand All @@ -1035,8 +1040,29 @@ def plot_effects(
"""
if self.framework is None:
raise ValueError("Apply fit() before plot_effects().")

if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]:
raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.")

if result_type != "effect" and self._framework.sensitivity_params is None:
raise ValueError(
f"result_type='{result_type}' requires sensitivity analysis. " "Please call sensitivity_analysis() first."
)

df = self._create_ci_dataframe(level=level, joint=joint)

# Set default y_label and title based on result_type
label_configs = {
"effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"},
"rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"},
"est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"},
"ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"},
}

config = label_configs[result_type]
y_label = y_label if y_label is not None else config["y_label"]
title = title if title is not None else config["title"]

# Sort time periods and treatment groups
first_treated_periods = sorted(df["First Treated"].unique())
n_periods = len(first_treated_periods)
Expand Down Expand Up @@ -1068,7 +1094,7 @@ def plot_effects(
period_df = df[df["First Treated"] == period]
ax = axes[idx]

self._plot_single_group(ax, period_df, period, colors, is_datetime, jitter_value)
self._plot_single_group(ax, period_df, period, result_type, colors, is_datetime, jitter_value)

# Set axis labels
if idx == n_periods - 1: # Only bottom plot gets x label
Expand All @@ -1085,7 +1111,7 @@ def plot_effects(
legend_ax.axis("off")
legend_elements = [
Line2D([0], [0], color="red", linestyle=":", alpha=0.7, label="Treatment start"),
Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label="Zero effect"),
Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label=f"Zero {result_type}"),
Line2D([0], [0], marker="o", color=colors["pre"], linestyle="None", label="Pre-treatment", markersize=5),
]

Expand All @@ -1108,7 +1134,7 @@ def plot_effects(

return fig, axes

def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_value):
def _plot_single_group(self, ax, period_df, period, result_type, colors, is_datetime, jitter_value):
"""
Plot estimates for a single treatment group on the given axis.

Expand All @@ -1120,6 +1146,10 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
DataFrame containing estimates for a specific time period.
period : int or datetime
Treatment period for this group.
result_type : str
Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values,
``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds.
Default is ``'effect'``.
colors : dict
Dictionary with 'pre', 'anticipation' (if applicable), and 'post' color values.
is_datetime : bool
Expand Down Expand Up @@ -1165,6 +1195,31 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_
# Define category mappings
categories = [("pre", pre_treatment_mask), ("anticipation", anticipation_mask), ("post", post_treatment_mask)]

# Define plot configurations for each result type
plot_configs = {
"effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30},
"rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50},
"est_bounds": {
"plot_col": "Estimate",
"err_col_upper": "Estimate Upper Bound",
"err_col_lower": "Estimate Lower Bound",
"s_val": 30,
},
"ci_bounds": {
"plot_col": "Estimate",
"err_col_upper": "CI Upper Bound",
"err_col_lower": "CI Lower Bound",
"s_val": 30,
},
}

config = plot_configs[result_type]
plot_col = config["plot_col"]
plot_col_2 = config.get("plot_col_2")
err_col_upper = config.get("err_col_upper")
err_col_lower = config.get("err_col_lower")
s_val = config["s_val"]

# Plot each category
for category_name, mask in categories:
if not mask.any():
Expand All @@ -1179,22 +1234,33 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_

if not category_data.empty:
ax.scatter(
category_data["jittered_x"], category_data["Estimate"], color=colors[category_name], alpha=0.8, s=30
)
ax.errorbar(
category_data["jittered_x"],
category_data["Estimate"],
yerr=[
category_data["Estimate"] - category_data["CI Lower"],
category_data["CI Upper"] - category_data["Estimate"],
],
fmt="o",
capsize=3,
color=colors[category_name],
markersize=4,
markeredgewidth=1,
linewidth=1,
category_data["jittered_x"], category_data[plot_col], color=colors[category_name], alpha=0.8, s=s_val
)
if result_type in ["effect", "est_bounds", "ci_bounds"]:
ax.errorbar(
category_data["jittered_x"],
category_data[plot_col],
yerr=[
category_data[plot_col] - category_data[err_col_lower],
category_data[err_col_upper] - category_data[plot_col],
],
fmt="o",
capsize=3,
color=colors[category_name],
markersize=4,
markeredgewidth=1,
linewidth=1,
)

elif result_type == "rv":
ax.scatter(
category_data["jittered_x"],
category_data[plot_col_2],
color=colors[category_name],
alpha=0.8,
s=s_val,
marker="s",
)

# Format axes
if is_datetime:
Expand Down Expand Up @@ -1431,6 +1497,8 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
- 'CI Lower': Lower bound of confidence intervals
- 'CI Upper': Upper bound of confidence intervals
- 'Pre-Treatment': Boolean indicating if evaluation period is before treatment
- 'RV': Robustness values (if sensitivity_analysis() has been called before)
- 'RVa': Robustness values for (1-a) confidence bounds (if sensitivity_analysis() has been called before)

Notes
-----
Expand Down Expand Up @@ -1459,5 +1527,11 @@ def _create_ci_dataframe(self, level=0.95, joint=True):
"Pre-Treatment": [gt_combination[2] < gt_combination[0] for gt_combination in self.gt_combinations],
}
)

if self._framework.sensitivity_params is not None:
df["RV"] = self.framework.sensitivity_params["rv"]
df["RVa"] = self.framework.sensitivity_params["rva"]
df["CI Lower Bound"] = self.framework.sensitivity_params["ci"]["lower"]
df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"]
df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"]
df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"]
return df
106 changes: 106 additions & 0 deletions doubleml/did/tests/test_did_multi_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,109 @@ def test_plot_effects_jitter(doubleml_did_fixture):
assert fig_default != fig

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_types(doubleml_did_fixture):
"""Test plot_effects with different result types."""
dml_obj = doubleml_did_fixture["model"]

# Test default result_type='effect'
fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect")
assert isinstance(fig_effect, plt.Figure)
assert isinstance(axes_effect, list)

# Check that the default y-label is set correctly
assert axes_effect[0].get_ylabel() == "Effect"

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_type_rv(doubleml_did_fixture):
"""Test plot_effects with result_type='rv' (requires sensitivity analysis)."""
dml_obj = doubleml_did_fixture["model"]

# Perform sensitivity analysis first
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)

# Test result_type='rv'
fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv")
assert isinstance(fig_rv, plt.Figure)
assert isinstance(axes_rv, list)

# Check that the y-label is set correctly
assert axes_rv[0].get_ylabel() == "Robustness Value"

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_type_est_bounds(doubleml_did_fixture):
"""Test plot_effects with result_type='est_bounds' (requires sensitivity analysis)."""
dml_obj = doubleml_did_fixture["model"]

# Perform sensitivity analysis first
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)

# Test result_type='est_bounds'
fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds")
assert isinstance(fig_est, plt.Figure)
assert isinstance(axes_est, list)

# Check that the y-label is set correctly
assert axes_est[0].get_ylabel() == "Estimate Bounds"

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture):
"""Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis)."""
dml_obj = doubleml_did_fixture["model"]

# Perform sensitivity analysis first
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)

# Test result_type='ci_bounds'
fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds")
assert isinstance(fig_ci, plt.Figure)
assert isinstance(axes_ci, list)

# Check that the y-label is set correctly
assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds"

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_type_invalid(doubleml_did_fixture):
"""Test plot_effects with invalid result_type."""
dml_obj = doubleml_did_fixture["model"]

# Test with invalid result_type
with pytest.raises(ValueError, match="result_type must be either"):
dml_obj.plot_effects(result_type="invalid_type")

plt.close("all")


@pytest.mark.ci
def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture):
"""Test plot_effects with result_type and custom labels."""
dml_obj = doubleml_did_fixture["model"]

# Perform sensitivity analysis first
dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03)

# Test result_type with custom labels
custom_title = "Custom Sensitivity Plot"
custom_ylabel = "Custom Bounds Label"

fig, axes = dml_obj.plot_effects(result_type="est_bounds", title=custom_title, y_label=custom_ylabel)

assert isinstance(fig, plt.Figure)
assert fig._suptitle.get_text() == custom_title
assert axes[0].get_ylabel() == custom_ylabel

plt.close("all")
Loading