From fe6375e91d73269844cc684dd12fe5b91be08846 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 23 Jun 2025 18:26:09 +0200
Subject: [PATCH 01/33] static panel model, data, and util
---
.../plm/datasets/dgp_static_panel_CP2025.py | 82 +++
doubleml/plm/plpr.py | 477 ++++++++++++++++++
doubleml/plm/utils/_plpr_util.py | 73 +++
3 files changed, 632 insertions(+)
create mode 100644 doubleml/plm/datasets/dgp_static_panel_CP2025.py
create mode 100644 doubleml/plm/plpr.py
create mode 100644 doubleml/plm/utils/_plpr_util.py
diff --git a/doubleml/plm/datasets/dgp_static_panel_CP2025.py b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
new file mode 100644
index 000000000..af9e11165
--- /dev/null
+++ b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
@@ -0,0 +1,82 @@
+import numpy as np
+import pandas as pd
+
+
+def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1'):
+ """
+ Generates static panel data from the simulation dgp in Clarke and Polselli (2025).
+
+ Parameters
+ ----------
+ num_n :
+ The number of unit in the panel.
+ num_t :
+ The number of time periods in the panel.
+ num_x :
+ The number of of covariates.
+ theta :
+ The value of the causal parameter.
+ dgp_type :
+ The type of DGP design to be used. Default is ``'dgp1'``, other options are ``'dgp2'`` and ``'dgp3'``.
+
+ Returns
+ -------
+ pandas.DataFrame
+ DataFrame containing the simulated static panel data.
+ """
+
+ # parameters
+ a = 0.25
+ b = 0.5
+ x_var = 5
+ a_var = 0.95
+
+ # id and time vectors
+ id = np.repeat(np.arange(1, num_n+1), num_t)
+ time = np.tile(np.arange(1, num_t+1), num_n)
+
+ # individual fixed effects
+ a_i = np.repeat(np.random.normal(0, np.sqrt(a_var), num_n), num_t)
+ c_i = np.repeat(np.random.standard_normal(num_n), num_t)
+
+ # covariates and errors
+ x_it = np.random.normal(loc=0, scale=np.sqrt(x_var), size=(num_n*num_t, dim_x))
+ u_it = np.random.standard_normal(num_n*num_t)
+ v_it = np.random.standard_normal(num_n*num_t)
+
+ # functional forms in nuisance functions
+ if dgp_type == 'dgp1':
+ l_0 = a * x_it[:,0] + x_it[:,2]
+ m_0 = a * x_it[:,0] + x_it[:,2]
+ elif dgp_type == 'dgp2':
+ l_0 = np.divide(np.exp(x_it[:,0]), 1 + np.exp(x_it[:,0])) + a * np.cos(x_it[:,2])
+ m_0 = np.cos(x_it[:,0]) + a * np.divide(np.exp(x_it[:,2]), 1 + np.exp(x_it[:,2]))
+ elif dgp_type == 'dgp3':
+ l_0 = b * (x_it[:,0] * x_it[:,2]) + a * (x_it[:,2] * np.where(x_it[:,2] > 0, 1, 0))
+ m_0 = a * (x_it[:,0] * np.where(x_it[:,0] > 0, 1, 0)) + b * (x_it[:,0] * x_it[:,2])
+ else:
+ raise ValueError('Invalid dgp')
+
+ # treatment
+ d_it = m_0 + c_i + v_it
+
+ def alpha_i(x_it, d_it, a_i, num_n, num_t):
+ d_i = np.array_split(d_it, num_n)
+ d_i_term = np.repeat(np.mean(d_i, axis=1), num_t) - np.mean(d_it)
+
+ x_i = np.array_split(np.sum(x_it[:, [0, 2]], axis=1), num_n)
+ x_i_mean = np.mean(x_i, axis=1)
+ x_i_term = np.repeat(x_i_mean, num_t)
+
+ alpha_term = 0.25 * d_i_term + 0.25 * x_i_term + a_i
+ return alpha_term
+
+ # outcome
+ y_it = d_it * theta + l_0 + alpha_i(x_it, d_it, a_i, num_n, num_t) + u_it
+
+ x_cols = [f'x{i + 1}' for i in np.arange(dim_x)]
+
+ data = pd.DataFrame(np.column_stack((id, time, d_it, y_it, x_it)),
+ columns=['id', 'time', 'd', 'y'] + x_cols).astype({'id': 'int64', 'time': 'int64'})
+
+ return data
\ No newline at end of file
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
new file mode 100644
index 000000000..1957efcab
--- /dev/null
+++ b/doubleml/plm/plpr.py
@@ -0,0 +1,477 @@
+import warnings
+
+import numpy as np
+import pandas as pd
+from sklearn.base import clone
+from sklearn.utils import check_X_y
+
+from ..data.base_data import DoubleMLData
+from ..double_ml import DoubleML
+from ..double_ml_score_mixins import LinearScoreMixin
+from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
+from ..utils._estimation import _dml_cv_predict, _dml_tune
+from ..utils.blp import DoubleMLBLP
+
+
+class DoubleMLPLPR(LinearScoreMixin, DoubleML):
+ """Double machine learning for partially linear panel regression models
+
+ Parameters
+ ----------
+ obj_dml_data : :class:`DoubleMLData` object
+ The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model.
+
+ ml_l : estimator implementing ``fit()`` and ``predict()``
+ A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`\\ell_0(X) = E[Y|X]`.
+
+ ml_m : estimator implementing ``fit()`` and ``predict()``
+ A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`m_0(X) = E[D|X]`.
+ For binary treatment variables :math:`D` (with values 0 and 1), a classifier implementing ``fit()`` and
+ ``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier` returns ``True``,
+ ``predict_proba()`` is used otherwise ``predict()``.
+
+ ml_g : estimator implementing ``fit()`` and ``predict()``
+ A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function
+ :math:`g_0(X) = E[Y - D \\theta_0|X]`.
+ Note: The learner `ml_g` is only required for the score ``'IV-type'``. Optionally, it can be specified and
+ estimated for callable scores.
+
+ n_folds : int
+ Number of folds.
+ Default is ``5``.
+
+ n_rep : int
+ Number of repetitons for the sample splitting.
+ Default is ``1``.
+
+ score : str or callable
+ A str (``'partialling out'`` or ``'IV-type'``) specifying the score function
+ or a callable object / function with signature ``psi_a, psi_b = score(y, d, l_hat, m_hat, g_hat, smpls)``.
+ Default is ``'partialling out'``.
+
+ pdml_approach : str
+ Panel DML approach (``'transform'``, ``'cre'``, ``'cre_general'``)
+
+ draw_sample_splitting : bool
+ Indicates whether the sample splitting should be drawn during initialization of the object.
+ Default is ``True``.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> import doubleml as dml
+
+
+ Notes
+ -----
+ **Partially linear panel regression (PLPR)** models take the form
+
+ """
+
+ def __init__(
+ self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", pdml_approach='transform', draw_sample_splitting=True):
+ super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)
+
+ self._check_data(self._dml_data)
+ # assert cluster?
+ valid_scores = ["IV-type", "partialling out"]
+ _check_score(self.score, valid_scores, allow_callable=True)
+
+ valid_pdml_approach = ["transform", "cre", "cre_general"]
+ self._check_pdml_approach(pdml_approach, valid_pdml_approach)
+ self._pdml_approach = pdml_approach
+
+ _ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
+ ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
+ self._learner = {"ml_l": ml_l, "ml_m": ml_m}
+
+ if ml_g is not None:
+ if (isinstance(self.score, str) & (self.score == "IV-type")) | callable(self.score):
+ _ = self._check_learner(ml_g, "ml_g", regressor=True, classifier=False)
+ self._learner["ml_g"] = ml_g
+ else:
+ assert isinstance(self.score, str) & (self.score == "partialling out")
+ warnings.warn(
+ (
+ 'A learner ml_g has been provided for score = "partialling out" but will be ignored. "'
+ "A learner ml_g is not required for estimation."
+ )
+ )
+ elif isinstance(self.score, str) & (self.score == "IV-type"):
+ warnings.warn(("For score = 'IV-type', learners ml_l and ml_g should be specified. Set ml_g = clone(ml_l)."))
+ self._learner["ml_g"] = clone(ml_l)
+
+ self._predict_method = {"ml_l": "predict"}
+ if "ml_g" in self._learner:
+ self._predict_method["ml_g"] = "predict"
+ if ml_m_is_classifier:
+ if self._dml_data.binary_treats.all():
+ self._predict_method["ml_m"] = "predict_proba"
+ else:
+ raise ValueError(
+ f"The ml_m learner {str(ml_m)} was identified as classifier "
+ "but at least one treatment variable is not binary with values 0 and 1."
+ )
+ else:
+ self._predict_method["ml_m"] = "predict"
+
+ self._initialize_ml_nuisance_params()
+ self._sensitivity_implemented = False
+ self._external_predictions_implemented = True
+
+ def _initialize_ml_nuisance_params(self):
+ self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
+
+ def _check_data(self, obj_dml_data):
+ if not isinstance(obj_dml_data, DoubleMLData):
+ raise TypeError(
+ f"The data must be of DoubleMLData type. {str(obj_dml_data)} of type {str(type(obj_dml_data))} was passed."
+ )
+ if obj_dml_data.z_cols is not None:
+ raise ValueError(
+ "Incompatible data. " + " and ".join(obj_dml_data.z_cols) + " have been set as instrumental variable(s). "
+ "To fit a partially linear IV regression model use DoubleMLPLIV instead of DoubleMLPLR."
+ )
+ return
+
+ def _check_pdml_approach(pdml_approach, valid_pdml_approach):
+ if isinstance(pdml_approach, str):
+ if pdml_approach not in valid_pdml_approach:
+ raise ValueError("Invalid pdml_approach " + pdml_approach + ". " + "pdml_approach score " + " or ".join(pdml_approach) + ".")
+ else:
+ raise TypeError(f"score should be a string. {str(pdml_approach)} was passed.")
+ return
+
+
+ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
+ x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
+ x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
+ m_external = external_predictions["ml_m"] is not None
+ l_external = external_predictions["ml_l"] is not None
+ if "ml_g" in self._learner:
+ g_external = external_predictions["ml_g"] is not None
+ else:
+ g_external = False
+
+ # nuisance l
+ if l_external:
+ l_hat = {"preds": external_predictions["ml_l"], "targets": None, "models": None}
+ elif self._score == "IV-type" and g_external:
+ l_hat = {"preds": None, "targets": None, "models": None}
+ else:
+ l_hat = _dml_cv_predict(
+ self._learner["ml_l"],
+ x,
+ y,
+ smpls=smpls,
+ n_jobs=n_jobs_cv,
+ est_params=self._get_params("ml_l"),
+ method=self._predict_method["ml_l"],
+ return_models=return_models,
+ )
+ _check_finite_predictions(l_hat["preds"], self._learner["ml_l"], "ml_l", smpls)
+
+ # nuisance m
+ if m_external:
+ m_hat = {"preds": external_predictions["ml_m"], "targets": None, "models": None}
+ else:
+ # cre using m_d + x for m_hat, otherwise only x
+ if self._pdml_approach == 'cre':
+ help_data = pd.DataFrame({'id': self._dml_data.cluster_vars[:, 0], 'd': d})
+ m_d = help_data.groupby(["id"]).transform('mean').values
+ x = np.column_stack((x, m_d))
+
+ m_hat = _dml_cv_predict(
+ self._learner["ml_m"],
+ x,
+ d,
+ smpls=smpls,
+ n_jobs=n_jobs_cv,
+ est_params=self._get_params("ml_m"),
+ method=self._predict_method["ml_m"],
+ return_models=return_models,
+ )
+
+ # general cre adjustment
+ if self._pdml_approach == 'cre_general':
+ help_data = pd.DataFrame({'id': self._dml_data.cluster_vars[:, 0], 'm_hat': m_hat['preds'], 'd': d})
+ group_means = help_data.groupby(['id'])[['m_hat', 'd']].transform('mean')
+ m_hat_star = m_hat['preds'] + group_means['d'] - group_means['m_hat']
+ m_hat['preds'] = m_hat_star
+
+
+ _check_finite_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls)
+ if self._check_learner(self._learner["ml_m"], "ml_m", regressor=True, classifier=True):
+ _check_is_propensity(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls, eps=1e-12)
+
+ if self._dml_data.binary_treats[self._dml_data.d_cols[self._i_treat]]:
+ _check_binary_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", self._dml_data.d_cols[self._i_treat])
+
+ # an estimate of g is obtained for the IV-type score and callable scores
+ g_hat = {"preds": None, "targets": None, "models": None}
+ if "ml_g" in self._learner:
+ # nuisance g
+ if g_external:
+ g_hat = {"preds": external_predictions["ml_g"], "targets": None, "models": None}
+ else:
+ # get an initial estimate for theta using the partialling out score
+ psi_a = -np.multiply(d - m_hat["preds"], d - m_hat["preds"])
+ psi_b = np.multiply(d - m_hat["preds"], y - l_hat["preds"])
+ theta_initial = -np.nanmean(psi_b) / np.nanmean(psi_a)
+ g_hat = _dml_cv_predict(
+ self._learner["ml_g"],
+ x,
+ y - theta_initial * d,
+ smpls=smpls,
+ n_jobs=n_jobs_cv,
+ est_params=self._get_params("ml_g"),
+ method=self._predict_method["ml_g"],
+ return_models=return_models,
+ )
+ _check_finite_predictions(g_hat["preds"], self._learner["ml_g"], "ml_g", smpls)
+
+ psi_a, psi_b = self._score_elements(y, d, l_hat["preds"], m_hat["preds"], g_hat["preds"], smpls)
+ psi_elements = {"psi_a": psi_a, "psi_b": psi_b}
+ preds = {
+ "predictions": {"ml_l": l_hat["preds"], "ml_m": m_hat["preds"], "ml_g": g_hat["preds"]},
+ "targets": {"ml_l": l_hat["targets"], "ml_m": m_hat["targets"], "ml_g": g_hat["targets"]},
+ "models": {"ml_l": l_hat["models"], "ml_m": m_hat["models"], "ml_g": g_hat["models"]},
+ }
+
+ return psi_elements, preds
+
+ def _score_elements(self, y, d, l_hat, m_hat, g_hat, smpls):
+ # compute residual
+ v_hat = d - m_hat
+
+ if isinstance(self.score, str):
+ if self.score == "IV-type":
+ psi_a = -np.multiply(v_hat, d)
+ psi_b = np.multiply(v_hat, y - g_hat)
+ else:
+ assert self.score == "partialling out"
+ u_hat = y - l_hat
+ psi_a = -np.multiply(v_hat, v_hat)
+ psi_b = np.multiply(v_hat, u_hat)
+ else:
+ assert callable(self.score)
+ psi_a, psi_b = self.score(y=y, d=d, l_hat=l_hat, m_hat=m_hat, g_hat=g_hat, smpls=smpls)
+
+ return psi_a, psi_b
+
+ # def _sensitivity_element_est(self, preds):
+ # # set elments for readability
+ # y = self._dml_data.y
+ # d = self._dml_data.d
+
+ # m_hat = preds["predictions"]["ml_m"]
+ # theta = self.all_coef[self._i_treat, self._i_rep]
+
+ # if self.score == "partialling out":
+ # l_hat = preds["predictions"]["ml_l"]
+ # sigma2_score_element = np.square(y - l_hat - np.multiply(theta, d - m_hat))
+ # else:
+ # assert self.score == "IV-type"
+ # g_hat = preds["predictions"]["ml_g"]
+ # sigma2_score_element = np.square(y - g_hat - np.multiply(theta, d))
+
+ # sigma2 = np.mean(sigma2_score_element)
+ # psi_sigma2 = sigma2_score_element - sigma2
+
+ # treatment_residual = d - m_hat
+ # nu2 = np.divide(1.0, np.mean(np.square(treatment_residual)))
+ # psi_nu2 = nu2 - np.multiply(np.square(treatment_residual), np.square(nu2))
+ # rr = np.multiply(treatment_residual, nu2)
+
+ # element_dict = {
+ # "sigma2": sigma2,
+ # "nu2": nu2,
+ # "psi_sigma2": psi_sigma2,
+ # "psi_nu2": psi_nu2,
+ # "riesz_rep": rr,
+ # }
+ # return element_dict
+
+ def _nuisance_tuning(
+ self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search
+ ):
+ x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
+ x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
+
+ if scoring_methods is None:
+ scoring_methods = {"ml_l": None, "ml_m": None, "ml_g": None}
+
+ train_inds = [train_index for (train_index, _) in smpls]
+ l_tune_res = _dml_tune(
+ y,
+ x,
+ train_inds,
+ self._learner["ml_l"],
+ param_grids["ml_l"],
+ scoring_methods["ml_l"],
+ n_folds_tune,
+ n_jobs_cv,
+ search_mode,
+ n_iter_randomized_search,
+ )
+ m_tune_res = _dml_tune(
+ d,
+ x,
+ train_inds,
+ self._learner["ml_m"],
+ param_grids["ml_m"],
+ scoring_methods["ml_m"],
+ n_folds_tune,
+ n_jobs_cv,
+ search_mode,
+ n_iter_randomized_search,
+ )
+
+ l_best_params = [xx.best_params_ for xx in l_tune_res]
+ m_best_params = [xx.best_params_ for xx in m_tune_res]
+
+ # an ML model for g is obtained for the IV-type score and callable scores
+ if "ml_g" in self._learner:
+ # construct an initial theta estimate from the tuned models using the partialling out score
+ l_hat = np.full_like(y, np.nan)
+ m_hat = np.full_like(d, np.nan)
+ for idx, (train_index, _) in enumerate(smpls):
+ l_hat[train_index] = l_tune_res[idx].predict(x[train_index, :])
+ m_hat[train_index] = m_tune_res[idx].predict(x[train_index, :])
+ psi_a = -np.multiply(d - m_hat, d - m_hat)
+ psi_b = np.multiply(d - m_hat, y - l_hat)
+ theta_initial = -np.nanmean(psi_b) / np.nanmean(psi_a)
+ g_tune_res = _dml_tune(
+ y - theta_initial * d,
+ x,
+ train_inds,
+ self._learner["ml_g"],
+ param_grids["ml_g"],
+ scoring_methods["ml_g"],
+ n_folds_tune,
+ n_jobs_cv,
+ search_mode,
+ n_iter_randomized_search,
+ )
+
+ g_best_params = [xx.best_params_ for xx in g_tune_res]
+ params = {"ml_l": l_best_params, "ml_m": m_best_params, "ml_g": g_best_params}
+ tune_res = {"l_tune": l_tune_res, "m_tune": m_tune_res, "g_tune": g_tune_res}
+ else:
+ params = {"ml_l": l_best_params, "ml_m": m_best_params}
+ tune_res = {"l_tune": l_tune_res, "m_tune": m_tune_res}
+
+ res = {"params": params, "tune_res": tune_res}
+
+ return res
+
+ # def cate(self, basis, is_gate=False, **kwargs):
+ # """
+ # Calculate conditional average treatment effects (CATE) for a given basis.
+
+ # Parameters
+ # ----------
+ # basis : :class:`pandas.DataFrame`
+ # The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
+ # where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
+
+ # is_gate : bool
+ # Indicates whether the basis is constructed for GATEs (dummy-basis).
+ # Default is ``False``.
+
+ # **kwargs: dict
+ # Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
+
+ # Returns
+ # -------
+ # model : :class:`doubleML.DoubleMLBLP`
+ # Best linear Predictor model.
+ # """
+ # if self._dml_data.n_treat > 1:
+ # raise NotImplementedError(
+ # "Only implemented for single treatment. " + f"Number of treatments is {str(self._dml_data.n_treat)}."
+ # )
+ # if self.n_rep != 1:
+ # raise NotImplementedError("Only implemented for one repetition. " + f"Number of repetitions is {str(self.n_rep)}.")
+
+ # Y_tilde, D_tilde = self._partial_out()
+
+ # D_basis = basis * D_tilde
+ # model = DoubleMLBLP(
+ # orth_signal=Y_tilde.reshape(-1),
+ # basis=D_basis,
+ # is_gate=is_gate,
+ # )
+ # model.fit(**kwargs)
+ # return model
+
+ # def gate(self, groups, **kwargs):
+ # """
+ # Calculate group average treatment effects (GATE) for groups.
+
+ # Parameters
+ # ----------
+ # groups : :class:`pandas.DataFrame`
+ # The group indicator for estimating the best linear predictor. Groups should be mutually exclusive.
+ # Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
+ # and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
+
+ # **kwargs: dict
+ # Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
+
+ # Returns
+ # -------
+ # model : :class:`doubleML.DoubleMLBLP`
+ # Best linear Predictor model for Group Effects.
+ # """
+
+ # if not isinstance(groups, pd.DataFrame):
+ # raise TypeError(f"Groups must be of DataFrame type. Groups of type {str(type(groups))} was passed.")
+ # if not all(groups.dtypes == bool) or all(groups.dtypes == int):
+ # if groups.shape[1] == 1:
+ # groups = pd.get_dummies(groups, prefix="Group", prefix_sep="_")
+ # else:
+ # raise TypeError(
+ # "Columns of groups must be of bool type or int type (dummy coded). "
+ # "Alternatively, groups should only contain one column."
+ # )
+
+ # if any(groups.sum(0) <= 5):
+ # warnings.warn("At least one group effect is estimated with less than 6 observations.")
+
+ # model = self.cate(groups, is_gate=True, **kwargs)
+ # return model
+
+ # def _partial_out(self):
+ # """
+ # Helper function. Returns the partialled out quantities of Y and D.
+ # Works with multiple repetitions.
+
+ # Returns
+ # -------
+ # Y_tilde : :class:`numpy.ndarray`
+ # The residual of the regression of Y on X.
+ # D_tilde : :class:`numpy.ndarray`
+ # The residual of the regression of D on X.
+ # """
+ # if self.predictions is None:
+ # raise ValueError("predictions are None. Call .fit(store_predictions=True) to store the predictions.")
+
+ # y = self._dml_data.y.reshape(-1, 1)
+ # d = self._dml_data.d.reshape(-1, 1)
+ # ml_m = self.predictions["ml_m"].squeeze(axis=2)
+
+ # if self.score == "partialling out":
+ # ml_l = self.predictions["ml_l"].squeeze(axis=2)
+ # Y_tilde = y - ml_l
+ # D_tilde = d - ml_m
+ # else:
+ # assert self.score == "IV-type"
+ # ml_g = self.predictions["ml_g"].squeeze(axis=2)
+ # Y_tilde = y - (self.coef * ml_m) - ml_g
+ # D_tilde = d - ml_m
+
+ # return Y_tilde, D_tilde
diff --git a/doubleml/plm/utils/_plpr_util.py b/doubleml/plm/utils/_plpr_util.py
new file mode 100644
index 000000000..9045ffb08
--- /dev/null
+++ b/doubleml/plm/utils/_plpr_util.py
@@ -0,0 +1,73 @@
+import numpy as np
+import pandas as pd
+from sklearn.preprocessing import PolynomialFeatures
+
+
+def extend_data(data):
+ data = data.copy()
+ poly = PolynomialFeatures(2, include_bias=False)
+
+ xdat = data.loc[:,data.columns.str.startswith('x') & ~data.columns.str.contains('lag')]
+ xpoly = poly.fit_transform(xdat)
+ x_p3 = xdat**3
+
+ x_pol_nam = poly.get_feature_names_out()
+ x_cols_p3 = [f'x{i + 1}^3' for i in np.arange(xdat.shape[1])]
+
+ if data.columns.str.startswith('m_x').any():
+ xmdat = data.loc[:,data.columns.str.startswith('m_x')]
+ xmpoly = poly.fit_transform(xmdat)
+ xm_p3 = xmdat**3
+
+ xm_pol_nam = poly.get_feature_names_out()
+ xm_cols_p3 = [f'm_x{i + 1}^3' for i in np.arange(xmdat.shape[1])]
+
+ X_all = np.column_stack((xpoly, x_p3, xmpoly, xm_p3))
+ x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3 + list(xm_pol_nam) + xm_cols_p3)
+ df_ext = data[['id', 'time', 'd', 'y', 'm_d']].join(x_df)
+
+ elif data.columns.str.contains('_lag').any():
+ xldat = data.loc[:,data.columns.str.contains('_lag')]
+ xlpoly = poly.fit_transform(xldat)
+ xl_p3 = xldat**3
+
+ xl_pol_nam = poly.get_feature_names_out()
+ xl_cols_p3 = [f'x{i + 1}_lag^3' for i in np.arange(xldat.shape[1])]
+
+ X_all = np.column_stack((xpoly, x_p3, xlpoly, xl_p3))
+ x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3 + list(xl_pol_nam) + xl_cols_p3)
+ df_ext = data[['id', 'time', 'd_diff', 'y_diff']].join(x_df)
+
+ else:
+ X_all = np.column_stack((xpoly, x_p3))
+ x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3)
+ df_ext = data[['id', 'time', 'd', 'y']].join(x_df)
+
+ return df_ext
+
+
+def cre_fct(data):
+ df = data.copy()
+ id_means = df.loc[:,~df.columns.isin(['time', 'y'])].groupby(["id"]).transform('mean')
+ df = df.join(id_means.rename(columns=lambda x: "m_" + x))
+ return df
+
+
+def fd_fct(data):
+ df = data.copy()
+ shifted = df.loc[:,~df.columns.isin(['d', 'y', 'time'])].groupby(["id"]).shift(1)
+ first_diff = df.loc[:,df.columns.isin(['id', 'd', 'y'])].groupby(["id"]).diff()
+ df_fd = df.join(shifted.rename(columns=lambda x: x +"_lag"))
+ df_fd = df_fd.join(first_diff.rename(columns=lambda x: x +"_diff"))
+ df = df_fd.dropna(subset=['x1_lag']).reset_index(drop=True)
+ return df
+
+
+def wd_fct(data):
+ df = data.copy()
+ df_demean = df.loc[:,~df.columns.isin(['time'])].groupby(["id"]).transform(lambda x: x - x.mean())
+ # add xbar (the grand mean allows a consistent estimate of the constant term)
+ within_means = df_demean + df.loc[:,~df.columns.isin(['id','time'])].mean()
+ df_wd = df.loc[:,df.columns.isin(['id','time'])]
+ df = df_wd.join(within_means)
+ return df
\ No newline at end of file
From 20319232ba6e5c48d836edcebd898f4401fb1e49 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 26 Jun 2025 17:51:05 +0200
Subject: [PATCH 02/33] update plpr and dataset
---
.../plm/datasets/dgp_static_panel_CP2025.py | 13 +++---
doubleml/plm/plpr.py | 42 +++----------------
2 files changed, 14 insertions(+), 41 deletions(-)
diff --git a/doubleml/plm/datasets/dgp_static_panel_CP2025.py b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
index af9e11165..2dd7b0576 100644
--- a/doubleml/plm/datasets/dgp_static_panel_CP2025.py
+++ b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
@@ -2,7 +2,7 @@
import pandas as pd
-def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1'):
+def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1', x_var=5, a_var=0.95):
"""
Generates static panel data from the simulation dgp in Clarke and Polselli (2025).
@@ -18,7 +18,11 @@ def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type=
The value of the causal parameter.
dgp_type :
The type of DGP design to be used. Default is ``'dgp1'``, other options are ``'dgp2'`` and ``'dgp3'``.
-
+ x_var :
+ The variance of the covariates.
+ a_var :
+ The variance of the individual fixed effect on outcome
+
Returns
-------
pandas.DataFrame
@@ -28,8 +32,6 @@ def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type=
# parameters
a = 0.25
b = 0.5
- x_var = 5
- a_var = 0.95
# id and time vectors
id = np.repeat(np.arange(1, num_n+1), num_t)
@@ -40,7 +42,8 @@ def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type=
c_i = np.repeat(np.random.standard_normal(num_n), num_t)
# covariates and errors
- x_it = np.random.normal(loc=0, scale=np.sqrt(x_var), size=(num_n*num_t, dim_x))
+ x_mean = 0
+ x_it = np.random.normal(loc=x_mean, scale=np.sqrt(x_var), size=(num_n*num_t, dim_x))
u_it = np.random.standard_normal(num_n*num_t)
v_it = np.random.standard_normal(num_n*num_t)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 1957efcab..63e0f841a 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -10,7 +10,7 @@
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
from ..utils._estimation import _dml_cv_predict, _dml_tune
-from ..utils.blp import DoubleMLBLP
+# from ..utils.blp import DoubleMLBLP
class DoubleMLPLPR(LinearScoreMixin, DoubleML):
@@ -119,7 +119,7 @@ def __init__(
self._predict_method["ml_m"] = "predict"
self._initialize_ml_nuisance_params()
- self._sensitivity_implemented = False
+ self._sensitivity_implemented = False ###
self._external_predictions_implemented = True
def _initialize_ml_nuisance_params(self):
@@ -137,10 +137,10 @@ def _check_data(self, obj_dml_data):
)
return
- def _check_pdml_approach(pdml_approach, valid_pdml_approach):
+ def _check_pdml_approach(self, pdml_approach, valid_pdml_approach):
if isinstance(pdml_approach, str):
if pdml_approach not in valid_pdml_approach:
- raise ValueError("Invalid pdml_approach " + pdml_approach + ". " + "pdml_approach score " + " or ".join(pdml_approach) + ".")
+ raise ValueError("Invalid pdml_approach " + pdml_approach + ". " + "Valid approach " + " or ".join(valid_pdml_approach) + ".")
else:
raise TypeError(f"score should be a string. {str(pdml_approach)} was passed.")
return
@@ -262,38 +262,8 @@ def _score_elements(self, y, d, l_hat, m_hat, g_hat, smpls):
return psi_a, psi_b
- # def _sensitivity_element_est(self, preds):
- # # set elments for readability
- # y = self._dml_data.y
- # d = self._dml_data.d
-
- # m_hat = preds["predictions"]["ml_m"]
- # theta = self.all_coef[self._i_treat, self._i_rep]
-
- # if self.score == "partialling out":
- # l_hat = preds["predictions"]["ml_l"]
- # sigma2_score_element = np.square(y - l_hat - np.multiply(theta, d - m_hat))
- # else:
- # assert self.score == "IV-type"
- # g_hat = preds["predictions"]["ml_g"]
- # sigma2_score_element = np.square(y - g_hat - np.multiply(theta, d))
-
- # sigma2 = np.mean(sigma2_score_element)
- # psi_sigma2 = sigma2_score_element - sigma2
-
- # treatment_residual = d - m_hat
- # nu2 = np.divide(1.0, np.mean(np.square(treatment_residual)))
- # psi_nu2 = nu2 - np.multiply(np.square(treatment_residual), np.square(nu2))
- # rr = np.multiply(treatment_residual, nu2)
-
- # element_dict = {
- # "sigma2": sigma2,
- # "nu2": nu2,
- # "psi_sigma2": psi_sigma2,
- # "psi_nu2": psi_nu2,
- # "riesz_rep": rr,
- # }
- # return element_dict
+ def _sensitivity_element_est(self, preds):
+ pass
def _nuisance_tuning(
self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search
From 5f6a3743e832b8189c366022b12cb760e3dff79c Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 26 Jun 2025 17:51:36 +0200
Subject: [PATCH 03/33] add basic simulations
---
doubleml/plm/sim/example_sim.ipynb | 467 +++++++++++
doubleml/plm/sim/learners_sim.ipynb | 1189 +++++++++++++++++++++++++++
2 files changed, 1656 insertions(+)
create mode 100644 doubleml/plm/sim/example_sim.ipynb
create mode 100644 doubleml/plm/sim/learners_sim.ipynb
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
new file mode 100644
index 000000000..dcbc3cb25
--- /dev/null
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -0,0 +1,467 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "6dfa56df",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import statsmodels.api as sm\n",
+ "from doubleml.data.base_data import DoubleMLData\n",
+ "from doubleml.data.cluster_data import DoubleMLClusterData\n",
+ "from doubleml.plm.plpr import DoubleMLPLPR\n",
+ "from sklearn.linear_model import LassoCV, LinearRegression\n",
+ "from sklearn.base import clone\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.pipeline import make_pipeline\n",
+ "from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct\n",
+ "from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
+ "import warnings\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "3c061cf1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(np.float64(0.6719371174913912),\n",
+ " np.float64(0.6090488219157397),\n",
+ " np.float64(0.7348254130670426))"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "np.random.seed(1)\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "\n",
+ "x_cols = [col for col in data.columns if \"x\" in col]\n",
+ "\n",
+ "X = sm.add_constant(data[['d'] + x_cols])\n",
+ "y = data['y']\n",
+ "clusters = data['id']\n",
+ "\n",
+ "ols_model = sm.OLS(y, X).fit(cov_type='cluster', cov_kwds={'groups': clusters})\n",
+ "ols_model.params['d'], ols_model.conf_int().loc['d'][0], ols_model.conf_int().loc['d'][1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "73c6599b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.48547 0.020911 23.216093 3.131961e-119 0.444485 0.526455\n"
+ ]
+ }
+ ],
+ "source": [
+ "# cre general\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "cre_data = cre_fct(data)\n",
+ "\n",
+ "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLClusterData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=x_cols)\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l=ml_l, ml_m=ml_m,\n",
+ " pdml_approach='cre_general', n_folds=2,\n",
+ " )\n",
+ "obj_dml_plpr.fit()\n",
+ "print(obj_dml_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "5c700387",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "================== DoubleMLPLPR Object ==================\n",
+ "\n",
+ "------------------ Data summary ------------------\n",
+ "Outcome variable: y\n",
+ "Treatment variable(s): ['d']\n",
+ "Cluster variable(s): ['id']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
+ "Instrument variable(s): None\n",
+ "No. Observations: 2500\n",
+ "\n",
+ "------------------ Score & algorithm ------------------\n",
+ "Score function: partialling out\n",
+ "\n",
+ "------------------ Machine learner ------------------\n",
+ "Learner ml_l: LassoCV()\n",
+ "Learner ml_m: LassoCV()\n",
+ "Out-of-sample Performance:\n",
+ "Regression:\n",
+ "Learner ml_l RMSE: [[1.75556179]]\n",
+ "Learner ml_m RMSE: [[0.95043707]]\n",
+ "\n",
+ "------------------ Resampling ------------------\n",
+ "No. folds per cluster: 2\n",
+ "No. folds: 2\n",
+ "No. repeated sample splits: 1\n",
+ "\n",
+ "------------------ Fit summary ------------------\n",
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.48547 0.020911 23.216093 3.131961e-119 0.444485 0.526455\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(obj_dml_plpr)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 59,
+ "id": "24f06d62",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.498693 0.022082 22.584173 6.200824e-113 0.455414 0.541972\n"
+ ]
+ }
+ ],
+ "source": [
+ "# cre normality assumption\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "cre_data = cre_fct(data)\n",
+ "\n",
+ "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLClusterData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=x_cols)\n",
+ "\n",
+ "# learner = LassoCV()\n",
+ "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='cre')\n",
+ "obj_dml_plpr.fit()\n",
+ "print(obj_dml_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "61a72563",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.524973 0.022845 22.979858 7.412967e-117 0.480198 0.569749\n"
+ ]
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "fd_data = fd_fct(data)\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLClusterData(fd_data,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='transform')\n",
+ "obj_dml_plpr.fit()\n",
+ "print(obj_dml_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "id": "aeb00efe",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.476888 0.020457 23.311636 3.378549e-120 0.436792 0.516983\n"
+ ]
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "wd_data = wd_fct(data)\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLClusterData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")])\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='transform')\n",
+ "obj_dml_plpr.fit()\n",
+ "print(obj_dml_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "id": "586d5edf",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ }
+ ],
+ "source": [
+ "n_reps = 100\n",
+ "theta = 0.5\n",
+ "\n",
+ "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "leaner_ols = LinearRegression()\n",
+ "\n",
+ "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_general = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
+ "res_fd = np.full((n_reps, 3), np.nan)\n",
+ "res_fd_cluster = np.full((n_reps, 3), np.nan)\n",
+ "res_wd = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "np.random.seed(1)\n",
+ "\n",
+ "for i in range(n_reps):\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
+ "\n",
+ " # CRE general OLS\n",
+ " cre_data = cre_fct(data)\n",
+ " dml_data = DoubleMLClusterData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(leaner_ols), clone(leaner_ols), n_folds=5, pdml_approach='cre_general')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_ols[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_ols[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # CRE general Lasso\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='cre_general')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # CRE normality\n",
+ " dml_data = DoubleMLClusterData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
+ " )\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='cre')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # FD approach\n",
+ " fd_data = fd_fct(data)\n",
+ " dml_data = DoubleMLClusterData(fd_data, y_col='y_diff', d_cols='d_diff', cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
+ " dml_plpr.fit()\n",
+ " res_fd[i, 0] = dml_plpr.coef[0]\n",
+ " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " # no cluster\n",
+ " dml_data = DoubleMLData(fd_data, y_col='y_diff', d_cols='d_diff',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
+ " dml_plpr.fit()\n",
+ " res_fd_cluster[i, 0] = dml_plpr.coef[0]\n",
+ " res_fd_cluster[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd_cluster[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " \n",
+ " # WD approach\n",
+ " wd_data = wd_fct(data)\n",
+ " dml_data = DoubleMLClusterData(wd_data, y_col='y', d_cols='d', cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
+ " dml_plpr.fit()\n",
+ " res_wd[i, 0] = dml_plpr.coef[0]\n",
+ " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "edb9e3d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Coef | \n",
+ " Bias | \n",
+ " Coverage | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | CRE OLS | \n",
+ " 0.498516 | \n",
+ " -0.001484 | \n",
+ " 0.94 | \n",
+ "
\n",
+ " \n",
+ " | CRE general | \n",
+ " 0.517304 | \n",
+ " 0.017304 | \n",
+ " 0.91 | \n",
+ "
\n",
+ " \n",
+ " | CRE normality | \n",
+ " 0.540535 | \n",
+ " 0.040535 | \n",
+ " 0.80 | \n",
+ "
\n",
+ " \n",
+ " | FD | \n",
+ " 0.504695 | \n",
+ " 0.004695 | \n",
+ " 0.95 | \n",
+ "
\n",
+ " \n",
+ " | FD no cluster | \n",
+ " 0.503954 | \n",
+ " 0.003954 | \n",
+ " 0.87 | \n",
+ "
\n",
+ " \n",
+ " | WD | \n",
+ " 0.502402 | \n",
+ " 0.002402 | \n",
+ " 0.93 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Coef Bias Coverage\n",
+ "CRE OLS 0.498516 -0.001484 0.94\n",
+ "CRE general 0.517304 0.017304 0.91\n",
+ "CRE normality 0.540535 0.040535 0.80\n",
+ "FD 0.504695 0.004695 0.95\n",
+ "FD no cluster 0.503954 0.003954 0.87\n",
+ "WD 0.502402 0.002402 0.93"
+ ]
+ },
+ "execution_count": 63,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(np.vstack([res_cre_ols.mean(axis=0), res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
+ " res_fd.mean(axis=0), res_fd_cluster.mean(axis=0), res_wd.mean(axis=0)]), \n",
+ " columns=['Coef', 'Bias', 'Coverage'], \n",
+ " index=['CRE OLS', 'CRE general', 'CRE normality', \n",
+ " 'FD', 'FD no cluster', 'WD'])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/doubleml/plm/sim/learners_sim.ipynb b/doubleml/plm/sim/learners_sim.ipynb
new file mode 100644
index 000000000..b718bd7b4
--- /dev/null
+++ b/doubleml/plm/sim/learners_sim.ipynb
@@ -0,0 +1,1189 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "fe0a50cb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from doubleml.data.cluster_data import DoubleMLClusterData\n",
+ "from doubleml.plm.plpr import DoubleMLPLPR\n",
+ "from sklearn.linear_model import LassoCV, LinearRegression\n",
+ "from sklearn.base import clone\n",
+ "from sklearn.tree import DecisionTreeRegressor\n",
+ "from sklearn.ensemble import RandomForestRegressor\n",
+ "from lightgbm import LGBMRegressor\n",
+ "from doubleml.plm.utils._plpr_util import extend_data, cre_fct, fd_fct, wd_fct\n",
+ "from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.pipeline import make_pipeline\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "2715990b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ml_ols = LinearRegression()\n",
+ "\n",
+ "ml_lasso = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "ml_cart = DecisionTreeRegressor()\n",
+ "\n",
+ "ml_rf = RandomForestRegressor(n_estimators=100, \n",
+ " max_features=1.0, \n",
+ " min_samples_leaf=5)\n",
+ "\n",
+ "# Rf\n",
+ "# ml_rf_grid = {'ml_l': {'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ "# 'ml_m': {'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ "\n",
+ "# dml_plpr = DoubleMLPLPR(data_pdml, clone(ml_rf), clone(ml_rf), pdml_approach='cre', n_folds=5)\n",
+ "# dml_plpr.tune(param_grids=ml_rf_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ "# dml_plpr.fit(n_jobs_cv=5)\n",
+ "# res_cre_rf[i, 0] = dml_plpr.coef[0] - theta\n",
+ "# res_cre_rf[i, 1] = dml_plpr.se[0] \n",
+ "# confint = dml_plpr.confint()\n",
+ "# res_cre_rf[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ "ml_boost = LGBMRegressor(verbose=-1, \n",
+ " n_estimators=100, \n",
+ " learning_rate=0.3,\n",
+ " min_child_samples=1) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "dca81b0b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " time | \n",
+ " d | \n",
+ " y | \n",
+ " x1 | \n",
+ " x2 | \n",
+ " x3 | \n",
+ " x4 | \n",
+ " x5 | \n",
+ " x6 | \n",
+ " ... | \n",
+ " x21 | \n",
+ " x22 | \n",
+ " x23 | \n",
+ " x24 | \n",
+ " x25 | \n",
+ " x26 | \n",
+ " x27 | \n",
+ " x28 | \n",
+ " x29 | \n",
+ " x30 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " -2.213736 | \n",
+ " -5.293206 | \n",
+ " -0.428742 | \n",
+ " 7.394317 | \n",
+ " -3.040323 | \n",
+ " -0.556956 | \n",
+ " -1.971048 | \n",
+ " 3.811521 | \n",
+ " ... | \n",
+ " 6.212741 | \n",
+ " -4.062078 | \n",
+ " 3.238729 | \n",
+ " -9.268694 | \n",
+ " -5.648472 | \n",
+ " 9.005437 | \n",
+ " -1.537681 | \n",
+ " -4.250953 | \n",
+ " 4.311248 | \n",
+ " 1.155672 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " -0.705639 | \n",
+ " -1.099439 | \n",
+ " 6.859486 | \n",
+ " -3.476090 | \n",
+ " -2.070501 | \n",
+ " 0.951224 | \n",
+ " -0.037587 | \n",
+ " 5.537579 | \n",
+ " ... | \n",
+ " -3.596578 | \n",
+ " 2.580089 | \n",
+ " 1.015682 | \n",
+ " -11.813514 | \n",
+ " 6.059815 | \n",
+ " 9.918398 | \n",
+ " -0.825656 | \n",
+ " -1.491334 | \n",
+ " 2.950680 | \n",
+ " 4.083369 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1 | \n",
+ " 3 | \n",
+ " 7.975096 | \n",
+ " 10.156811 | \n",
+ " 8.308916 | \n",
+ " 2.863192 | \n",
+ " 4.956786 | \n",
+ " -2.209816 | \n",
+ " -3.509712 | \n",
+ " -5.863826 | \n",
+ " ... | \n",
+ " 5.826138 | \n",
+ " -0.334149 | \n",
+ " 1.253619 | \n",
+ " -6.735010 | \n",
+ " 0.705128 | \n",
+ " -3.978092 | \n",
+ " 3.948464 | \n",
+ " -3.216404 | \n",
+ " 0.912388 | \n",
+ " -1.348336 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1 | \n",
+ " 4 | \n",
+ " -1.259776 | \n",
+ " -2.755462 | \n",
+ " -8.534541 | \n",
+ " 4.878212 | \n",
+ " 1.080527 | \n",
+ " 9.905302 | \n",
+ " 4.336665 | \n",
+ " 7.272071 | \n",
+ " ... | \n",
+ " 7.800067 | \n",
+ " 6.354467 | \n",
+ " -3.945007 | \n",
+ " -6.898544 | \n",
+ " -0.649614 | \n",
+ " 2.045477 | \n",
+ " 4.110676 | \n",
+ " 7.025232 | \n",
+ " 4.912177 | \n",
+ " -1.291772 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 1 | \n",
+ " 5 | \n",
+ " -14.969484 | \n",
+ " -24.316493 | \n",
+ " 2.135777 | \n",
+ " 1.402142 | \n",
+ " -15.818097 | \n",
+ " 6.619429 | \n",
+ " -2.092901 | \n",
+ " -0.523111 | \n",
+ " ... | \n",
+ " -2.314368 | \n",
+ " 6.764651 | \n",
+ " -1.790039 | \n",
+ " -5.611779 | \n",
+ " -0.821312 | \n",
+ " -4.314345 | \n",
+ " -2.108172 | \n",
+ " 1.211618 | \n",
+ " 4.972882 | \n",
+ " -3.570822 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 995 | \n",
+ " 100 | \n",
+ " 6 | \n",
+ " -1.542279 | \n",
+ " -3.390828 | \n",
+ " 5.548290 | \n",
+ " 0.599051 | \n",
+ " -1.009998 | \n",
+ " -2.323957 | \n",
+ " 3.695354 | \n",
+ " -3.960357 | \n",
+ " ... | \n",
+ " -4.302796 | \n",
+ " 5.029354 | \n",
+ " -3.867411 | \n",
+ " -5.779731 | \n",
+ " 1.022882 | \n",
+ " -1.610741 | \n",
+ " -6.182220 | \n",
+ " 3.790532 | \n",
+ " 1.378363 | \n",
+ " -3.250114 | \n",
+ "
\n",
+ " \n",
+ " | 996 | \n",
+ " 100 | \n",
+ " 7 | \n",
+ " -5.195300 | \n",
+ " -6.324920 | \n",
+ " 0.875120 | \n",
+ " -1.827447 | \n",
+ " -3.364818 | \n",
+ " 0.381971 | \n",
+ " -2.868032 | \n",
+ " -0.474953 | \n",
+ " ... | \n",
+ " -7.145188 | \n",
+ " -2.962602 | \n",
+ " 6.051157 | \n",
+ " -0.310512 | \n",
+ " 3.926997 | \n",
+ " 6.150247 | \n",
+ " -0.424644 | \n",
+ " -0.768921 | \n",
+ " 1.381692 | \n",
+ " -4.279590 | \n",
+ "
\n",
+ " \n",
+ " | 997 | \n",
+ " 100 | \n",
+ " 8 | \n",
+ " -4.935131 | \n",
+ " -6.861080 | \n",
+ " -5.428853 | \n",
+ " 2.489955 | \n",
+ " -0.062190 | \n",
+ " 5.694679 | \n",
+ " -8.452585 | \n",
+ " -13.783593 | \n",
+ " ... | \n",
+ " 5.129312 | \n",
+ " 6.157375 | \n",
+ " -5.336022 | \n",
+ " -7.142420 | \n",
+ " 6.714065 | \n",
+ " -6.922350 | \n",
+ " 4.991919 | \n",
+ " 6.219515 | \n",
+ " -5.687230 | \n",
+ " -3.842934 | \n",
+ "
\n",
+ " \n",
+ " | 998 | \n",
+ " 100 | \n",
+ " 9 | \n",
+ " -10.656389 | \n",
+ " -17.293887 | \n",
+ " -8.934547 | \n",
+ " 2.080141 | \n",
+ " -6.928987 | \n",
+ " 1.259232 | \n",
+ " -7.253584 | \n",
+ " 1.381321 | \n",
+ " ... | \n",
+ " -0.607923 | \n",
+ " 10.114989 | \n",
+ " -4.271365 | \n",
+ " -8.707851 | \n",
+ " 1.071853 | \n",
+ " -1.960183 | \n",
+ " 0.646585 | \n",
+ " 4.832729 | \n",
+ " 3.747999 | \n",
+ " 0.701658 | \n",
+ "
\n",
+ " \n",
+ " | 999 | \n",
+ " 100 | \n",
+ " 10 | \n",
+ " -9.323197 | \n",
+ " -15.524491 | \n",
+ " -2.411281 | \n",
+ " 5.912876 | \n",
+ " -8.815911 | \n",
+ " -1.317385 | \n",
+ " 2.275555 | \n",
+ " 3.626635 | \n",
+ " ... | \n",
+ " 2.755303 | \n",
+ " 11.769591 | \n",
+ " -6.222097 | \n",
+ " -12.357810 | \n",
+ " -3.555324 | \n",
+ " 0.905898 | \n",
+ " 1.622200 | \n",
+ " -5.201272 | \n",
+ " 11.640755 | \n",
+ " -1.110845 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1000 rows × 34 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id time d y x1 x2 x3 x4 \\\n",
+ "0 1 1 -2.213736 -5.293206 -0.428742 7.394317 -3.040323 -0.556956 \n",
+ "1 1 2 -0.705639 -1.099439 6.859486 -3.476090 -2.070501 0.951224 \n",
+ "2 1 3 7.975096 10.156811 8.308916 2.863192 4.956786 -2.209816 \n",
+ "3 1 4 -1.259776 -2.755462 -8.534541 4.878212 1.080527 9.905302 \n",
+ "4 1 5 -14.969484 -24.316493 2.135777 1.402142 -15.818097 6.619429 \n",
+ ".. ... ... ... ... ... ... ... ... \n",
+ "995 100 6 -1.542279 -3.390828 5.548290 0.599051 -1.009998 -2.323957 \n",
+ "996 100 7 -5.195300 -6.324920 0.875120 -1.827447 -3.364818 0.381971 \n",
+ "997 100 8 -4.935131 -6.861080 -5.428853 2.489955 -0.062190 5.694679 \n",
+ "998 100 9 -10.656389 -17.293887 -8.934547 2.080141 -6.928987 1.259232 \n",
+ "999 100 10 -9.323197 -15.524491 -2.411281 5.912876 -8.815911 -1.317385 \n",
+ "\n",
+ " x5 x6 ... x21 x22 x23 x24 \\\n",
+ "0 -1.971048 3.811521 ... 6.212741 -4.062078 3.238729 -9.268694 \n",
+ "1 -0.037587 5.537579 ... -3.596578 2.580089 1.015682 -11.813514 \n",
+ "2 -3.509712 -5.863826 ... 5.826138 -0.334149 1.253619 -6.735010 \n",
+ "3 4.336665 7.272071 ... 7.800067 6.354467 -3.945007 -6.898544 \n",
+ "4 -2.092901 -0.523111 ... -2.314368 6.764651 -1.790039 -5.611779 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "995 3.695354 -3.960357 ... -4.302796 5.029354 -3.867411 -5.779731 \n",
+ "996 -2.868032 -0.474953 ... -7.145188 -2.962602 6.051157 -0.310512 \n",
+ "997 -8.452585 -13.783593 ... 5.129312 6.157375 -5.336022 -7.142420 \n",
+ "998 -7.253584 1.381321 ... -0.607923 10.114989 -4.271365 -8.707851 \n",
+ "999 2.275555 3.626635 ... 2.755303 11.769591 -6.222097 -12.357810 \n",
+ "\n",
+ " x25 x26 x27 x28 x29 x30 \n",
+ "0 -5.648472 9.005437 -1.537681 -4.250953 4.311248 1.155672 \n",
+ "1 6.059815 9.918398 -0.825656 -1.491334 2.950680 4.083369 \n",
+ "2 0.705128 -3.978092 3.948464 -3.216404 0.912388 -1.348336 \n",
+ "3 -0.649614 2.045477 4.110676 7.025232 4.912177 -1.291772 \n",
+ "4 -0.821312 -4.314345 -2.108172 1.211618 4.972882 -3.570822 \n",
+ ".. ... ... ... ... ... ... \n",
+ "995 1.022882 -1.610741 -6.182220 3.790532 1.378363 -3.250114 \n",
+ "996 3.926997 6.150247 -0.424644 -0.768921 1.381692 -4.279590 \n",
+ "997 6.714065 -6.922350 4.991919 6.219515 -5.687230 -3.842934 \n",
+ "998 1.071853 -1.960183 0.646585 4.832729 3.747999 0.701658 \n",
+ "999 -3.555324 0.905898 1.622200 -5.201272 11.640755 -1.110845 \n",
+ "\n",
+ "[1000 rows x 34 columns]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp1', x_var=5**2, a_var=0.95**2)\n",
+ "data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "d4580f3d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Index(['id', 'time', 'd', 'y', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8',\n",
+ " 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18',\n",
+ " 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28',\n",
+ " 'x29', 'x30'],\n",
+ " dtype='object')"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "wd_fct(data).columns"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f1342648",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ }
+ ],
+ "source": [
+ "n_reps = 20\n",
+ "theta = 0.5\n",
+ "dgp = 'dgp3'\n",
+ "\n",
+ "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_lasso = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_cart = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_boost = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "res_fd_ols = np.full((n_reps, 3), np.nan)\n",
+ "res_fd_lasso = np.full((n_reps, 3), np.nan)\n",
+ "res_fd_cart = np.full((n_reps, 3), np.nan)\n",
+ "res_fd_boost = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "res_wd_ols = np.full((n_reps, 3), np.nan)\n",
+ "res_wd_lasso = np.full((n_reps, 3), np.nan)\n",
+ "res_wd_cart = np.full((n_reps, 3), np.nan)\n",
+ "res_wd_boost = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "\n",
+ "for i in range(n_reps):\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ "\n",
+ " ml_cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ "\n",
+ " ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ " \n",
+ " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type=dgp, x_var=5**2, a_var=0.95**2)\n",
+ "\n",
+ " ## CRE\n",
+ " cre_data = cre_fct(data)\n",
+ "\n",
+ " data_cre_pdml = DoubleMLClusterData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ "\n",
+ " # OLS\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.fit()\n",
+ " res_cre_ols[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_ols[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Lasso\n",
+ " cre_data_ext = extend_data(cre_data)\n",
+ " data_cre_pdml_ext = DoubleMLClusterData(cre_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
+ " )\n",
+ "\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_lasso[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Cart\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_cart[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_cart[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Boost\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_boost[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_boost[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " ## FD\n",
+ " fd_data = fd_fct(data)\n",
+ "\n",
+ " data_fd_pdml = DoubleMLClusterData(fd_data,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if \"x\" in col])\n",
+ "\n",
+ " # OLS\n",
+ " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.fit()\n",
+ " res_fd_ols[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_fd_ols[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Lasso\n",
+ " fd_data_ext = extend_data(fd_data)\n",
+ " data_fd_pdml_ext = DoubleMLClusterData(fd_data_ext,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data_ext.columns if \"x\" in col]\n",
+ " )\n",
+ "\n",
+ " dml_plpr = DoubleMLPLPR(data_fd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_fd_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_fd_lasso[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Cart\n",
+ " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_fd_cart[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_fd_cart[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Boost\n",
+ " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_fd_boost[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_fd_boost[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " ## WD\n",
+ " wd_data = wd_fct(data)\n",
+ "\n",
+ " data_wd_pdml = DoubleMLClusterData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
+ "\n",
+ " # OLS\n",
+ " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.fit()\n",
+ " res_wd_ols[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_wd_ols[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Lasso\n",
+ " wd_data_ext = extend_data(wd_data)\n",
+ " data_wd_pdml_ext = DoubleMLClusterData(wd_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data_ext.columns if \"x\" in col]\n",
+ " )\n",
+ "\n",
+ " dml_plpr = DoubleMLPLPR(data_wd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_wd_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_wd_lasso[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Cart\n",
+ " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_wd_cart[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_wd_cart[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Boost\n",
+ " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='transform', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_wd_boost[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_wd_boost[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "cf49d96e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Bias | \n",
+ " SE | \n",
+ " Coverage | \n",
+ " SE/SD | \n",
+ " RMSE | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | OLS (CRE) | \n",
+ " 0.992168 | \n",
+ " 0.004074 | \n",
+ " 0.0 | \n",
+ " 0.111677 | \n",
+ " 0.992171 | \n",
+ "
\n",
+ " \n",
+ " | Lasso (CRE) | \n",
+ " 0.006154 | \n",
+ " 0.033672 | \n",
+ " 1.0 | \n",
+ " 0.090553 | \n",
+ " 0.028267 | \n",
+ "
\n",
+ " \n",
+ " | Cart (CRE) | \n",
+ " 0.665847 | \n",
+ " 0.087393 | \n",
+ " 0.0 | \n",
+ " 0.405625 | \n",
+ " 0.702282 | \n",
+ "
\n",
+ " \n",
+ " | Boost (CRE) | \n",
+ " 0.667262 | \n",
+ " 0.066499 | \n",
+ " 0.0 | \n",
+ " 0.225313 | \n",
+ " 0.672432 | \n",
+ "
\n",
+ " \n",
+ " | OLS (FD) | \n",
+ " 0.990554 | \n",
+ " 0.004807 | \n",
+ " 0.0 | \n",
+ " 0.096511 | \n",
+ " 0.990562 | \n",
+ "
\n",
+ " \n",
+ " | Lasso (FD) | \n",
+ " 0.025886 | \n",
+ " 0.039459 | \n",
+ " 1.0 | \n",
+ " 0.104909 | \n",
+ " 0.037748 | \n",
+ "
\n",
+ " \n",
+ " | Cart (FD) | \n",
+ " 0.851024 | \n",
+ " 0.045228 | \n",
+ " 0.0 | \n",
+ " 0.402826 | \n",
+ " 0.854973 | \n",
+ "
\n",
+ " \n",
+ " | Boost (FD) | \n",
+ " 0.835220 | \n",
+ " 0.040769 | \n",
+ " 0.0 | \n",
+ " 0.087226 | \n",
+ " 0.837982 | \n",
+ "
\n",
+ " \n",
+ " | OLS (WD) | \n",
+ " 0.992407 | \n",
+ " 0.004118 | \n",
+ " 0.0 | \n",
+ " 0.111584 | \n",
+ " 0.992413 | \n",
+ "
\n",
+ " \n",
+ " | Lasso (WD) | \n",
+ " 0.969836 | \n",
+ " 0.010887 | \n",
+ " 0.0 | \n",
+ " 0.084016 | \n",
+ " 0.969867 | \n",
+ "
\n",
+ " \n",
+ " | Cart (WD) | \n",
+ " 0.676682 | \n",
+ " 0.067397 | \n",
+ " 0.0 | \n",
+ " 0.365097 | \n",
+ " 0.709067 | \n",
+ "
\n",
+ " \n",
+ " | Boost (WD) | \n",
+ " 0.813528 | \n",
+ " 0.031009 | \n",
+ " 0.0 | \n",
+ " 0.057072 | \n",
+ " 0.816461 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Bias SE Coverage SE/SD RMSE\n",
+ "OLS (CRE) 0.992168 0.004074 0.0 0.111677 0.992171\n",
+ "Lasso (CRE) 0.006154 0.033672 1.0 0.090553 0.028267\n",
+ "Cart (CRE) 0.665847 0.087393 0.0 0.405625 0.702282\n",
+ "Boost (CRE) 0.667262 0.066499 0.0 0.225313 0.672432\n",
+ "OLS (FD) 0.990554 0.004807 0.0 0.096511 0.990562\n",
+ "Lasso (FD) 0.025886 0.039459 1.0 0.104909 0.037748\n",
+ "Cart (FD) 0.851024 0.045228 0.0 0.402826 0.854973\n",
+ "Boost (FD) 0.835220 0.040769 0.0 0.087226 0.837982\n",
+ "OLS (WD) 0.992407 0.004118 0.0 0.111584 0.992413\n",
+ "Lasso (WD) 0.969836 0.010887 0.0 0.084016 0.969867\n",
+ "Cart (WD) 0.676682 0.067397 0.0 0.365097 0.709067\n",
+ "Boost (WD) 0.813528 0.031009 0.0 0.057072 0.816461"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# cre general, dgp square\n",
+ "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
+ " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0),\n",
+ " res_fd_ols.mean(axis=0), res_fd_lasso.mean(axis=0), \n",
+ " res_fd_cart.mean(axis=0), res_fd_boost.mean(axis=0),\n",
+ " res_wd_ols.mean(axis=0), res_wd_lasso.mean(axis=0), \n",
+ " res_wd_cart.mean(axis=0), res_wd_boost.mean(axis=0)])\n",
+ "\n",
+ "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
+ " res_cre_cart[:,1].std(), res_cre_boost[:,1].std(),\n",
+ " res_fd_ols[:,1].std(), res_fd_lasso[:,1].std(), \n",
+ " res_fd_cart[:,1].std(), res_fd_boost[:,1].std(),\n",
+ " res_wd_ols[:,1].std(), res_wd_lasso[:,1].std(), \n",
+ " res_wd_cart[:,1].std(), res_wd_boost[:,1].std()])\n",
+ "\n",
+ "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
+ " res_cre_cart[:,1], res_cre_boost[:,1],\n",
+ " res_fd_ols[:,1], res_fd_lasso[:,1], \n",
+ " res_fd_cart[:,1], res_fd_boost[:,1],\n",
+ " res_wd_ols[:,1], res_wd_lasso[:,1], \n",
+ " res_wd_cart[:,1], res_wd_boost[:,1]])\n",
+ "\n",
+ "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
+ " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2)),\n",
+ " np.sqrt(np.mean(res_fd_ols[:,0]**2)), np.sqrt(np.mean(res_fd_lasso[:,0]**2)), \n",
+ " np.sqrt(np.mean(res_fd_cart[:,0]**2)), np.sqrt(np.mean(res_fd_boost[:,0]**2)),\n",
+ " np.sqrt(np.mean(res_wd_ols[:,0]**2)), np.sqrt(np.mean(res_wd_lasso[:,0]**2)), \n",
+ " np.sqrt(np.mean(res_wd_cart[:,0]**2)), np.sqrt(np.mean(res_wd_boost[:,0]**2))])\n",
+ "\n",
+ "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
+ "\n",
+ "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "\n",
+ "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
+ " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)',\n",
+ " 'OLS (FD)', 'Lasso (FD)', 'Cart (FD)', 'Boost (FD)',\n",
+ " 'OLS (WD)', 'Lasso (WD)', 'Cart (WD)', 'Boost (WD)'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "93efa3c9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ }
+ ],
+ "source": [
+ "n_reps = 20\n",
+ "theta = 0.5\n",
+ "dgp = 'dgp3'\n",
+ "\n",
+ "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_lasso = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_cart = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_boost = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "\n",
+ "for i in range(n_reps):\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ "\n",
+ " ml_cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ "\n",
+ " ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ " \n",
+ " data = make_static_panel_CP2025(num_n=4000, theta=theta, dgp_type=dgp, x_var=5**2, a_var=0.95**2)\n",
+ "\n",
+ " ## CRE\n",
+ " cre_data = cre_fct(data)\n",
+ "\n",
+ " data_cre_pdml = DoubleMLClusterData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ "\n",
+ " # OLS\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.fit()\n",
+ " res_cre_ols[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_ols[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Lasso\n",
+ " cre_data_ext = extend_data(cre_data)\n",
+ " data_cre_pdml_ext = DoubleMLClusterData(cre_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
+ " )\n",
+ "\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_lasso[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Cart\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_cart[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_cart[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # Boost\n",
+ " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='cre_general', n_folds=5)\n",
+ " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
+ " dml_plpr.fit(n_jobs_cv=5)\n",
+ " res_cre_boost[i, 0] = dml_plpr.coef[0] - theta\n",
+ " res_cre_boost[i, 1] = dml_plpr.se[0] \n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "9eda614d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Bias | \n",
+ " SE | \n",
+ " Coverage | \n",
+ " SE/SD | \n",
+ " RMSE | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | OLS (CRE) | \n",
+ " 0.992847 | \n",
+ " 0.000642 | \n",
+ " 0.0 | \n",
+ " 0.010548 | \n",
+ " 0.992848 | \n",
+ "
\n",
+ " \n",
+ " | Lasso (CRE) | \n",
+ " 0.025631 | \n",
+ " 0.012231 | \n",
+ " 0.4 | \n",
+ " 0.024428 | \n",
+ " 0.027179 | \n",
+ "
\n",
+ " \n",
+ " | Cart (CRE) | \n",
+ " -0.008529 | \n",
+ " 0.035518 | \n",
+ " 0.2 | \n",
+ " 0.293714 | \n",
+ " 0.174186 | \n",
+ "
\n",
+ " \n",
+ " | Boost (CRE) | \n",
+ " -0.194521 | \n",
+ " 0.018691 | \n",
+ " 0.0 | \n",
+ " 0.235885 | \n",
+ " 0.197435 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Bias SE Coverage SE/SD RMSE\n",
+ "OLS (CRE) 0.992847 0.000642 0.0 0.010548 0.992848\n",
+ "Lasso (CRE) 0.025631 0.012231 0.4 0.024428 0.027179\n",
+ "Cart (CRE) -0.008529 0.035518 0.2 0.293714 0.174186\n",
+ "Boost (CRE) -0.194521 0.018691 0.0 0.235885 0.197435"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# cre normal, dgp square\n",
+ "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
+ " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0)])\n",
+ "\n",
+ "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
+ " res_cre_cart[:,1].std(), res_cre_boost[:,1].std()])\n",
+ "\n",
+ "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
+ " res_cre_cart[:,1], res_cre_boost[:,1]])\n",
+ "\n",
+ "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
+ " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2))])\n",
+ "\n",
+ "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
+ "\n",
+ "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "\n",
+ "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
+ " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "cad7f1f9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Bias | \n",
+ " SE | \n",
+ " Coverage | \n",
+ " SE/SD | \n",
+ " RMSE | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | OLS (CRE) | \n",
+ " 0.992847 | \n",
+ " 0.000642 | \n",
+ " 0.0 | \n",
+ " 0.010548 | \n",
+ " 0.992848 | \n",
+ "
\n",
+ " \n",
+ " | Lasso (CRE) | \n",
+ " -0.008373 | \n",
+ " 0.005419 | \n",
+ " 0.6 | \n",
+ " 0.008168 | \n",
+ " 0.009689 | \n",
+ "
\n",
+ " \n",
+ " | Cart (CRE) | \n",
+ " 0.059338 | \n",
+ " 0.059713 | \n",
+ " 0.4 | \n",
+ " 0.452969 | \n",
+ " 0.187953 | \n",
+ "
\n",
+ " \n",
+ " | Boost (CRE) | \n",
+ " 0.036213 | \n",
+ " 0.031977 | \n",
+ " 0.7 | \n",
+ " 0.247513 | \n",
+ " 0.057652 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Bias SE Coverage SE/SD RMSE\n",
+ "OLS (CRE) 0.992847 0.000642 0.0 0.010548 0.992848\n",
+ "Lasso (CRE) -0.008373 0.005419 0.6 0.008168 0.009689\n",
+ "Cart (CRE) 0.059338 0.059713 0.4 0.452969 0.187953\n",
+ "Boost (CRE) 0.036213 0.031977 0.7 0.247513 0.057652"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# cre general, dgp square\n",
+ "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
+ " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0)])\n",
+ "\n",
+ "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
+ " res_cre_cart[:,1].std(), res_cre_boost[:,1].std()])\n",
+ "\n",
+ "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
+ " res_cre_cart[:,1], res_cre_boost[:,1]])\n",
+ "\n",
+ "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
+ " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2))])\n",
+ "\n",
+ "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
+ "\n",
+ "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "\n",
+ "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
+ " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)'])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From 215a5e5e30c5da682b7192811780be1728221781 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 10 Jul 2025 16:23:24 +0200
Subject: [PATCH 04/33] add plpr __str__, and checks
---
doubleml/plm/plpr.py | 63 +++++++++++++++++++++++++++++++++++++++++---
1 file changed, 60 insertions(+), 3 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 63e0f841a..bb9bf25b2 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -6,6 +6,7 @@
from sklearn.utils import check_X_y
from ..data.base_data import DoubleMLData
+from ..data.cluster_data import DoubleMLClusterData
from ..double_ml import DoubleML
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
@@ -122,18 +123,74 @@ def __init__(
self._sensitivity_implemented = False ###
self._external_predictions_implemented = True
+ def __str__(self):
+ class_name = self.__class__.__name__
+ header = f"================== {class_name} Object ==================\n"
+ data_summary = self._dml_data._data_summary_str()
+ score_pdml_approach_info = (
+ f"Score function: {str(self.score)}\n"
+ f"Static panel model approach: {str(self.pdml_approach)}\n"
+ )
+ learner_info = ""
+ for key, value in self.learner.items():
+ learner_info += f"Learner {key}: {str(value)}\n"
+ if self.nuisance_loss is not None:
+ learner_info += "Out-of-sample Performance:\n"
+ is_classifier = [value for value in self._is_classifier.values()]
+ is_regressor = [not value for value in is_classifier]
+ if any(is_regressor):
+ learner_info += "Regression:\n"
+ for learner in [key for key, value in self._is_classifier.items() if value is False]:
+ learner_info += f"Learner {learner} RMSE: {self.nuisance_loss[learner]}\n"
+ if any(is_classifier):
+ learner_info += "Classification:\n"
+ for learner in [key for key, value in self._is_classifier.items() if value is True]:
+ learner_info += f"Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n"
+
+ if self._is_cluster_data:
+ resampling_info = (
+ f"No. folds per cluster: {self._n_folds_per_cluster}\n"
+ f"No. folds: {self.n_folds}\n"
+ f"No. repeated sample splits: {self.n_rep}\n"
+ )
+ else:
+ resampling_info = f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}\n"
+ fit_summary = str(self.summary)
+ res = (
+ header
+ + "\n------------------ Data summary ------------------\n"
+ + data_summary
+ + "\n------------------ Score & algorithm ------------------\n"
+ + score_pdml_approach_info
+ + "\n------------------ Machine learner ------------------\n"
+ + learner_info
+ + "\n------------------ Resampling ------------------\n"
+ + resampling_info
+ + "\n------------------ Fit summary ------------------\n"
+ + fit_summary
+ )
+ return res
+
+ @property
+ def pdml_approach(self):
+ """
+ The score function.
+ """
+ return self._pdml_approach
+
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
+ # DoubleML init check for type already?
def _check_data(self, obj_dml_data):
- if not isinstance(obj_dml_data, DoubleMLData):
+ if not isinstance(obj_dml_data, DoubleMLClusterData):
raise TypeError(
- f"The data must be of DoubleMLData type. {str(obj_dml_data)} of type {str(type(obj_dml_data))} was passed."
+ f"The data must be of DoubleMLClusterData type. {str(type(obj_dml_data))} was passed."
)
if obj_dml_data.z_cols is not None:
raise ValueError(
"Incompatible data. " + " and ".join(obj_dml_data.z_cols) + " have been set as instrumental variable(s). "
- "To fit a partially linear IV regression model use DoubleMLPLIV instead of DoubleMLPLR."
+ "DoubleMLPLPR currently does not support instrumental variables."
)
return
From 4189671c572b7ed8880151c762618bd8e6b54590 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 10 Jul 2025 16:24:02 +0200
Subject: [PATCH 05/33] update example_sim
---
doubleml/plm/sim/example_sim.ipynb | 49 ++++++++++++++++++++----------
1 file changed, 33 insertions(+), 16 deletions(-)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index dcbc3cb25..6f7e5fd9c 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "6dfa56df",
"metadata": {},
"outputs": [],
@@ -17,7 +17,7 @@
"from sklearn.base import clone\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.pipeline import make_pipeline\n",
- "from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct\n",
+ "from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct, extend_data\n",
"from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
"import warnings\n",
"import warnings\n",
@@ -26,7 +26,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "3c061cf1",
"metadata": {},
"outputs": [
@@ -38,7 +38,7 @@
" np.float64(0.7348254130670426))"
]
},
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -59,7 +59,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 6,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -67,8 +67,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.48547 0.020911 23.216093 3.131961e-119 0.444485 0.526455\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.54181 0.021414 25.30162 3.065740e-141 0.499839 0.583781\n"
]
}
],
@@ -90,16 +90,32 @@
"ml_m = clone(learner)\n",
"\n",
"obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l=ml_l, ml_m=ml_m,\n",
- " pdml_approach='cre_general', n_folds=2,\n",
+ " pdml_approach='cre', n_folds=5\n",
" )\n",
+ "\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
- "id": "5c700387",
+ "execution_count": null,
+ "id": "7deabe55",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# model rmse\n",
+ "\n",
+ "# u_hat = obj_dml_plpr._dml_data.y - obj_dml_plpr.predictions['ml_l'].flatten()\n",
+ "# v_hat = obj_dml_plpr._dml_data.d - obj_dml_plpr.predictions['ml_m'].flatten()\n",
+ "\n",
+ "# np.sqrt(np.mean((u_hat - (obj_dml_plpr.coef[0] * v_hat))**2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "83663379",
"metadata": {},
"outputs": [
{
@@ -118,23 +134,24 @@
"\n",
"------------------ Score & algorithm ------------------\n",
"Score function: partialling out\n",
+ "Static panel model approach: cre\n",
"\n",
"------------------ Machine learner ------------------\n",
"Learner ml_l: LassoCV()\n",
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.75556179]]\n",
- "Learner ml_m RMSE: [[0.95043707]]\n",
+ "Learner ml_l RMSE: [[1.78726751]]\n",
+ "Learner ml_m RMSE: [[0.98667923]]\n",
"\n",
"------------------ Resampling ------------------\n",
- "No. folds per cluster: 2\n",
- "No. folds: 2\n",
+ "No. folds per cluster: 5\n",
+ "No. folds: 5\n",
"No. repeated sample splits: 1\n",
"\n",
"------------------ Fit summary ------------------\n",
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.48547 0.020911 23.216093 3.131961e-119 0.444485 0.526455\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.536846 0.020889 25.700208 1.162629e-145 0.495905 0.577787\n"
]
}
],
From 2c7e23883d3089720be83980f8923988603686b9 Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Thu, 9 Oct 2025 20:06:21 +0200
Subject: [PATCH 06/33] add model descriptions
---
doubleml/plm/sim/example_sim.ipynb | 99 +++++++++++++++++++++++++-----
1 file changed, 83 insertions(+), 16 deletions(-)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 6f7e5fd9c..757e1ba88 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -1,5 +1,72 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "id": "01469cb2",
+ "metadata": {},
+ "source": [
+ "### Double Machine Learning for Static Panel Models with Fixed Effects\n",
+ "\n",
+ "Extending the partially linear model to panel data by introducing fixed effects $\\alpha^*_i$ to give the partially linear panel regression (PLPR) model.\n",
+ "\n",
+ "Partialled-out PLPR (PO-PLPR) model:\n",
+ "\n",
+ "\\begin{align*}\n",
+ " Y_{it} &= \\theta_0 D_{it} + g_0(X_{it}) + \\alpha_i + U_{it} \\\\\n",
+ " D_{it} &= m_0(X_{it}) + \\gamma_i + V_{it}\n",
+ "\\end{align*},\n",
+ "\n",
+ "- $Y_{it}$ outcome, $D_{it}$ treatment, $X_{it}$ covariates, $\\theta_0$ causal treatment effect\n",
+ "- $g_0$ and $m_0$ nuisance functions\n",
+ "- $\\alpha_i$, $\\gamma_i$ unobserved individual heterogeneity, correlated with covariates\n",
+ "- $U_{it}$, $V_{it}$ error terms\n",
+ "\n",
+ "Further note $E[U_{it} \\mid D_{it}, X_{it}, \\alpha_i] = 0$, but $E[\\alpha_i \\mid D_{it}, X_{it}] \\neq 0$, and $E[V_{it} \\mid X_{it}, \\gamma_i]=0$\n",
+ "\n",
+ "#### 1 Correlated Random Effect Approach\n",
+ "\n",
+ "##### 1.1. General case:\n",
+ "\n",
+ "- Learning $g_0$ from $\\{ Y_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$\n",
+ "- First learning $m_0$ from $\\{ D_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$ with prediction $\\hat{m}_{0i} = \\tilde{m}_0 (X_{it}, \\bar{X}_i) $\n",
+ " - Calculate $\\hat{\\bar{m}}_i = T^{-1} \\sum_{t=1}^T \\hat{m}_{0i} $\n",
+ " - Calculate final nuisance part as $ \\hat{m}^*_0 (X_{it}, \\bar{X}_i, \\bar{D}_i) = \\hat{m}_{0i} + \\bar{D}_i - \\hat{\\bar{m}}_i $ \n",
+ "\n",
+ "##### 1.2. Normal assumption:\n",
+ "\n",
+ "(conditional distribution $ D_{i1}, \\dots, D_{iT} \\mid X_{i1}, \\dots X_{iT} $ is multivariate normal)\n",
+ "- Learn $m^*_{0i}$ from $\\{ D_{it}, X_{it}, \\bar{X}_i, \\bar{D}_i: t=1,\\dots, T \\}_{i=1}^N$\n",
+ "\n",
+ "#### 2. Transformation Approaches\n",
+ "\n",
+ "##### 2.1 First Difference (FD) Transformation - Exact\n",
+ "\n",
+ "Consider FD transformation $Q(Y_{it})= Y_{it} - Y_{it-1} $, under Assumptions 3.1-3.5, transformed nuisance function can be learnt as\n",
+ "\n",
+ "- $ \\Delta g_0 (X_{it-1}, X_{it}) $ from $ \\{ Y_{it}-Y_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
+ "- $ \\Delta m_0 (X_{it-1}, X_{it}) $ from $ \\{ D_{it}-D_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
+ "\n",
+ "##### 2.2 Within Group (WG) Transformation - Approximate\n",
+ "\n",
+ "For WG transformation $Q(X_{it})= X_{it} - \\bar{X}_{i} $, where $ \\bar{X}_{i} = T^{-1} \\sum_{t=1}^T W_{it} $. Approximate model\n",
+ "\\begin{align*}\n",
+ " Q(Y_{it}) &\\approx \\theta_0 Q(D_{it}) + g_0 (Q(X_{it})) + Q(U_{it}) \\\\\n",
+ " Q(D_{it}) &\\approx m_0 (Q(X_{it})) + Q(V_{it})\n",
+ "\\end{align*}\n",
+ "\n",
+ "- $g_0$ can be learnt from transformed data $ \\{ Q(Y_{it}), Q(X_{it}) : t=1,\\dots,T \\}_{i=1}^N $\n",
+ "- $m_0$ can be learnt from transformed data $ \\{ Q(D_{it}), Q(X_{it}) : t=1,\\dots,T \\}_{i=1}^N $\n",
+ "\n",
+ "#### Implementation\n",
+ "\n",
+ "- Using block-k-fold cross-fitting, where the entire time series of the sampled unit is allocated to one fold to allow for possible serial correlation\n",
+ "within each unit as is common with panel data\n",
+ "\n",
+ "- Cluster robust standard error\n",
+ "\n",
+ "$\\Rightarrow$ using id variable as cluster for DML"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -33,9 +100,9 @@
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913912),\n",
- " np.float64(0.6090488219157397),\n",
- " np.float64(0.7348254130670426))"
+ "(np.float64(0.6719371174913908),\n",
+ " np.float64(0.6090488219157394),\n",
+ " np.float64(0.7348254130670423))"
]
},
"execution_count": 2,
@@ -59,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -67,8 +134,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.54181 0.021414 25.30162 3.065740e-141 0.499839 0.583781\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
@@ -90,7 +157,7 @@
"ml_m = clone(learner)\n",
"\n",
"obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l=ml_l, ml_m=ml_m,\n",
- " pdml_approach='cre', n_folds=5\n",
+ " pdml_approach='cre_general', n_folds=5\n",
" )\n",
"\n",
"obj_dml_plpr.fit()\n",
@@ -114,7 +181,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"id": "83663379",
"metadata": {},
"outputs": [
@@ -134,15 +201,15 @@
"\n",
"------------------ Score & algorithm ------------------\n",
"Score function: partialling out\n",
- "Static panel model approach: cre\n",
+ "Static panel model approach: cre_general\n",
"\n",
"------------------ Machine learner ------------------\n",
"Learner ml_l: LassoCV()\n",
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.78726751]]\n",
- "Learner ml_m RMSE: [[0.98667923]]\n",
+ "Learner ml_l RMSE: [[1.72960098]]\n",
+ "Learner ml_m RMSE: [[0.95035703]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -151,7 +218,7 @@
"\n",
"------------------ Fit summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.536846 0.020889 25.700208 1.162629e-145 0.495905 0.577787\n"
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
@@ -161,7 +228,7 @@
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": 18,
"id": "24f06d62",
"metadata": {},
"outputs": [
@@ -169,8 +236,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498693 0.022082 22.584173 6.200824e-113 0.455414 0.541972\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.532542 0.02172 24.518225 9.443063e-133 0.489971 0.575113\n"
]
}
],
@@ -476,7 +543,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.6"
+ "version": "3.13.5"
}
},
"nbformat": 4,
From 7708f9bd48998d4b3c62b5033c5d23b829bd1564 Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Thu, 9 Oct 2025 20:07:57 +0200
Subject: [PATCH 07/33] fix typo
---
doubleml/plm/sim/example_sim.ipynb | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 757e1ba88..5c068f52d 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -14,7 +14,7 @@
"\\begin{align*}\n",
" Y_{it} &= \\theta_0 D_{it} + g_0(X_{it}) + \\alpha_i + U_{it} \\\\\n",
" D_{it} &= m_0(X_{it}) + \\gamma_i + V_{it}\n",
- "\\end{align*},\n",
+ "\\end{align*}\n",
"\n",
"- $Y_{it}$ outcome, $D_{it}$ treatment, $X_{it}$ covariates, $\\theta_0$ causal treatment effect\n",
"- $g_0$ and $m_0$ nuisance functions\n",
From d5bf9e9a86a997bae4c4b678367db6a02d8ea489 Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Fri, 10 Oct 2025 00:09:35 +0200
Subject: [PATCH 08/33] fix notation consistency
---
doubleml/plm/sim/example_sim.ipynb | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 5c068f52d..d7312c090 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -7,7 +7,7 @@
"source": [
"### Double Machine Learning for Static Panel Models with Fixed Effects\n",
"\n",
- "Extending the partially linear model to panel data by introducing fixed effects $\\alpha^*_i$ to give the partially linear panel regression (PLPR) model.\n",
+ "Extending the partially linear model to panel data by introducing fixed effects $\\alpha_i$ to give the partially linear panel regression (PLPR) model.\n",
"\n",
"Partialled-out PLPR (PO-PLPR) model:\n",
"\n",
@@ -28,14 +28,14 @@
"##### 1.1. General case:\n",
"\n",
"- Learning $g_0$ from $\\{ Y_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$\n",
- "- First learning $m_0$ from $\\{ D_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$ with prediction $\\hat{m}_{0i} = \\tilde{m}_0 (X_{it}, \\bar{X}_i) $\n",
- " - Calculate $\\hat{\\bar{m}}_i = T^{-1} \\sum_{t=1}^T \\hat{m}_{0i} $\n",
- " - Calculate final nuisance part as $ \\hat{m}^*_0 (X_{it}, \\bar{X}_i, \\bar{D}_i) = \\hat{m}_{0i} + \\bar{D}_i - \\hat{\\bar{m}}_i $ \n",
+ "- First learning $\\tilde{m}_0({\\cdot})$ from $\\{ D_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$ with prediction $\\hat{m}_{0it} = \\tilde{m}_0 (X_{it}, \\bar{X}_i) $\n",
+ " - Calculate $\\hat{\\bar{m}}_i = T^{-1} \\sum_{t=1}^T \\hat{m}_{0it} $\n",
+ " - Calculate final nuisance part as $ \\hat{m}^*_0 (X_{it}, \\bar{X}_i, \\bar{D}_i) = \\hat{m}_{0it} + \\bar{D}_i - \\hat{\\bar{m}}_i $ \n",
"\n",
"##### 1.2. Normal assumption:\n",
"\n",
"(conditional distribution $ D_{i1}, \\dots, D_{iT} \\mid X_{i1}, \\dots X_{iT} $ is multivariate normal)\n",
- "- Learn $m^*_{0i}$ from $\\{ D_{it}, X_{it}, \\bar{X}_i, \\bar{D}_i: t=1,\\dots, T \\}_{i=1}^N$\n",
+ "- Learn $m^*_{0}$ from $\\{ D_{it}, X_{it}, \\bar{X}_i, \\bar{D}_i: t=1,\\dots, T \\}_{i=1}^N$\n",
"\n",
"#### 2. Transformation Approaches\n",
"\n",
@@ -48,7 +48,7 @@
"\n",
"##### 2.2 Within Group (WG) Transformation - Approximate\n",
"\n",
- "For WG transformation $Q(X_{it})= X_{it} - \\bar{X}_{i} $, where $ \\bar{X}_{i} = T^{-1} \\sum_{t=1}^T W_{it} $. Approximate model\n",
+ "For WG transformation $Q(X_{it})= X_{it} - \\bar{X}_{i} $, where $ \\bar{X}_{i} = T^{-1} \\sum_{t=1}^T X_{it} $. Approximate model\n",
"\\begin{align*}\n",
" Q(Y_{it}) &\\approx \\theta_0 Q(D_{it}) + g_0 (Q(X_{it})) + Q(U_{it}) \\\\\n",
" Q(D_{it}) &\\approx m_0 (Q(X_{it})) + Q(V_{it})\n",
From d8c303984c51cffca1d3c771f25231b23b6b7dce Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Fri, 10 Oct 2025 00:12:41 +0200
Subject: [PATCH 09/33] update description numbering
---
doubleml/plm/sim/example_sim.ipynb | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index d7312c090..5ea1f2af0 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -39,14 +39,14 @@
"\n",
"#### 2. Transformation Approaches\n",
"\n",
- "##### 2.1 First Difference (FD) Transformation - Exact\n",
+ "##### 2.1. First Difference (FD) Transformation - Exact\n",
"\n",
"Consider FD transformation $Q(Y_{it})= Y_{it} - Y_{it-1} $, under Assumptions 3.1-3.5, transformed nuisance function can be learnt as\n",
"\n",
"- $ \\Delta g_0 (X_{it-1}, X_{it}) $ from $ \\{ Y_{it}-Y_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
"- $ \\Delta m_0 (X_{it-1}, X_{it}) $ from $ \\{ D_{it}-D_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
"\n",
- "##### 2.2 Within Group (WG) Transformation - Approximate\n",
+ "##### 2.2. Within Group (WG) Transformation - Approximate\n",
"\n",
"For WG transformation $Q(X_{it})= X_{it} - \\bar{X}_{i} $, where $ \\bar{X}_{i} = T^{-1} \\sum_{t=1}^T X_{it} $. Approximate model\n",
"\\begin{align*}\n",
From 753f68a360bd7a19f64fa64a7278536490357d11 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 27 Oct 2025 16:17:13 +0100
Subject: [PATCH 10/33] update from ClusterData to base Data class
---
doubleml/plm/plpr.py | 6 +-
doubleml/plm/sim/example_sim.ipynb | 105 +++---
doubleml/plm/sim/learners_sim.ipynb | 531 ++++++++++++++--------------
3 files changed, 322 insertions(+), 320 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index bb9bf25b2..852316725 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -6,7 +6,6 @@
from sklearn.utils import check_X_y
from ..data.base_data import DoubleMLData
-from ..data.cluster_data import DoubleMLClusterData
from ..double_ml import DoubleML
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
@@ -182,10 +181,11 @@ def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
# DoubleML init check for type already?
+ # TODO: Ensure cluster usage
def _check_data(self, obj_dml_data):
- if not isinstance(obj_dml_data, DoubleMLClusterData):
+ if not isinstance(obj_dml_data, DoubleMLData):
raise TypeError(
- f"The data must be of DoubleMLClusterData type. {str(type(obj_dml_data))} was passed."
+ f"The data must be of DoubleMLData type. {str(type(obj_dml_data))} was passed."
)
if obj_dml_data.z_cols is not None:
raise ValueError(
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 5ea1f2af0..cc4d968b1 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -78,7 +78,7 @@
"import pandas as pd\n",
"import statsmodels.api as sm\n",
"from doubleml.data.base_data import DoubleMLData\n",
- "from doubleml.data.cluster_data import DoubleMLClusterData\n",
+ "from doubleml.data.base_data import DoubleMLData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
"from sklearn.linear_model import LassoCV, LinearRegression\n",
"from sklearn.base import clone\n",
@@ -100,9 +100,9 @@
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913908),\n",
- " np.float64(0.6090488219157394),\n",
- " np.float64(0.7348254130670423))"
+ "(np.float64(0.6719371174913912),\n",
+ " np.float64(0.6090488219157397),\n",
+ " np.float64(0.7348254130670426))"
]
},
"execution_count": 2,
@@ -126,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 3,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -135,7 +135,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
]
}
],
@@ -146,11 +146,11 @@
"\n",
"x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
"\n",
- "obj_dml_data_pdml = DoubleMLClusterData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=x_cols)\n",
+ "obj_dml_data_pdml = DoubleMLData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=x_cols)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
@@ -181,7 +181,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 4,
"id": "83663379",
"metadata": {},
"outputs": [
@@ -194,9 +194,10 @@
"------------------ Data summary ------------------\n",
"Outcome variable: y\n",
"Treatment variable(s): ['d']\n",
- "Cluster variable(s): ['id']\n",
"Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
"Instrument variable(s): None\n",
+ "Cluster variable(s): ['id']\n",
+ "Is cluster data: True\n",
"No. Observations: 2500\n",
"\n",
"------------------ Score & algorithm ------------------\n",
@@ -208,8 +209,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.72960098]]\n",
- "Learner ml_m RMSE: [[0.95035703]]\n",
+ "Learner ml_l RMSE: [[1.63784321]]\n",
+ "Learner ml_m RMSE: [[0.96294553]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -218,7 +219,7 @@
"\n",
"------------------ Fit summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
]
}
],
@@ -228,7 +229,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 5,
"id": "24f06d62",
"metadata": {},
"outputs": [
@@ -237,7 +238,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.532542 0.02172 24.518225 9.443063e-133 0.489971 0.575113\n"
+ "d 0.548085 0.02097 26.136802 1.392306e-150 0.506985 0.589186\n"
]
}
],
@@ -248,11 +249,11 @@
"\n",
"x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
"\n",
- "obj_dml_data_pdml = DoubleMLClusterData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=x_cols)\n",
+ "obj_dml_data_pdml = DoubleMLData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=x_cols)\n",
"\n",
"# learner = LassoCV()\n",
"learner = make_pipeline(StandardScaler(), LassoCV())\n",
@@ -267,7 +268,7 @@
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": 6,
"id": "61a72563",
"metadata": {},
"outputs": [
@@ -275,8 +276,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.524973 0.022845 22.979858 7.412967e-117 0.480198 0.569749\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.492487 0.025352 19.42565 4.684180e-84 0.442797 0.542176\n"
]
}
],
@@ -284,11 +285,11 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"fd_data = fd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLClusterData(fd_data,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ "obj_dml_data_pdml = DoubleMLData(fd_data,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
@@ -301,7 +302,7 @@
},
{
"cell_type": "code",
- "execution_count": 61,
+ "execution_count": 10,
"id": "aeb00efe",
"metadata": {},
"outputs": [
@@ -309,8 +310,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.476888 0.020457 23.311636 3.378549e-120 0.436792 0.516983\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.5089 0.019913 25.555893 4.722065e-144 0.469871 0.547929\n"
]
}
],
@@ -318,11 +319,11 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"wd_data = wd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLClusterData(wd_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")])\n",
+ "obj_dml_data_pdml = DoubleMLData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")])\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
@@ -335,7 +336,7 @@
},
{
"cell_type": "code",
- "execution_count": 62,
+ "execution_count": 13,
"id": "586d5edf",
"metadata": {},
"outputs": [
@@ -370,8 +371,8 @@
"\n",
" # CRE general OLS\n",
" cre_data = cre_fct(data)\n",
- " dml_data = DoubleMLClusterData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ " dml_data = DoubleMLData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(leaner_ols), clone(leaner_ols), n_folds=5, pdml_approach='cre_general')\n",
" dml_plpr.fit()\n",
" res_cre_ols[i, 0] = dml_plpr.coef[0]\n",
@@ -388,8 +389,8 @@
" res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
"\n",
" # CRE normality\n",
- " dml_data = DoubleMLClusterData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
+ " dml_data = DoubleMLData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
" )\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='cre')\n",
" dml_plpr.fit()\n",
@@ -400,8 +401,8 @@
"\n",
" # FD approach\n",
" fd_data = fd_fct(data)\n",
- " dml_data = DoubleMLClusterData(fd_data, y_col='y_diff', d_cols='d_diff', cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ " dml_data = DoubleMLData(fd_data, y_col='y_diff', d_cols='d_diff', cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
" dml_plpr.fit()\n",
" res_fd[i, 0] = dml_plpr.coef[0]\n",
@@ -410,7 +411,7 @@
" res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
" # no cluster\n",
" dml_data = DoubleMLData(fd_data, y_col='y_diff', d_cols='d_diff',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
" dml_plpr.fit()\n",
" res_fd_cluster[i, 0] = dml_plpr.coef[0]\n",
@@ -420,8 +421,8 @@
" \n",
" # WD approach\n",
" wd_data = wd_fct(data)\n",
- " dml_data = DoubleMLClusterData(wd_data, y_col='y', d_cols='d', cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
+ " dml_data = DoubleMLData(wd_data, y_col='y', d_cols='d', cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
" dml_plpr.fit()\n",
" res_wd[i, 0] = dml_plpr.coef[0]\n",
@@ -432,7 +433,7 @@
},
{
"cell_type": "code",
- "execution_count": 63,
+ "execution_count": 16,
"id": "edb9e3d5",
"metadata": {},
"outputs": [
@@ -513,7 +514,7 @@
"WD 0.502402 0.002402 0.93"
]
},
- "execution_count": 63,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -543,7 +544,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.13.5"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/doubleml/plm/sim/learners_sim.ipynb b/doubleml/plm/sim/learners_sim.ipynb
index b718bd7b4..c29e03a6b 100644
--- a/doubleml/plm/sim/learners_sim.ipynb
+++ b/doubleml/plm/sim/learners_sim.ipynb
@@ -9,7 +9,7 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
- "from doubleml.data.cluster_data import DoubleMLClusterData\n",
+ "from doubleml.data.base_data import DoubleMLData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
"from sklearn.linear_model import LassoCV, LinearRegression\n",
"from sklearn.base import clone\n",
@@ -61,7 +61,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 3,
"id": "dca81b0b",
"metadata": {},
"outputs": [
@@ -114,121 +114,121 @@
" 0 | \n",
" 1 | \n",
" 1 | \n",
- " -2.213736 | \n",
- " -5.293206 | \n",
- " -0.428742 | \n",
- " 7.394317 | \n",
- " -3.040323 | \n",
- " -0.556956 | \n",
- " -1.971048 | \n",
- " 3.811521 | \n",
+ " -8.112787 | \n",
+ " -8.912584 | \n",
+ " -5.796365 | \n",
+ " -0.601492 | \n",
+ " -3.487003 | \n",
+ " 4.357256 | \n",
+ " -3.527997 | \n",
+ " -7.455948 | \n",
" ... | \n",
- " 6.212741 | \n",
- " -4.062078 | \n",
- " 3.238729 | \n",
- " -9.268694 | \n",
- " -5.648472 | \n",
- " 9.005437 | \n",
- " -1.537681 | \n",
- " -4.250953 | \n",
- " 4.311248 | \n",
- " 1.155672 | \n",
+ " 5.577388 | \n",
+ " -1.605127 | \n",
+ " -0.814059 | \n",
+ " -3.103182 | \n",
+ " 2.631538 | \n",
+ " -4.643003 | \n",
+ " 5.162550 | \n",
+ " 3.740774 | \n",
+ " 2.113925 | \n",
+ " 2.026183 | \n",
" \n",
" \n",
" | 1 | \n",
" 1 | \n",
" 2 | \n",
- " -0.705639 | \n",
- " -1.099439 | \n",
- " 6.859486 | \n",
- " -3.476090 | \n",
- " -2.070501 | \n",
- " 0.951224 | \n",
- " -0.037587 | \n",
- " 5.537579 | \n",
+ " -6.949439 | \n",
+ " -11.955038 | \n",
+ " -3.906188 | \n",
+ " 2.728437 | \n",
+ " -4.309356 | \n",
+ " 4.652335 | \n",
+ " 4.837147 | \n",
+ " 5.113480 | \n",
" ... | \n",
- " -3.596578 | \n",
- " 2.580089 | \n",
- " 1.015682 | \n",
- " -11.813514 | \n",
- " 6.059815 | \n",
- " 9.918398 | \n",
- " -0.825656 | \n",
- " -1.491334 | \n",
- " 2.950680 | \n",
- " 4.083369 | \n",
+ " -6.215166 | \n",
+ " -1.291356 | \n",
+ " 1.542859 | \n",
+ " -5.832660 | \n",
+ " -6.999235 | \n",
+ " -1.041017 | \n",
+ " 0.388897 | \n",
+ " 0.135666 | \n",
+ " -5.257444 | \n",
+ " -4.460909 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 3 | \n",
- " 7.975096 | \n",
- " 10.156811 | \n",
- " 8.308916 | \n",
- " 2.863192 | \n",
- " 4.956786 | \n",
- " -2.209816 | \n",
- " -3.509712 | \n",
- " -5.863826 | \n",
+ " -4.068573 | \n",
+ " -6.083197 | \n",
+ " 1.199280 | \n",
+ " 1.113007 | \n",
+ " -3.238536 | \n",
+ " 5.611841 | \n",
+ " -3.096405 | \n",
+ " 7.262224 | \n",
" ... | \n",
- " 5.826138 | \n",
- " -0.334149 | \n",
- " 1.253619 | \n",
- " -6.735010 | \n",
- " 0.705128 | \n",
- " -3.978092 | \n",
- " 3.948464 | \n",
- " -3.216404 | \n",
- " 0.912388 | \n",
- " -1.348336 | \n",
+ " -6.793106 | \n",
+ " 5.217539 | \n",
+ " 4.765350 | \n",
+ " 3.238961 | \n",
+ " -3.244586 | \n",
+ " 0.046503 | \n",
+ " 7.297417 | \n",
+ " 5.151098 | \n",
+ " 0.353556 | \n",
+ " -6.192547 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 4 | \n",
- " -1.259776 | \n",
- " -2.755462 | \n",
- " -8.534541 | \n",
- " 4.878212 | \n",
- " 1.080527 | \n",
- " 9.905302 | \n",
- " 4.336665 | \n",
- " 7.272071 | \n",
+ " 4.268473 | \n",
+ " 8.099756 | \n",
+ " -3.690119 | \n",
+ " -3.551698 | \n",
+ " 7.695905 | \n",
+ " 3.349990 | \n",
+ " -3.575687 | \n",
+ " -9.272200 | \n",
" ... | \n",
- " 7.800067 | \n",
- " 6.354467 | \n",
- " -3.945007 | \n",
- " -6.898544 | \n",
- " -0.649614 | \n",
- " 2.045477 | \n",
- " 4.110676 | \n",
- " 7.025232 | \n",
- " 4.912177 | \n",
- " -1.291772 | \n",
+ " 2.183245 | \n",
+ " -9.719218 | \n",
+ " -3.691420 | \n",
+ " -4.724887 | \n",
+ " -2.681429 | \n",
+ " -3.256659 | \n",
+ " 2.039591 | \n",
+ " -5.688881 | \n",
+ " -1.675406 | \n",
+ " -1.537060 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1 | \n",
" 5 | \n",
- " -14.969484 | \n",
- " -24.316493 | \n",
- " 2.135777 | \n",
- " 1.402142 | \n",
- " -15.818097 | \n",
- " 6.619429 | \n",
- " -2.092901 | \n",
- " -0.523111 | \n",
+ " -8.490611 | \n",
+ " -13.074335 | \n",
+ " -8.383416 | \n",
+ " 1.125561 | \n",
+ " -4.826987 | \n",
+ " 1.226380 | \n",
+ " 0.565376 | \n",
+ " 1.337693 | \n",
" ... | \n",
- " -2.314368 | \n",
- " 6.764651 | \n",
- " -1.790039 | \n",
- " -5.611779 | \n",
- " -0.821312 | \n",
- " -4.314345 | \n",
- " -2.108172 | \n",
- " 1.211618 | \n",
- " 4.972882 | \n",
- " -3.570822 | \n",
+ " -1.622405 | \n",
+ " -11.514240 | \n",
+ " -4.995206 | \n",
+ " -0.293343 | \n",
+ " 5.670162 | \n",
+ " 5.218059 | \n",
+ " -10.535997 | \n",
+ " -0.007612 | \n",
+ " 4.940226 | \n",
+ " -2.512659 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -258,121 +258,121 @@
" 995 | \n",
" 100 | \n",
" 6 | \n",
- " -1.542279 | \n",
- " -3.390828 | \n",
- " 5.548290 | \n",
- " 0.599051 | \n",
- " -1.009998 | \n",
- " -2.323957 | \n",
- " 3.695354 | \n",
- " -3.960357 | \n",
+ " 7.979518 | \n",
+ " 13.313478 | \n",
+ " 0.743929 | \n",
+ " 0.479841 | \n",
+ " 8.463661 | \n",
+ " -3.785925 | \n",
+ " 3.066799 | \n",
+ " -5.972398 | \n",
" ... | \n",
- " -4.302796 | \n",
- " 5.029354 | \n",
- " -3.867411 | \n",
- " -5.779731 | \n",
- " 1.022882 | \n",
- " -1.610741 | \n",
- " -6.182220 | \n",
- " 3.790532 | \n",
- " 1.378363 | \n",
- " -3.250114 | \n",
+ " -8.675939 | \n",
+ " -0.339098 | \n",
+ " 0.200580 | \n",
+ " 4.741587 | \n",
+ " 3.884253 | \n",
+ " 0.082965 | \n",
+ " -3.765886 | \n",
+ " 2.210837 | \n",
+ " -2.203842 | \n",
+ " 9.350995 | \n",
"
\n",
" \n",
" | 996 | \n",
" 100 | \n",
" 7 | \n",
- " -5.195300 | \n",
- " -6.324920 | \n",
- " 0.875120 | \n",
- " -1.827447 | \n",
- " -3.364818 | \n",
- " 0.381971 | \n",
- " -2.868032 | \n",
- " -0.474953 | \n",
+ " 4.525037 | \n",
+ " 7.323752 | \n",
+ " 2.795891 | \n",
+ " -0.028399 | \n",
+ " 3.351155 | \n",
+ " -13.480410 | \n",
+ " 4.504775 | \n",
+ " 2.866025 | \n",
" ... | \n",
- " -7.145188 | \n",
- " -2.962602 | \n",
- " 6.051157 | \n",
- " -0.310512 | \n",
- " 3.926997 | \n",
- " 6.150247 | \n",
- " -0.424644 | \n",
- " -0.768921 | \n",
- " 1.381692 | \n",
- " -4.279590 | \n",
+ " 2.935810 | \n",
+ " -6.909156 | \n",
+ " -6.092518 | \n",
+ " 7.090190 | \n",
+ " -0.192387 | \n",
+ " -0.971816 | \n",
+ " 2.114409 | \n",
+ " 7.572450 | \n",
+ " -3.337941 | \n",
+ " 4.831238 | \n",
"
\n",
" \n",
" | 997 | \n",
" 100 | \n",
" 8 | \n",
- " -4.935131 | \n",
- " -6.861080 | \n",
- " -5.428853 | \n",
- " 2.489955 | \n",
- " -0.062190 | \n",
- " 5.694679 | \n",
- " -8.452585 | \n",
- " -13.783593 | \n",
+ " 2.510815 | \n",
+ " 3.504373 | \n",
+ " 4.272010 | \n",
+ " -3.236265 | \n",
+ " 1.253958 | \n",
+ " 1.062489 | \n",
+ " -7.690689 | \n",
+ " 6.750913 | \n",
" ... | \n",
- " 5.129312 | \n",
- " 6.157375 | \n",
- " -5.336022 | \n",
- " -7.142420 | \n",
- " 6.714065 | \n",
- " -6.922350 | \n",
- " 4.991919 | \n",
- " 6.219515 | \n",
- " -5.687230 | \n",
- " -3.842934 | \n",
+ " -9.397734 | \n",
+ " 1.931898 | \n",
+ " 7.888287 | \n",
+ " 0.276521 | \n",
+ " 3.114361 | \n",
+ " 4.152857 | \n",
+ " 0.079838 | \n",
+ " 2.297878 | \n",
+ " 9.451616 | \n",
+ " -1.324771 | \n",
"
\n",
" \n",
" | 998 | \n",
" 100 | \n",
" 9 | \n",
- " -10.656389 | \n",
- " -17.293887 | \n",
- " -8.934547 | \n",
- " 2.080141 | \n",
- " -6.928987 | \n",
- " 1.259232 | \n",
- " -7.253584 | \n",
- " 1.381321 | \n",
+ " -4.087541 | \n",
+ " -3.451450 | \n",
+ " 0.115834 | \n",
+ " -2.387410 | \n",
+ " -1.961343 | \n",
+ " -4.106975 | \n",
+ " 4.037239 | \n",
+ " -3.903956 | \n",
" ... | \n",
- " -0.607923 | \n",
- " 10.114989 | \n",
- " -4.271365 | \n",
- " -8.707851 | \n",
- " 1.071853 | \n",
- " -1.960183 | \n",
- " 0.646585 | \n",
- " 4.832729 | \n",
- " 3.747999 | \n",
- " 0.701658 | \n",
+ " -5.021652 | \n",
+ " 1.694328 | \n",
+ " -1.283313 | \n",
+ " 7.283484 | \n",
+ " 8.015243 | \n",
+ " 6.879811 | \n",
+ " -7.213541 | \n",
+ " -2.226587 | \n",
+ " -0.305480 | \n",
+ " -1.568153 | \n",
"
\n",
" \n",
" | 999 | \n",
" 100 | \n",
" 10 | \n",
- " -9.323197 | \n",
- " -15.524491 | \n",
- " -2.411281 | \n",
- " 5.912876 | \n",
- " -8.815911 | \n",
- " -1.317385 | \n",
- " 2.275555 | \n",
- " 3.626635 | \n",
+ " -8.074941 | \n",
+ " -12.453872 | \n",
+ " -0.695072 | \n",
+ " -1.788528 | \n",
+ " -7.955557 | \n",
+ " 4.716530 | \n",
+ " 5.760638 | \n",
+ " -6.033057 | \n",
" ... | \n",
- " 2.755303 | \n",
- " 11.769591 | \n",
- " -6.222097 | \n",
- " -12.357810 | \n",
- " -3.555324 | \n",
- " 0.905898 | \n",
- " 1.622200 | \n",
- " -5.201272 | \n",
- " 11.640755 | \n",
- " -1.110845 | \n",
+ " 2.323859 | \n",
+ " 0.301849 | \n",
+ " 0.853097 | \n",
+ " 3.270169 | \n",
+ " 3.749521 | \n",
+ " -2.260064 | \n",
+ " 5.343868 | \n",
+ " -0.764016 | \n",
+ " 2.769752 | \n",
+ " -4.067194 | \n",
"
\n",
" \n",
"\n",
@@ -380,49 +380,49 @@
""
],
"text/plain": [
- " id time d y x1 x2 x3 x4 \\\n",
- "0 1 1 -2.213736 -5.293206 -0.428742 7.394317 -3.040323 -0.556956 \n",
- "1 1 2 -0.705639 -1.099439 6.859486 -3.476090 -2.070501 0.951224 \n",
- "2 1 3 7.975096 10.156811 8.308916 2.863192 4.956786 -2.209816 \n",
- "3 1 4 -1.259776 -2.755462 -8.534541 4.878212 1.080527 9.905302 \n",
- "4 1 5 -14.969484 -24.316493 2.135777 1.402142 -15.818097 6.619429 \n",
- ".. ... ... ... ... ... ... ... ... \n",
- "995 100 6 -1.542279 -3.390828 5.548290 0.599051 -1.009998 -2.323957 \n",
- "996 100 7 -5.195300 -6.324920 0.875120 -1.827447 -3.364818 0.381971 \n",
- "997 100 8 -4.935131 -6.861080 -5.428853 2.489955 -0.062190 5.694679 \n",
- "998 100 9 -10.656389 -17.293887 -8.934547 2.080141 -6.928987 1.259232 \n",
- "999 100 10 -9.323197 -15.524491 -2.411281 5.912876 -8.815911 -1.317385 \n",
+ " id time d y x1 x2 x3 x4 \\\n",
+ "0 1 1 -8.112787 -8.912584 -5.796365 -0.601492 -3.487003 4.357256 \n",
+ "1 1 2 -6.949439 -11.955038 -3.906188 2.728437 -4.309356 4.652335 \n",
+ "2 1 3 -4.068573 -6.083197 1.199280 1.113007 -3.238536 5.611841 \n",
+ "3 1 4 4.268473 8.099756 -3.690119 -3.551698 7.695905 3.349990 \n",
+ "4 1 5 -8.490611 -13.074335 -8.383416 1.125561 -4.826987 1.226380 \n",
+ ".. ... ... ... ... ... ... ... ... \n",
+ "995 100 6 7.979518 13.313478 0.743929 0.479841 8.463661 -3.785925 \n",
+ "996 100 7 4.525037 7.323752 2.795891 -0.028399 3.351155 -13.480410 \n",
+ "997 100 8 2.510815 3.504373 4.272010 -3.236265 1.253958 1.062489 \n",
+ "998 100 9 -4.087541 -3.451450 0.115834 -2.387410 -1.961343 -4.106975 \n",
+ "999 100 10 -8.074941 -12.453872 -0.695072 -1.788528 -7.955557 4.716530 \n",
"\n",
- " x5 x6 ... x21 x22 x23 x24 \\\n",
- "0 -1.971048 3.811521 ... 6.212741 -4.062078 3.238729 -9.268694 \n",
- "1 -0.037587 5.537579 ... -3.596578 2.580089 1.015682 -11.813514 \n",
- "2 -3.509712 -5.863826 ... 5.826138 -0.334149 1.253619 -6.735010 \n",
- "3 4.336665 7.272071 ... 7.800067 6.354467 -3.945007 -6.898544 \n",
- "4 -2.092901 -0.523111 ... -2.314368 6.764651 -1.790039 -5.611779 \n",
- ".. ... ... ... ... ... ... ... \n",
- "995 3.695354 -3.960357 ... -4.302796 5.029354 -3.867411 -5.779731 \n",
- "996 -2.868032 -0.474953 ... -7.145188 -2.962602 6.051157 -0.310512 \n",
- "997 -8.452585 -13.783593 ... 5.129312 6.157375 -5.336022 -7.142420 \n",
- "998 -7.253584 1.381321 ... -0.607923 10.114989 -4.271365 -8.707851 \n",
- "999 2.275555 3.626635 ... 2.755303 11.769591 -6.222097 -12.357810 \n",
+ " x5 x6 ... x21 x22 x23 x24 \\\n",
+ "0 -3.527997 -7.455948 ... 5.577388 -1.605127 -0.814059 -3.103182 \n",
+ "1 4.837147 5.113480 ... -6.215166 -1.291356 1.542859 -5.832660 \n",
+ "2 -3.096405 7.262224 ... -6.793106 5.217539 4.765350 3.238961 \n",
+ "3 -3.575687 -9.272200 ... 2.183245 -9.719218 -3.691420 -4.724887 \n",
+ "4 0.565376 1.337693 ... -1.622405 -11.514240 -4.995206 -0.293343 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "995 3.066799 -5.972398 ... -8.675939 -0.339098 0.200580 4.741587 \n",
+ "996 4.504775 2.866025 ... 2.935810 -6.909156 -6.092518 7.090190 \n",
+ "997 -7.690689 6.750913 ... -9.397734 1.931898 7.888287 0.276521 \n",
+ "998 4.037239 -3.903956 ... -5.021652 1.694328 -1.283313 7.283484 \n",
+ "999 5.760638 -6.033057 ... 2.323859 0.301849 0.853097 3.270169 \n",
"\n",
- " x25 x26 x27 x28 x29 x30 \n",
- "0 -5.648472 9.005437 -1.537681 -4.250953 4.311248 1.155672 \n",
- "1 6.059815 9.918398 -0.825656 -1.491334 2.950680 4.083369 \n",
- "2 0.705128 -3.978092 3.948464 -3.216404 0.912388 -1.348336 \n",
- "3 -0.649614 2.045477 4.110676 7.025232 4.912177 -1.291772 \n",
- "4 -0.821312 -4.314345 -2.108172 1.211618 4.972882 -3.570822 \n",
- ".. ... ... ... ... ... ... \n",
- "995 1.022882 -1.610741 -6.182220 3.790532 1.378363 -3.250114 \n",
- "996 3.926997 6.150247 -0.424644 -0.768921 1.381692 -4.279590 \n",
- "997 6.714065 -6.922350 4.991919 6.219515 -5.687230 -3.842934 \n",
- "998 1.071853 -1.960183 0.646585 4.832729 3.747999 0.701658 \n",
- "999 -3.555324 0.905898 1.622200 -5.201272 11.640755 -1.110845 \n",
+ " x25 x26 x27 x28 x29 x30 \n",
+ "0 2.631538 -4.643003 5.162550 3.740774 2.113925 2.026183 \n",
+ "1 -6.999235 -1.041017 0.388897 0.135666 -5.257444 -4.460909 \n",
+ "2 -3.244586 0.046503 7.297417 5.151098 0.353556 -6.192547 \n",
+ "3 -2.681429 -3.256659 2.039591 -5.688881 -1.675406 -1.537060 \n",
+ "4 5.670162 5.218059 -10.535997 -0.007612 4.940226 -2.512659 \n",
+ ".. ... ... ... ... ... ... \n",
+ "995 3.884253 0.082965 -3.765886 2.210837 -2.203842 9.350995 \n",
+ "996 -0.192387 -0.971816 2.114409 7.572450 -3.337941 4.831238 \n",
+ "997 3.114361 4.152857 0.079838 2.297878 9.451616 -1.324771 \n",
+ "998 8.015243 6.879811 -7.213541 -2.226587 -0.305480 -1.568153 \n",
+ "999 3.749521 -2.260064 5.343868 -0.764016 2.769752 -4.067194 \n",
"\n",
"[1000 rows x 34 columns]"
]
},
- "execution_count": 7,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -434,7 +434,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 4,
"id": "d4580f3d",
"metadata": {},
"outputs": [
@@ -448,7 +448,7 @@
" dtype='object')"
]
},
- "execution_count": 8,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -512,11 +512,11 @@
" ## CRE\n",
" cre_data = cre_fct(data)\n",
"\n",
- " data_cre_pdml = DoubleMLClusterData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ " data_cre_pdml = DoubleMLData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
"\n",
" # OLS\n",
" dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
@@ -528,12 +528,12 @@
"\n",
" # Lasso\n",
" cre_data_ext = extend_data(cre_data)\n",
- " data_cre_pdml_ext = DoubleMLClusterData(cre_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " data_cre_pdml_ext = DoubleMLData(cre_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
+ " )\n",
"\n",
" dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
" dml_plpr.fit(n_jobs_cv=5)\n",
@@ -563,11 +563,11 @@
" ## FD\n",
" fd_data = fd_fct(data)\n",
"\n",
- " data_fd_pdml = DoubleMLClusterData(fd_data,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if \"x\" in col])\n",
+ " data_fd_pdml = DoubleMLData(fd_data,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data.columns if \"x\" in col])\n",
"\n",
" # OLS\n",
" dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
@@ -579,12 +579,12 @@
"\n",
" # Lasso\n",
" fd_data_ext = extend_data(fd_data)\n",
- " data_fd_pdml_ext = DoubleMLClusterData(fd_data_ext,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " data_fd_pdml_ext = DoubleMLData(fd_data_ext,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in fd_data_ext.columns if \"x\" in col]\n",
+ " )\n",
"\n",
" dml_plpr = DoubleMLPLPR(data_fd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
" dml_plpr.fit(n_jobs_cv=5)\n",
@@ -614,11 +614,11 @@
" ## WD\n",
" wd_data = wd_fct(data)\n",
"\n",
- " data_wd_pdml = DoubleMLClusterData(wd_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
+ " data_wd_pdml = DoubleMLData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
"\n",
" # OLS\n",
" dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
@@ -630,12 +630,12 @@
"\n",
" # Lasso\n",
" wd_data_ext = extend_data(wd_data)\n",
- " data_wd_pdml_ext = DoubleMLClusterData(wd_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " data_wd_pdml_ext = DoubleMLData(wd_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in wd_data_ext.columns if \"x\" in col]\n",
+ " )\n",
"\n",
" dml_plpr = DoubleMLPLPR(data_wd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
" dml_plpr.fit(n_jobs_cv=5)\n",
@@ -903,11 +903,12 @@
" ## CRE\n",
" cre_data = cre_fct(data)\n",
"\n",
- " data_cre_pdml = DoubleMLClusterData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ " data_cre_pdml = DoubleMLData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
+ " )\n",
"\n",
" # OLS\n",
" dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
@@ -919,12 +920,12 @@
"\n",
" # Lasso\n",
" cre_data_ext = extend_data(cre_data)\n",
- " data_cre_pdml_ext = DoubleMLClusterData(cre_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " data_cre_pdml_ext = DoubleMLData(cre_data_ext,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " cluster_cols='id',\n",
+ " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
+ " )\n",
"\n",
" dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
" dml_plpr.fit(n_jobs_cv=5)\n",
From 058da4eba0354b567e0f7a97de1cb9e7630ebcca Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 27 Oct 2025 17:31:23 +0100
Subject: [PATCH 11/33] add static_panel flag in PanelData
---
doubleml/data/panel_data.py | 94 ++++++++++++++++++++----------
doubleml/plm/sim/example_sim.ipynb | 66 +++++++++++++++++++--
2 files changed, 126 insertions(+), 34 deletions(-)
diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py
index 22aad0f73..87fae7f95 100644
--- a/doubleml/data/panel_data.py
+++ b/doubleml/data/panel_data.py
@@ -57,6 +57,12 @@ class DoubleMLPanelData(DoubleMLData):
datetime_unit : str
The unit of the time and treatment variable (if datetime type).
+ static_panel : bool
+ Indicates whether the data model corresponds to the standard panel data, where the treatment variable(s) indicate
+ the treatment groups in terms of first time of treatment exposure, or if the data model corresponds to a static
+ panel data approach.
+ Default is ``False``.
+
Examples
--------
>>> from doubleml.did.datasets import make_did_CS2021
@@ -85,44 +91,72 @@ def __init__(
use_other_treat_as_covariate=True,
force_all_x_finite=True,
datetime_unit="M",
+ static_panel=False,
):
DoubleMLBaseData.__init__(self, data)
- # we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
- self.id_col = id_col
- self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
- self._set_id_var()
-
- # Set time column before calling parent constructor
- self.t_col = t_col
-
- # Call parent constructor
- DoubleMLData.__init__(
- self,
- data=data,
- y_col=y_col,
- d_cols=d_cols,
- x_cols=x_cols,
- z_cols=z_cols,
- use_other_treat_as_covariate=use_other_treat_as_covariate,
- force_all_x_finite=force_all_x_finite,
- force_all_d_finite=False,
- )
+ if not static_panel:
+ # we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
+ self.id_col = id_col
+ self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
+ self._set_id_var()
- # reset index to ensure a simple RangeIndex
- self.data.reset_index(drop=True, inplace=True)
+ # Set time column before calling parent constructor
+ self.t_col = t_col
+
+ # Call parent constructor
+ DoubleMLData.__init__(
+ self,
+ data=data,
+ y_col=y_col,
+ d_cols=d_cols,
+ x_cols=x_cols,
+ z_cols=z_cols,
+ use_other_treat_as_covariate=use_other_treat_as_covariate,
+ force_all_x_finite=force_all_x_finite,
+ force_all_d_finite=False,
+ )
- # Set time variable array after data is loaded
- self._set_time_var()
+ # reset index to ensure a simple RangeIndex
+ self.data.reset_index(drop=True, inplace=True)
- if self.n_treat != 1:
- raise ValueError("Only one treatment column is allowed for panel data.")
+ # Set time variable array after data is loaded
+ self._set_time_var()
- self._check_disjoint_sets_id_col()
+ if self.n_treat != 1:
+ raise ValueError("Only one treatment column is allowed for panel data.")
+
+ self._check_disjoint_sets_id_col()
+
+ # intialize the unique values of g and t
+ self._g_values = np.sort(np.unique(self.d)) # unique values of g
+ self._t_values = np.sort(np.unique(self.t)) # unique values of t
+
+ else:
+ # static panel type data class, where id column is used as the cluster variable
+
+ self.id_col = id_col
+ self._set_id_var()
+ self.t_col = t_col
+
+ DoubleMLData.__init__(
+ self,
+ data=data,
+ y_col=y_col,
+ d_cols=d_cols,
+ x_cols=x_cols,
+ z_cols=z_cols,
+ cluster_cols=id_col,
+ use_other_treat_as_covariate=use_other_treat_as_covariate,
+ force_all_x_finite=force_all_x_finite,
+ force_all_d_finite=False,
+ )
- # intialize the unique values of g and t
- self._g_values = np.sort(np.unique(self.d)) # unique values of g
- self._t_values = np.sort(np.unique(self.t)) # unique values of t
+ if self.n_treat != 1:
+ raise ValueError("Only one treatment column is allowed for panel data.")
+
+ if self.z_cols is not None:
+ raise ValueError("Static panel data currently does not support instrumental variables.")
def __str__(self):
data_summary = self._data_summary_str()
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index cc4d968b1..89dcafd98 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -78,7 +78,6 @@
"import pandas as pd\n",
"import statsmodels.api as sm\n",
"from doubleml.data.base_data import DoubleMLData\n",
- "from doubleml.data.base_data import DoubleMLData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
"from sklearn.linear_model import LassoCV, LinearRegression\n",
"from sklearn.base import clone\n",
@@ -93,7 +92,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 18,
"id": "3c061cf1",
"metadata": {},
"outputs": [
@@ -105,7 +104,7 @@
" np.float64(0.7348254130670426))"
]
},
- "execution_count": 2,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@@ -126,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 19,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -227,6 +226,65 @@
"print(obj_dml_plpr)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "0c47504d",
+ "metadata": {},
+ "source": [
+ "Using Panel Data Class"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "4290cd61",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "================== DoubleMLPanelData Object ==================\n",
+ "\n",
+ "------------------ Data summary ------------------\n",
+ "Outcome variable: y\n",
+ "Treatment variable(s): ['d']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
+ "Instrument variable(s): None\n",
+ "Time variable: time\n",
+ "Id variable: id\n",
+ "No. Unique Ids: 250\n",
+ "No. Observations: 2500\n",
+ "\n",
+ "------------------ DataFrame info ------------------\n",
+ "\n",
+ "RangeIndex: 2500 entries, 0 to 2499\n",
+ "Columns: 65 entries, id to m_x30\n",
+ "dtypes: float64(63), int64(2)\n",
+ "memory usage: 1.2 MB\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from doubleml.data.panel_data import DoubleMLPanelData\n",
+ "\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "cre_data = cre_fct(data)\n",
+ "\n",
+ "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLPanelData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "print(obj_dml_data_pdml)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 5,
From 85da2eb339624c3c1e5523931e78496b0b11d048 Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Tue, 4 Nov 2025 12:58:22 +0100
Subject: [PATCH 12/33] add static_panel property
---
doubleml/data/panel_data.py | 25 +++++++---
doubleml/plm/sim/example_sim.ipynb | 76 ++++++++++++++----------------
2 files changed, 55 insertions(+), 46 deletions(-)
diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py
index 87fae7f95..3c9810c33 100644
--- a/doubleml/data/panel_data.py
+++ b/doubleml/data/panel_data.py
@@ -95,6 +95,8 @@ def __init__(
):
DoubleMLBaseData.__init__(self, data)
+ self.static_panel = static_panel
+
if not static_panel:
# we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
self.id_col = id_col
@@ -123,9 +125,6 @@ def __init__(
# Set time variable array after data is loaded
self._set_time_var()
- if self.n_treat != 1:
- raise ValueError("Only one treatment column is allowed for panel data.")
-
self._check_disjoint_sets_id_col()
# intialize the unique values of g and t
@@ -152,12 +151,12 @@ def __init__(
force_all_d_finite=False,
)
- if self.n_treat != 1:
- raise ValueError("Only one treatment column is allowed for panel data.")
-
if self.z_cols is not None:
raise ValueError("Static panel data currently does not support instrumental variables.")
+ if self.n_treat != 1:
+ raise ValueError("Only one treatment column is allowed for panel data.")
+
def __str__(self):
data_summary = self._data_summary_str()
buf = io.StringIO()
@@ -329,6 +328,20 @@ def n_t_periods(self):
The number of time periods.
"""
return len(self.t_values)
+
+ @property
+ def static_panel(self):
+ """
+ Indicates whether the data model corresponds to the standard panel data or if the data model corresponds to a static
+ panel data approach.
+ """
+ return self._static_panel
+
+ @static_panel.setter
+ def static_panel(self, value):
+ if not isinstance(value, bool):
+ raise TypeError(f"static_panel must be True or False. Got {str(value)}.")
+ self._static_panel = value
def _get_optional_col_sets(self):
base_optional_col_sets = super()._get_optional_col_sets()
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 89dcafd98..ca332ec06 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -92,19 +92,19 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 2,
"id": "3c061cf1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913912),\n",
- " np.float64(0.6090488219157397),\n",
- " np.float64(0.7348254130670426))"
+ "(np.float64(0.6719371174913908),\n",
+ " np.float64(0.6090488219157394),\n",
+ " np.float64(0.7348254130670423))"
]
},
- "execution_count": 18,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -125,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 18,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,13 +133,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.485431 0.01876 25.875155 1.268226e-147 0.448661 0.522201\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ "d 0.483919 0.018573 26.055021 1.180159e-149 0.447517 0.520321\n"
]
}
],
"source": [
"# cre general\n",
+ "\n",
+ "# np.random.seed(1)\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"cre_data = cre_fct(data)\n",
"\n",
@@ -158,9 +162,24 @@
"obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l=ml_l, ml_m=ml_m,\n",
" pdml_approach='cre_general', n_folds=5\n",
" )\n",
- "\n",
"obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)"
+ "print(obj_dml_plpr.summary)\n",
+ "\n",
+ "from doubleml.data.panel_data import DoubleMLPanelData\n",
+ "\n",
+ "obj_panel = DoubleMLPanelData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
+ " pdml_approach='cre_general', n_folds=5\n",
+ " )\n",
+ "dml_panel_plpr.fit()\n",
+ "print(dml_panel_plpr.summary)"
]
},
{
@@ -180,7 +199,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"id": "83663379",
"metadata": {},
"outputs": [
@@ -208,8 +227,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.63784321]]\n",
- "Learner ml_m RMSE: [[0.96294553]]\n",
+ "Learner ml_l RMSE: [[1.71732417]]\n",
+ "Learner ml_m RMSE: [[0.94662251]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -218,7 +237,7 @@
"\n",
"------------------ Fit summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ "d 0.518921 0.022264 23.308011 3.677034e-120 0.475285 0.562557\n"
]
}
],
@@ -226,18 +245,10 @@
"print(obj_dml_plpr)"
]
},
- {
- "cell_type": "markdown",
- "id": "0c47504d",
- "metadata": {},
- "source": [
- "Using Panel Data Class"
- ]
- },
{
"cell_type": "code",
- "execution_count": 10,
- "id": "4290cd61",
+ "execution_count": 6,
+ "id": "b544d599",
"metadata": {},
"outputs": [
{
@@ -267,22 +278,7 @@
}
],
"source": [
- "from doubleml.data.panel_data import DoubleMLPanelData\n",
- "\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
- "\n",
- "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
- "\n",
- "obj_dml_data_pdml = DoubleMLPanelData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=x_cols,\n",
- " static_panel=True)\n",
- "\n",
- "print(obj_dml_data_pdml)"
+ "print(obj_panel)"
]
},
{
@@ -602,7 +598,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.6"
+ "version": "3.13.9"
}
},
"nbformat": 4,
From 08cfd8c09df6a4a3b90be30b6908825f88c74922 Mon Sep 17 00:00:00 2001
From: SvenKlaassen <47529404+SvenKlaassen@users.noreply.github.com>
Date: Tue, 4 Nov 2025 14:17:30 +0100
Subject: [PATCH 13/33] add static_panel property and update tests for panel
data handling
---
doubleml/data/panel_data.py | 113 ++++++++++---------------
doubleml/data/tests/test_panel_data.py | 23 ++++-
2 files changed, 64 insertions(+), 72 deletions(-)
diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py
index 3c9810c33..1655b410b 100644
--- a/doubleml/data/panel_data.py
+++ b/doubleml/data/panel_data.py
@@ -41,6 +41,13 @@ class DoubleMLPanelData(DoubleMLData):
The instrumental variable(s).
Default is ``None``.
+ static_panel : bool
+ Indicates whether the data model corresponds to a static
+ panel data approach (``True``) or to staggered adoption panel data
+ (``False``). In the latter case, the treatment groups/values are defined in terms of the first time of
+ treatment exposure.
+ Default is ``False``.
+
use_other_treat_as_covariate : bool
Indicates whether in the multiple-treatment case the other treatment variables should be added as covariates.
Default is ``True``.
@@ -57,12 +64,6 @@ class DoubleMLPanelData(DoubleMLData):
datetime_unit : str
The unit of the time and treatment variable (if datetime type).
- static_panel : bool
- Indicates whether the data model corresponds to the standard panel data, where the treatment variable(s) indicate
- the treatment groups in terms of first time of treatment exposure, or if the data model corresponds to a static
- panel data approach.
- Default is ``False``.
-
Examples
--------
>>> from doubleml.did.datasets import make_did_CS2021
@@ -88,71 +89,51 @@ def __init__(
id_col,
x_cols=None,
z_cols=None,
+ static_panel=False,
use_other_treat_as_covariate=True,
force_all_x_finite=True,
datetime_unit="M",
- static_panel=False,
):
DoubleMLBaseData.__init__(self, data)
- self.static_panel = static_panel
-
- if not static_panel:
- # we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
- self.id_col = id_col
- self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
- self._set_id_var()
-
- # Set time column before calling parent constructor
- self.t_col = t_col
-
- # Call parent constructor
- DoubleMLData.__init__(
- self,
- data=data,
- y_col=y_col,
- d_cols=d_cols,
- x_cols=x_cols,
- z_cols=z_cols,
- use_other_treat_as_covariate=use_other_treat_as_covariate,
- force_all_x_finite=force_all_x_finite,
- force_all_d_finite=False,
- )
-
- # reset index to ensure a simple RangeIndex
- self.data.reset_index(drop=True, inplace=True)
+ self._static_panel = static_panel
- # Set time variable array after data is loaded
- self._set_time_var()
+ # we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
+ self.id_col = id_col
+ self._set_id_var()
+ # Set time column before calling parent constructor
+ self.t_col = t_col
+ self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
- self._check_disjoint_sets_id_col()
+ if not self.static_panel:
+ cluster_cols = None
+ else:
+ cluster_cols = id_col
+
+ DoubleMLData.__init__(
+ self,
+ data=data,
+ y_col=y_col,
+ d_cols=d_cols,
+ x_cols=x_cols,
+ z_cols=z_cols,
+ cluster_cols=cluster_cols,
+ use_other_treat_as_covariate=use_other_treat_as_covariate,
+ force_all_x_finite=force_all_x_finite,
+ force_all_d_finite=False,
+ )
- # intialize the unique values of g and t
- self._g_values = np.sort(np.unique(self.d)) # unique values of g
- self._t_values = np.sort(np.unique(self.t)) # unique values of t
+ # reset index to ensure a simple RangeIndex
+ self.data.reset_index(drop=True, inplace=True)
- else:
- # static panel type data class, where id column is used as the cluster variable
+ # Set time variable array after data is loaded
+ self._set_time_var()
- self.id_col = id_col
- self._set_id_var()
- self.t_col = t_col
-
- DoubleMLData.__init__(
- self,
- data=data,
- y_col=y_col,
- d_cols=d_cols,
- x_cols=x_cols,
- z_cols=z_cols,
- cluster_cols=id_col,
- use_other_treat_as_covariate=use_other_treat_as_covariate,
- force_all_x_finite=force_all_x_finite,
- force_all_d_finite=False,
- )
+ self._check_disjoint_sets_id_col()
- if self.z_cols is not None:
- raise ValueError("Static panel data currently does not support instrumental variables.")
+ # intialize the unique values of g and t
+ self._g_values = np.sort(np.unique(self.d)) # unique values of g
+ self._t_values = np.sort(np.unique(self.t)) # unique values of t
if self.n_treat != 1:
raise ValueError("Only one treatment column is allowed for panel data.")
@@ -179,6 +160,7 @@ def _data_summary_str(self):
f"Instrument variable(s): {self.z_cols}\n"
f"Time variable: {self.t_col}\n"
f"Id variable: {self.id_col}\n"
+ f"Static panel data: {self.static_panel}\n"
)
data_summary += f"No. Unique Ids: {self.n_ids}\n" f"No. Observations: {self.n_obs}\n"
@@ -328,21 +310,12 @@ def n_t_periods(self):
The number of time periods.
"""
return len(self.t_values)
-
+
@property
def static_panel(self):
- """
- Indicates whether the data model corresponds to the standard panel data or if the data model corresponds to a static
- panel data approach.
- """
+ """Indicates whether the data model corresponds to a static panel data approach."""
return self._static_panel
- @static_panel.setter
- def static_panel(self, value):
- if not isinstance(value, bool):
- raise TypeError(f"static_panel must be True or False. Got {str(value)}.")
- self._static_panel = value
-
def _get_optional_col_sets(self):
base_optional_col_sets = super()._get_optional_col_sets()
id_col_set = {self.id_col}
diff --git a/doubleml/data/tests/test_panel_data.py b/doubleml/data/tests/test_panel_data.py
index d8506b0d7..0698368c3 100644
--- a/doubleml/data/tests/test_panel_data.py
+++ b/doubleml/data/tests/test_panel_data.py
@@ -157,14 +157,26 @@ def test_panel_data_str():
assert "Time variable: t" in dml_str
assert "Id variable: id" in dml_str
assert "No. Observations:" in dml_str
+ assert "Static panel data:" in dml_str
+
+
+@pytest.fixture(scope="module", params=[True, False])
+def static_panel(request):
+ return request.param
@pytest.mark.ci
-def test_panel_data_properties():
+def test_panel_data_properties(static_panel):
np.random.seed(3141)
df = make_did_SZ2020(n_obs=100, return_type="DoubleMLPanelData")._data
dml_data = DoubleMLPanelData(
- data=df, y_col="y", d_cols="d", t_col="t", id_col="id", x_cols=[f"Z{i + 1}" for i in np.arange(4)]
+ data=df,
+ y_col="y",
+ d_cols="d",
+ t_col="t",
+ id_col="id",
+ x_cols=[f"Z{i + 1}" for i in np.arange(4)],
+ static_panel=static_panel,
)
assert np.array_equal(dml_data.id_var, df["id"].values)
@@ -176,3 +188,10 @@ def test_panel_data_properties():
assert dml_data.n_groups == len(np.unique(df["d"].values))
assert np.array_equal(dml_data.t_values, np.sort(np.unique(df["t"].values)))
assert dml_data.n_t_periods == len(np.unique(df["t"].values))
+
+ if static_panel:
+ assert dml_data.static_panel is True
+ assert dml_data.cluster_cols == ["id"]
+ else:
+ assert dml_data.static_panel is False
+ assert dml_data.cluster_cols is None
From 1ecab6631e9a8e406658e208ee3c8f19181d56fa Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Thu, 6 Nov 2025 21:42:31 +0100
Subject: [PATCH 14/33] update plpr model, include data transformation
---
doubleml/plm/plpr.py | 227 ++++++++++---------------
doubleml/plm/sim/example_sim.ipynb | 262 +++++++++++------------------
2 files changed, 180 insertions(+), 309 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 852316725..f8502cd56 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -5,12 +5,11 @@
from sklearn.base import clone
from sklearn.utils import check_X_y
-from ..data.base_data import DoubleMLData
+from ..data.panel_data import DoubleMLPanelData
from ..double_ml import DoubleML
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
from ..utils._estimation import _dml_cv_predict, _dml_tune
-# from ..utils.blp import DoubleMLBLP
class DoubleMLPLPR(LinearScoreMixin, DoubleML):
@@ -52,13 +51,16 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
or a callable object / function with signature ``psi_a, psi_b = score(y, d, l_hat, m_hat, g_hat, smpls)``.
Default is ``'partialling out'``.
- pdml_approach : str
- Panel DML approach (``'transform'``, ``'cre'``, ``'cre_general'``)
+ static_panel_approach : str
+ A str (``'cre_general'``, ``'cre_normal'``, ``'fd_exact'``, ``'wg_approx'``) specifying the type of
+ static panel approach in Clarke and Polselli (2025).
+ Default is ``'fd_exact'``.
draw_sample_splitting : bool
Indicates whether the sample splitting should be drawn during initialization of the object.
Default is ``True``.
+ TODO: include example and notes
Examples
--------
>>> import numpy as np
@@ -72,17 +74,17 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
"""
def __init__(
- self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", pdml_approach='transform', draw_sample_splitting=True):
+ self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", static_panel_approach="fd_exact", draw_sample_splitting=True):
super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)
self._check_data(self._dml_data)
- # assert cluster?
+ # TODO: assert cluster?
valid_scores = ["IV-type", "partialling out"]
_check_score(self.score, valid_scores, allow_callable=True)
- valid_pdml_approach = ["transform", "cre", "cre_general"]
- self._check_pdml_approach(pdml_approach, valid_pdml_approach)
- self._pdml_approach = pdml_approach
+ valid_static_panel_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
+ self._check_static_panel_approach(static_panel_approach, valid_static_panel_approach)
+ self._static_panel_approach = static_panel_approach
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
@@ -119,16 +121,19 @@ def __init__(
self._predict_method["ml_m"] = "predict"
self._initialize_ml_nuisance_params()
- self._sensitivity_implemented = False ###
+ self._sensitivity_implemented = False
self._external_predictions_implemented = True
+ # Get transformed data depending on approach
+ self._data_transform = self._transform_data(self._static_panel_approach)
+
def __str__(self):
class_name = self.__class__.__name__
header = f"================== {class_name} Object ==================\n"
data_summary = self._dml_data._data_summary_str()
- score_pdml_approach_info = (
+ score_static_panel_approach_info = (
f"Score function: {str(self.score)}\n"
- f"Static panel model approach: {str(self.pdml_approach)}\n"
+ f"Static panel model approach: {str(self.static_panel_approach)}\n"
)
learner_info = ""
for key, value in self.learner.items():
@@ -160,7 +165,7 @@ def __str__(self):
+ "\n------------------ Data summary ------------------\n"
+ data_summary
+ "\n------------------ Score & algorithm ------------------\n"
- + score_pdml_approach_info
+ + score_static_panel_approach_info
+ "\n------------------ Machine learner ------------------\n"
+ learner_info
+ "\n------------------ Resampling ------------------\n"
@@ -171,21 +176,30 @@ def __str__(self):
return res
@property
- def pdml_approach(self):
+ def static_panel_approach(self):
"""
The score function.
"""
- return self._pdml_approach
+ return self._static_panel_approach
+
+ @property
+ def data_transform(self):
+ """
+ The transformed static panel data.
+ """
+ return self._data_transform
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
- # DoubleML init check for type already?
- # TODO: Ensure cluster usage
def _check_data(self, obj_dml_data):
- if not isinstance(obj_dml_data, DoubleMLData):
+ if not isinstance(obj_dml_data, DoubleMLPanelData):
raise TypeError(
- f"The data must be of DoubleMLData type. {str(type(obj_dml_data))} was passed."
+ f"The data must be of DoubleMLPanelData type. {str(type(obj_dml_data))} was passed."
+ )
+ if not obj_dml_data.static_panel:
+ raise ValueError(
+ "For the PLPR model, the DoubleMLPanelData object requires the static_panel flag to be set to True."
)
if obj_dml_data.z_cols is not None:
raise ValueError(
@@ -194,14 +208,52 @@ def _check_data(self, obj_dml_data):
)
return
- def _check_pdml_approach(self, pdml_approach, valid_pdml_approach):
- if isinstance(pdml_approach, str):
- if pdml_approach not in valid_pdml_approach:
- raise ValueError("Invalid pdml_approach " + pdml_approach + ". " + "Valid approach " + " or ".join(valid_pdml_approach) + ".")
+ def _check_static_panel_approach(self, static_panel_approach, valid_static_panel_approach):
+ if isinstance(static_panel_approach, str):
+ if static_panel_approach not in valid_static_panel_approach:
+ raise ValueError("Invalid static_panel_approach " + static_panel_approach + ". " + "Valid approach " + " or ".join(valid_static_panel_approach) + ".")
else:
- raise TypeError(f"score should be a string. {str(pdml_approach)} was passed.")
+ raise TypeError(f"static_panel_approach should be a string. {str(static_panel_approach)} was passed.")
return
+
+ # TODO: preprocess and transform data based on static_panel_approach (cre, fd, wd)
+ def _transform_data(self, static_panel_approach):
+ df = self._dml_data.data.copy()
+
+ y_col = self._dml_data.y_col
+ d_cols = self._dml_data.d_cols
+ x_cols = self._dml_data.x_cols
+ t_col = self._dml_data.t_col
+ id_col = self._dml_data.id_col
+
+ if static_panel_approach in ["cre_general", "cre_normal"]:
+ # uses regular y_col, d_cols, x_cols + m_x_cols
+ df_id_means = df[[id_col] + d_cols + x_cols].groupby(id_col).transform("mean")
+ df_means = df_id_means.add_prefix("m_")
+ data = pd.concat([df, df_means], axis=1)
+ # {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"m_{x}"" for x in x_cols]}
+ elif static_panel_approach == "fd_exact":
+ # TODO: potential issues with unbalanced panels/missing periods, right now the
+ # last available is used as the lag and for diff. Maybe reindex to a complete time grid per id.
+ # uses y_col_diff, d_cols_diff, x_cols + x_cols_lag
+ df = df.sort_values([id_col, t_col])
+ shifted = df[[id_col] + x_cols].groupby(id_col).shift(1).add_suffix("_lag")
+ first_diff = df[[id_col] + d_cols + [y_col]].groupby(id_col).diff().add_suffix("_diff")
+ df_fd = pd.concat([df, shifted, first_diff], axis=1)
+ data = df_fd.dropna(subset=[x_cols[0] + "_lag"]).reset_index(drop=True)
+ # {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
+ elif static_panel_approach == "wg_approx":
+ # uses y_col, d_cols, x_cols
+ df_demean = df.drop(t_col, axis=1).groupby(id_col).transform(lambda x: x - x.mean())
+ # add grand means
+ grand_means = df.drop([id_col, t_col], axis=1).mean()
+ within_means = df_demean + grand_means
+ data = pd.concat([df[[id_col, t_col]], within_means], axis=1)
+ # {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols}
+ else:
+ raise ValueError(f"Invalid static_panel_approach.")
+ return data
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
@@ -235,10 +287,11 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
if m_external:
m_hat = {"preds": external_predictions["ml_m"], "targets": None, "models": None}
else:
+ # TODO: update this section
# cre using m_d + x for m_hat, otherwise only x
- if self._pdml_approach == 'cre':
- help_data = pd.DataFrame({'id': self._dml_data.cluster_vars[:, 0], 'd': d})
- m_d = help_data.groupby(["id"]).transform('mean').values
+ if self._static_panel_approach == "cre_normal":
+ help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "d": d})
+ m_d = help_data.groupby(["id"]).transform("mean").values
x = np.column_stack((x, m_d))
m_hat = _dml_cv_predict(
@@ -253,11 +306,11 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
)
# general cre adjustment
- if self._pdml_approach == 'cre_general':
- help_data = pd.DataFrame({'id': self._dml_data.cluster_vars[:, 0], 'm_hat': m_hat['preds'], 'd': d})
- group_means = help_data.groupby(['id'])[['m_hat', 'd']].transform('mean')
- m_hat_star = m_hat['preds'] + group_means['d'] - group_means['m_hat']
- m_hat['preds'] = m_hat_star
+ if self._static_panel_approach == "cre_general":
+ help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "m_hat": m_hat["preds"], "d": d})
+ group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
+ m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
+ m_hat["preds"] = m_hat_star
_check_finite_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls)
@@ -393,112 +446,4 @@ def _nuisance_tuning(
res = {"params": params, "tune_res": tune_res}
- return res
-
- # def cate(self, basis, is_gate=False, **kwargs):
- # """
- # Calculate conditional average treatment effects (CATE) for a given basis.
-
- # Parameters
- # ----------
- # basis : :class:`pandas.DataFrame`
- # The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
- # where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
-
- # is_gate : bool
- # Indicates whether the basis is constructed for GATEs (dummy-basis).
- # Default is ``False``.
-
- # **kwargs: dict
- # Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
-
- # Returns
- # -------
- # model : :class:`doubleML.DoubleMLBLP`
- # Best linear Predictor model.
- # """
- # if self._dml_data.n_treat > 1:
- # raise NotImplementedError(
- # "Only implemented for single treatment. " + f"Number of treatments is {str(self._dml_data.n_treat)}."
- # )
- # if self.n_rep != 1:
- # raise NotImplementedError("Only implemented for one repetition. " + f"Number of repetitions is {str(self.n_rep)}.")
-
- # Y_tilde, D_tilde = self._partial_out()
-
- # D_basis = basis * D_tilde
- # model = DoubleMLBLP(
- # orth_signal=Y_tilde.reshape(-1),
- # basis=D_basis,
- # is_gate=is_gate,
- # )
- # model.fit(**kwargs)
- # return model
-
- # def gate(self, groups, **kwargs):
- # """
- # Calculate group average treatment effects (GATE) for groups.
-
- # Parameters
- # ----------
- # groups : :class:`pandas.DataFrame`
- # The group indicator for estimating the best linear predictor. Groups should be mutually exclusive.
- # Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
- # and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
-
- # **kwargs: dict
- # Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
-
- # Returns
- # -------
- # model : :class:`doubleML.DoubleMLBLP`
- # Best linear Predictor model for Group Effects.
- # """
-
- # if not isinstance(groups, pd.DataFrame):
- # raise TypeError(f"Groups must be of DataFrame type. Groups of type {str(type(groups))} was passed.")
- # if not all(groups.dtypes == bool) or all(groups.dtypes == int):
- # if groups.shape[1] == 1:
- # groups = pd.get_dummies(groups, prefix="Group", prefix_sep="_")
- # else:
- # raise TypeError(
- # "Columns of groups must be of bool type or int type (dummy coded). "
- # "Alternatively, groups should only contain one column."
- # )
-
- # if any(groups.sum(0) <= 5):
- # warnings.warn("At least one group effect is estimated with less than 6 observations.")
-
- # model = self.cate(groups, is_gate=True, **kwargs)
- # return model
-
- # def _partial_out(self):
- # """
- # Helper function. Returns the partialled out quantities of Y and D.
- # Works with multiple repetitions.
-
- # Returns
- # -------
- # Y_tilde : :class:`numpy.ndarray`
- # The residual of the regression of Y on X.
- # D_tilde : :class:`numpy.ndarray`
- # The residual of the regression of D on X.
- # """
- # if self.predictions is None:
- # raise ValueError("predictions are None. Call .fit(store_predictions=True) to store the predictions.")
-
- # y = self._dml_data.y.reshape(-1, 1)
- # d = self._dml_data.d.reshape(-1, 1)
- # ml_m = self.predictions["ml_m"].squeeze(axis=2)
-
- # if self.score == "partialling out":
- # ml_l = self.predictions["ml_l"].squeeze(axis=2)
- # Y_tilde = y - ml_l
- # D_tilde = d - ml_m
- # else:
- # assert self.score == "IV-type"
- # ml_g = self.predictions["ml_g"].squeeze(axis=2)
- # Y_tilde = y - (self.coef * ml_m) - ml_g
- # D_tilde = d - ml_m
-
- # return Y_tilde, D_tilde
+ return res
\ No newline at end of file
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index ca332ec06..755c56ae1 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -86,7 +86,6 @@
"from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct, extend_data\n",
"from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
"import warnings\n",
- "import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
@@ -125,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 8,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,10 +132,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.485431 0.01876 25.875155 1.268226e-147 0.448661 0.522201\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.483919 0.018573 26.055021 1.180159e-149 0.447517 0.520321\n"
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
@@ -149,22 +146,10 @@
"\n",
"x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
"\n",
- "obj_dml_data_pdml = DoubleMLData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=x_cols)\n",
- "\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l=ml_l, ml_m=ml_m,\n",
- " pdml_approach='cre_general', n_folds=5\n",
- " )\n",
- "obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)\n",
- "\n",
"from doubleml.data.panel_data import DoubleMLPanelData\n",
"\n",
"obj_panel = DoubleMLPanelData(cre_data,\n",
@@ -176,7 +161,7 @@
" static_panel=True)\n",
"\n",
"dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
- " pdml_approach='cre_general', n_folds=5\n",
+ " static_panel_approach='cre_general', n_folds=5\n",
" )\n",
"dml_panel_plpr.fit()\n",
"print(dml_panel_plpr.summary)"
@@ -184,23 +169,23 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"id": "7deabe55",
"metadata": {},
"outputs": [],
"source": [
"# model rmse\n",
"\n",
- "# u_hat = obj_dml_plpr._dml_data.y - obj_dml_plpr.predictions['ml_l'].flatten()\n",
- "# v_hat = obj_dml_plpr._dml_data.d - obj_dml_plpr.predictions['ml_m'].flatten()\n",
+ "# u_hat = dml_panel_plpr._dml_data.y - dml_panel_plpr.predictions['ml_l'].flatten()\n",
+ "# v_hat = dml_panel_plpr._dml_data.d - dml_panel_plpr.predictions['ml_m'].flatten()\n",
"\n",
- "# np.sqrt(np.mean((u_hat - (obj_dml_plpr.coef[0] * v_hat))**2))"
+ "# np.sqrt(np.mean(np.square(u_hat - (dml_panel_plpr.coef[0] * v_hat))))"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "83663379",
+ "execution_count": 9,
+ "id": "48d4dbd8",
"metadata": {},
"outputs": [
{
@@ -214,8 +199,10 @@
"Treatment variable(s): ['d']\n",
"Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
"Instrument variable(s): None\n",
- "Cluster variable(s): ['id']\n",
- "Is cluster data: True\n",
+ "Time variable: time\n",
+ "Id variable: id\n",
+ "Static panel data: True\n",
+ "No. Unique Ids: 250\n",
"No. Observations: 2500\n",
"\n",
"------------------ Score & algorithm ------------------\n",
@@ -227,8 +214,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.71732417]]\n",
- "Learner ml_m RMSE: [[0.94662251]]\n",
+ "Learner ml_l RMSE: [[1.72960098]]\n",
+ "Learner ml_m RMSE: [[0.95035703]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -237,53 +224,17 @@
"\n",
"------------------ Fit summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.518921 0.022264 23.308011 3.677034e-120 0.475285 0.562557\n"
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
"source": [
- "print(obj_dml_plpr)"
+ "print(dml_panel_plpr)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "b544d599",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "================== DoubleMLPanelData Object ==================\n",
- "\n",
- "------------------ Data summary ------------------\n",
- "Outcome variable: y\n",
- "Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
- "Instrument variable(s): None\n",
- "Time variable: time\n",
- "Id variable: id\n",
- "No. Unique Ids: 250\n",
- "No. Observations: 2500\n",
- "\n",
- "------------------ DataFrame info ------------------\n",
- "\n",
- "RangeIndex: 2500 entries, 0 to 2499\n",
- "Columns: 65 entries, id to m_x30\n",
- "dtypes: float64(63), int64(2)\n",
- "memory usage: 1.2 MB\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(obj_panel)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
+ "execution_count": 14,
"id": "24f06d62",
"metadata": {},
"outputs": [
@@ -291,8 +242,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.548085 0.02097 26.136802 1.392306e-150 0.506985 0.589186\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.533994 0.021782 24.51491 1.024387e-132 0.491301 0.576687\n"
]
}
],
@@ -303,11 +254,13 @@
"\n",
"x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
"\n",
- "obj_dml_data_pdml = DoubleMLData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=x_cols)\n",
+ "obj_dml_data_pdml = DoubleMLPanelData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
"\n",
"# learner = LassoCV()\n",
"learner = make_pipeline(StandardScaler(), LassoCV())\n",
@@ -315,14 +268,14 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='cre')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='cre_normal')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 18,
"id": "61a72563",
"metadata": {},
"outputs": [
@@ -330,8 +283,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.492487 0.025352 19.42565 4.684180e-84 0.442797 0.542176\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.510621 0.024609 20.749141 1.248004e-95 0.462388 0.558854\n"
]
}
],
@@ -339,24 +292,26 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"fd_data = fd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLData(fd_data,\n",
+ "obj_dml_data_pdml = DoubleMLPanelData(fd_data,\n",
" y_col='y_diff',\n",
" d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='transform')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='fd_exact')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 25,
"id": "aeb00efe",
"metadata": {},
"outputs": [
@@ -364,8 +319,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.5089 0.019913 25.555893 4.722065e-144 0.469871 0.547929\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.491597 0.021075 23.325666 2.434375e-120 0.45029 0.532904\n"
]
}
],
@@ -373,24 +328,26 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"wd_data = wd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLData(wd_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")])\n",
+ "obj_dml_data_pdml = DoubleMLPanelData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, pdml_approach='transform')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='wg_approx')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 26,
"id": "586d5edf",
"metadata": {},
"outputs": [
@@ -408,13 +365,9 @@
"\n",
"learner = make_pipeline(StandardScaler(), LassoCV())\n",
"\n",
- "leaner_ols = LinearRegression()\n",
- "\n",
- "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
"res_cre_general = np.full((n_reps, 3), np.nan)\n",
"res_cre_normal = np.full((n_reps, 3), np.nan)\n",
"res_fd = np.full((n_reps, 3), np.nan)\n",
- "res_fd_cluster = np.full((n_reps, 3), np.nan)\n",
"res_wd = np.full((n_reps, 3), np.nan)\n",
"\n",
"np.random.seed(1)\n",
@@ -423,19 +376,13 @@
" print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
" data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
"\n",
- " # CRE general OLS\n",
- " cre_data = cre_fct(data)\n",
- " dml_data = DoubleMLData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(leaner_ols), clone(leaner_ols), n_folds=5, pdml_approach='cre_general')\n",
- " dml_plpr.fit()\n",
- " res_cre_ols[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_ols[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
" # CRE general Lasso\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='cre_general')\n",
+ " cre_data = cre_fct(data)\n",
+ " dml_data = DoubleMLPanelData(cre_data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
+ " x_cols=[col for col in cre_data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " static_panel_approach='cre_general')\n",
" dml_plpr.fit()\n",
" res_cre_general[i, 0] = dml_plpr.coef[0]\n",
" res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -443,10 +390,8 @@
" res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
"\n",
" # CRE normality\n",
- " dml_data = DoubleMLData(cre_data, y_col='y', d_cols='d', cluster_cols='id', \n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
- " )\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='cre')\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " static_panel_approach='cre_normal')\n",
" dml_plpr.fit()\n",
" res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
" res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -455,29 +400,24 @@
"\n",
" # FD approach\n",
" fd_data = fd_fct(data)\n",
- " dml_data = DoubleMLData(fd_data, y_col='y_diff', d_cols='d_diff', cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
+ " dml_data = DoubleMLPanelData(fd_data, y_col='y_diff', d_cols='d_diff', t_col='time', id_col='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " static_panel_approach='fd_exact')\n",
" dml_plpr.fit()\n",
" res_fd[i, 0] = dml_plpr.coef[0]\n",
" res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
" confint = dml_plpr.confint()\n",
" res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- " # no cluster\n",
- " dml_data = DoubleMLData(fd_data, y_col='y_diff', d_cols='d_diff',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")])\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
- " dml_plpr.fit()\n",
- " res_fd_cluster[i, 0] = dml_plpr.coef[0]\n",
- " res_fd_cluster[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_fd_cluster[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
" \n",
" # WD approach\n",
" wd_data = wd_fct(data)\n",
- " dml_data = DoubleMLData(wd_data, y_col='y', d_cols='d', cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, pdml_approach='transform')\n",
+ " dml_data = DoubleMLPanelData(wd_data, y_col='y', d_cols='d', t_col='time', id_col='id',\n",
+ " x_cols=[col for col in wd_data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " static_panel_approach='wg_approx')\n",
" dml_plpr.fit()\n",
" res_wd[i, 0] = dml_plpr.coef[0]\n",
" res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -487,8 +427,8 @@
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "edb9e3d5",
+ "execution_count": 28,
+ "id": "987c0496",
"metadata": {},
"outputs": [
{
@@ -519,66 +459,52 @@
" \n",
" \n",
" \n",
- " | CRE OLS | \n",
- " 0.498516 | \n",
- " -0.001484 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
" | CRE general | \n",
- " 0.517304 | \n",
- " 0.017304 | \n",
- " 0.91 | \n",
- "
\n",
- " \n",
- " | CRE normality | \n",
- " 0.540535 | \n",
- " 0.040535 | \n",
- " 0.80 | \n",
+ " 0.516684 | \n",
+ " 0.016684 | \n",
+ " 0.92 | \n",
"
\n",
" \n",
- " | FD | \n",
- " 0.504695 | \n",
- " 0.004695 | \n",
- " 0.95 | \n",
+ " CRE normal | \n",
+ " 0.541518 | \n",
+ " 0.041518 | \n",
+ " 0.78 | \n",
"
\n",
" \n",
- " | FD no cluster | \n",
- " 0.503954 | \n",
- " 0.003954 | \n",
- " 0.87 | \n",
+ " FD exact | \n",
+ " 0.504094 | \n",
+ " 0.004094 | \n",
+ " 0.94 | \n",
"
\n",
" \n",
- " | WD | \n",
- " 0.502402 | \n",
- " 0.002402 | \n",
- " 0.93 | \n",
+ " WG approx | \n",
+ " 0.502006 | \n",
+ " 0.002006 | \n",
+ " 0.94 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " Coef Bias Coverage\n",
- "CRE OLS 0.498516 -0.001484 0.94\n",
- "CRE general 0.517304 0.017304 0.91\n",
- "CRE normality 0.540535 0.040535 0.80\n",
- "FD 0.504695 0.004695 0.95\n",
- "FD no cluster 0.503954 0.003954 0.87\n",
- "WD 0.502402 0.002402 0.93"
+ " Coef Bias Coverage\n",
+ "CRE general 0.516684 0.016684 0.92\n",
+ "CRE normal 0.541518 0.041518 0.78\n",
+ "FD exact 0.504094 0.004094 0.94\n",
+ "WG approx 0.502006 0.002006 0.94"
]
},
- "execution_count": 16,
+ "execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "pd.DataFrame(np.vstack([res_cre_ols.mean(axis=0), res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
- " res_fd.mean(axis=0), res_fd_cluster.mean(axis=0), res_wd.mean(axis=0)]), \n",
+ "pd.DataFrame(np.vstack([res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
+ " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
" columns=['Coef', 'Bias', 'Coverage'], \n",
- " index=['CRE OLS', 'CRE general', 'CRE normality', \n",
- " 'FD', 'FD no cluster', 'WD'])"
+ " index=['CRE general', 'CRE normal', \n",
+ " 'FD exact', 'WG approx'])"
]
}
],
From 5dc1a448b77f563e08e60db3b295f72755b16cd4 Mon Sep 17 00:00:00 2001
From: SvenKlaassen
Date: Fri, 7 Nov 2025 15:58:38 +0100
Subject: [PATCH 15/33] refactor: simplify string representation and add
additional info method in DoubleMLPLPR class
---
doubleml/plm/plpr.py | 54 ++++++++------------------------------------
1 file changed, 10 insertions(+), 44 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index f8502cd56..ccfc7c76d 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -127,53 +127,19 @@ def __init__(
# Get transformed data depending on approach
self._data_transform = self._transform_data(self._static_panel_approach)
- def __str__(self):
- class_name = self.__class__.__name__
- header = f"================== {class_name} Object ==================\n"
- data_summary = self._dml_data._data_summary_str()
+ def _format_score_info_str(self):
score_static_panel_approach_info = (
f"Score function: {str(self.score)}\n"
f"Static panel model approach: {str(self.static_panel_approach)}\n"
)
- learner_info = ""
- for key, value in self.learner.items():
- learner_info += f"Learner {key}: {str(value)}\n"
- if self.nuisance_loss is not None:
- learner_info += "Out-of-sample Performance:\n"
- is_classifier = [value for value in self._is_classifier.values()]
- is_regressor = [not value for value in is_classifier]
- if any(is_regressor):
- learner_info += "Regression:\n"
- for learner in [key for key, value in self._is_classifier.items() if value is False]:
- learner_info += f"Learner {learner} RMSE: {self.nuisance_loss[learner]}\n"
- if any(is_classifier):
- learner_info += "Classification:\n"
- for learner in [key for key, value in self._is_classifier.items() if value is True]:
- learner_info += f"Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n"
-
- if self._is_cluster_data:
- resampling_info = (
- f"No. folds per cluster: {self._n_folds_per_cluster}\n"
- f"No. folds: {self.n_folds}\n"
- f"No. repeated sample splits: {self.n_rep}\n"
- )
- else:
- resampling_info = f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}\n"
- fit_summary = str(self.summary)
- res = (
- header
- + "\n------------------ Data summary ------------------\n"
- + data_summary
- + "\n------------------ Score & algorithm ------------------\n"
- + score_static_panel_approach_info
- + "\n------------------ Machine learner ------------------\n"
- + learner_info
- + "\n------------------ Resampling ------------------\n"
- + resampling_info
- + "\n------------------ Fit summary ------------------\n"
- + fit_summary
- )
- return res
+ return score_static_panel_approach_info
+
+ def _format_additional_info_str(self):
+ """
+ Includes information on the transformed features based on the estimation approach.
+ """
+ # TODO: Add Information on features after transformation
+ return ""
@property
def static_panel_approach(self):
@@ -229,7 +195,7 @@ def _transform_data(self, static_panel_approach):
if static_panel_approach in ["cre_general", "cre_normal"]:
# uses regular y_col, d_cols, x_cols + m_x_cols
df_id_means = df[[id_col] + d_cols + x_cols].groupby(id_col).transform("mean")
- df_means = df_id_means.add_prefix("m_")
+ df_means = df_id_means.add_prefix("mean_")
data = pd.concat([df, df_means], axis=1)
# {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"m_{x}"" for x in x_cols]}
elif static_panel_approach == "fd_exact":
From fdc43308a15ba2e0ea89b9ceaec9a3160da200ac Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Mon, 10 Nov 2025 09:53:27 +0100
Subject: [PATCH 16/33] correct score info string newline spacing
---
doubleml/plm/plpr.py | 5 ++-
doubleml/plm/sim/example_sim.ipynb | 55 ++++++++++++++++++++++++++++--
2 files changed, 57 insertions(+), 3 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index ccfc7c76d..f7947b8e0 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -125,12 +125,15 @@ def __init__(
self._external_predictions_implemented = True
# Get transformed data depending on approach
+ # TODO: get y, x, d cols, set additional properties for y_data, d_data, x_data to be used in
+ # nuisance
self._data_transform = self._transform_data(self._static_panel_approach)
+
def _format_score_info_str(self):
score_static_panel_approach_info = (
f"Score function: {str(self.score)}\n"
- f"Static panel model approach: {str(self.static_panel_approach)}\n"
+ f"Static panel model approach: {str(self.static_panel_approach)}"
)
return score_static_panel_approach_info
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 755c56ae1..ebff5d89b 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -124,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 3,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,7 +133,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
]
}
],
@@ -182,6 +182,57 @@
"# np.sqrt(np.mean(np.square(u_hat - (dml_panel_plpr.coef[0] * v_hat))))"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "11bbc988",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "================== DoubleMLPLPR Object ==================\n",
+ "\n",
+ "------------------ Data Summary ------------------\n",
+ "Outcome variable: y\n",
+ "Treatment variable(s): ['d']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
+ "Instrument variable(s): None\n",
+ "Time variable: time\n",
+ "Id variable: id\n",
+ "Static panel data: True\n",
+ "No. Unique Ids: 250\n",
+ "No. Observations: 2500\n",
+ "\n",
+ "\n",
+ "------------------ Score & Algorithm ------------------\n",
+ "Score function: partialling out\n",
+ "Static panel model approach: cre_general\n",
+ "\n",
+ "------------------ Machine Learner ------------------\n",
+ "Learner ml_l: LassoCV()\n",
+ "Learner ml_m: LassoCV()\n",
+ "Out-of-sample Performance:\n",
+ "Regression:\n",
+ "Learner ml_l RMSE: [[1.63784321]]\n",
+ "Learner ml_m RMSE: [[0.96294553]]\n",
+ "\n",
+ "------------------ Resampling ------------------\n",
+ "No. folds per cluster: 5\n",
+ "No. folds: 5\n",
+ "No. repeated sample splits: 1\n",
+ "\n",
+ "------------------ Fit Summary ------------------\n",
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(dml_panel_plpr)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 9,
From bccc81c43eccded590743317216fbcb60cbb53a5 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 10 Nov 2025 16:48:08 +0100
Subject: [PATCH 17/33] data transform update, add transform_col_names property
---
.../plm/datasets/dgp_static_panel_CP2025.py | 4 +-
doubleml/plm/plpr.py | 93 ++++++++++---------
doubleml/plm/sim/example_sim.ipynb | 60 ++++++------
3 files changed, 83 insertions(+), 74 deletions(-)
diff --git a/doubleml/plm/datasets/dgp_static_panel_CP2025.py b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
index 2dd7b0576..ade148c41 100644
--- a/doubleml/plm/datasets/dgp_static_panel_CP2025.py
+++ b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
@@ -79,7 +79,7 @@ def alpha_i(x_it, d_it, a_i, num_n, num_t):
x_cols = [f'x{i + 1}' for i in np.arange(dim_x)]
- data = pd.DataFrame(np.column_stack((id, time, d_it, y_it, x_it)),
- columns=['id', 'time', 'd', 'y'] + x_cols).astype({'id': 'int64', 'time': 'int64'})
+ data = pd.DataFrame(np.column_stack((id, time, y_it, d_it, x_it)),
+ columns=['id', 'time', 'y', 'd'] + x_cols).astype({'id': 'int64', 'time': 'int64'})
return data
\ No newline at end of file
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index f7947b8e0..f94238c22 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -51,7 +51,7 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
or a callable object / function with signature ``psi_a, psi_b = score(y, d, l_hat, m_hat, g_hat, smpls)``.
Default is ``'partialling out'``.
- static_panel_approach : str
+ approach : str
A str (``'cre_general'``, ``'cre_normal'``, ``'fd_exact'``, ``'wg_approx'``) specifying the type of
static panel approach in Clarke and Polselli (2025).
Default is ``'fd_exact'``.
@@ -74,7 +74,7 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
"""
def __init__(
- self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", static_panel_approach="fd_exact", draw_sample_splitting=True):
+ self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", approach="fd_exact", draw_sample_splitting=True):
super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)
self._check_data(self._dml_data)
@@ -82,9 +82,9 @@ def __init__(
valid_scores = ["IV-type", "partialling out"]
_check_score(self.score, valid_scores, allow_callable=True)
- valid_static_panel_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
- self._check_static_panel_approach(static_panel_approach, valid_static_panel_approach)
- self._static_panel_approach = static_panel_approach
+ valid_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
+ self._check_approach(approach, valid_approach)
+ self._approach = approach
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
@@ -127,15 +127,15 @@ def __init__(
# Get transformed data depending on approach
# TODO: get y, x, d cols, set additional properties for y_data, d_data, x_data to be used in
# nuisance
- self._data_transform = self._transform_data(self._static_panel_approach)
+ self._data_transform, self._transform_col_names, = self._transform_data(self._approach)
def _format_score_info_str(self):
- score_static_panel_approach_info = (
+ score_approach_info = (
f"Score function: {str(self.score)}\n"
- f"Static panel model approach: {str(self.static_panel_approach)}"
+ f"Static panel model approach: {str(self.approach)}"
)
- return score_static_panel_approach_info
+ return score_approach_info
def _format_additional_info_str(self):
"""
@@ -145,11 +145,11 @@ def _format_additional_info_str(self):
return ""
@property
- def static_panel_approach(self):
+ def approach(self):
"""
The score function.
"""
- return self._static_panel_approach
+ return self._approach
@property
def data_transform(self):
@@ -157,7 +157,14 @@ def data_transform(self):
The transformed static panel data.
"""
return self._data_transform
-
+
+ @property
+ def transform_col_names(self):
+ """
+ The column names of the transformed static panel data.
+ """
+ return self._transform_col_names
+
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
@@ -177,16 +184,16 @@ def _check_data(self, obj_dml_data):
)
return
- def _check_static_panel_approach(self, static_panel_approach, valid_static_panel_approach):
- if isinstance(static_panel_approach, str):
- if static_panel_approach not in valid_static_panel_approach:
- raise ValueError("Invalid static_panel_approach " + static_panel_approach + ". " + "Valid approach " + " or ".join(valid_static_panel_approach) + ".")
+ def _check_approach(self, approach, valid_approach):
+ if isinstance(approach, str):
+ if approach not in valid_approach:
+ raise ValueError("Invalid approach " + approach + ". " + "Valid approach " + " or ".join(valid_approach) + ".")
else:
- raise TypeError(f"static_panel_approach should be a string. {str(static_panel_approach)} was passed.")
+ raise TypeError(f"approach should be a string. {str(approach)} was passed.")
return
- # TODO: preprocess and transform data based on static_panel_approach (cre, fd, wd)
- def _transform_data(self, static_panel_approach):
+ # TODO: preprocess and transform data based on approach (cre, fd, wd)
+ def _transform_data(self, approach):
df = self._dml_data.data.copy()
y_col = self._dml_data.y_col
@@ -195,34 +202,36 @@ def _transform_data(self, static_panel_approach):
t_col = self._dml_data.t_col
id_col = self._dml_data.id_col
- if static_panel_approach in ["cre_general", "cre_normal"]:
- # uses regular y_col, d_cols, x_cols + m_x_cols
- df_id_means = df[[id_col] + d_cols + x_cols].groupby(id_col).transform("mean")
- df_means = df_id_means.add_prefix("mean_")
+ if approach in ["cre_general", "cre_normal"]:
+ df_id_means = df[[id_col] + x_cols].groupby(id_col).transform("mean")
+ df_means = df_id_means.add_suffix("_mean")
data = pd.concat([df, df_means], axis=1)
- # {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"m_{x}"" for x in x_cols]}
- elif static_panel_approach == "fd_exact":
+ col_names = {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"{x}_mean" for x in x_cols]}
+ elif approach == "fd_exact":
# TODO: potential issues with unbalanced panels/missing periods, right now the
- # last available is used as the lag and for diff. Maybe reindex to a complete time grid per id.
- # uses y_col_diff, d_cols_diff, x_cols + x_cols_lag
+ # last available is used for the lag and first difference. Maybe reindex to a complete time grid per id.
df = df.sort_values([id_col, t_col])
shifted = df[[id_col] + x_cols].groupby(id_col).shift(1).add_suffix("_lag")
- first_diff = df[[id_col] + d_cols + [y_col]].groupby(id_col).diff().add_suffix("_diff")
- df_fd = pd.concat([df, shifted, first_diff], axis=1)
+ first_diff = df[[id_col] + [y_col] + d_cols].groupby(id_col).diff().add_suffix("_diff")
+ df_fd = pd.concat([df, shifted], axis=1)
+ # replace original y and d columns for first-difference transformations, rename
+ df_fd[[y_col] + d_cols] = first_diff
+ cols_rename_dict = {y_col: f"{y_col}_diff"} | {col: f"{col}_diff" for col in d_cols}
+ df_fd = df_fd.rename(columns=cols_rename_dict)
+ # drop rows for first period
data = df_fd.dropna(subset=[x_cols[0] + "_lag"]).reset_index(drop=True)
- # {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
- elif static_panel_approach == "wg_approx":
- # uses y_col, d_cols, x_cols
- df_demean = df.drop(t_col, axis=1).groupby(id_col).transform(lambda x: x - x.mean())
- # add grand means
- grand_means = df.drop([id_col, t_col], axis=1).mean()
- within_means = df_demean + grand_means
+ col_names = {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
+ elif approach == "wg_approx":
+ cols_to_demean = [y_col] + d_cols + x_cols
+ # compute group and grand means for within means
+ group_means = df.groupby(id_col)[cols_to_demean].transform('mean')
+ grand_means = df[cols_to_demean].mean()
+ within_means = df[cols_to_demean] - group_means + grand_means
+ within_means = within_means.add_suffix("_demean")
data = pd.concat([df[[id_col, t_col]], within_means], axis=1)
- # {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols}
- else:
- raise ValueError(f"Invalid static_panel_approach.")
+ col_names = {"y_col": f"{y_col}_demean", "d_cols": [f"{d}_demean" for d in d_cols], "x_cols": [f"{x}_demean" for x in x_cols]}
- return data
+ return data, col_names
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
@@ -258,7 +267,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
else:
# TODO: update this section
# cre using m_d + x for m_hat, otherwise only x
- if self._static_panel_approach == "cre_normal":
+ if self._approach == "cre_normal":
help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "d": d})
m_d = help_data.groupby(["id"]).transform("mean").values
x = np.column_stack((x, m_d))
@@ -275,7 +284,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
)
# general cre adjustment
- if self._static_panel_approach == "cre_general":
+ if self._approach == "cre_general":
help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "m_hat": m_hat["preds"], "d": d})
group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index ebff5d89b..8d9108dd1 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -98,9 +98,9 @@
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913908),\n",
- " np.float64(0.6090488219157394),\n",
- " np.float64(0.7348254130670423))"
+ "(np.float64(0.6719371174913912),\n",
+ " np.float64(0.6090488219157397),\n",
+ " np.float64(0.7348254130670426))"
]
},
"execution_count": 2,
@@ -124,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 11,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,7 +133,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
@@ -161,7 +161,7 @@
" static_panel=True)\n",
"\n",
"dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
- " static_panel_approach='cre_general', n_folds=5\n",
+ " approach='cre_general', n_folds=5\n",
" )\n",
"dml_panel_plpr.fit()\n",
"print(dml_panel_plpr.summary)"
@@ -184,7 +184,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 12,
"id": "11bbc988",
"metadata": {},
"outputs": [
@@ -215,8 +215,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.63784321]]\n",
- "Learner ml_m RMSE: [[0.96294553]]\n",
+ "Learner ml_l RMSE: [[1.72960098]]\n",
+ "Learner ml_m RMSE: [[0.95035703]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -225,7 +225,7 @@
"\n",
"------------------ Fit Summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
]
}
],
@@ -285,7 +285,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 16,
"id": "24f06d62",
"metadata": {},
"outputs": [
@@ -293,8 +293,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.533994 0.021782 24.51491 1.024387e-132 0.491301 0.576687\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.503726 0.022932 21.965879 6.106437e-107 0.45878 0.548672\n"
]
}
],
@@ -319,7 +319,7 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='cre_normal')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='cre_normal')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
@@ -334,8 +334,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.510621 0.024609 20.749141 1.248004e-95 0.462388 0.558854\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.531576 0.025707 20.677905 5.476813e-95 0.48119 0.581961\n"
]
}
],
@@ -355,14 +355,14 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='fd_exact')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='fd_exact')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 19,
"id": "aeb00efe",
"metadata": {},
"outputs": [
@@ -370,8 +370,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.491597 0.021075 23.325666 2.434375e-120 0.45029 0.532904\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.464302 0.022013 21.092019 9.415595e-99 0.421157 0.507447\n"
]
}
],
@@ -391,14 +391,14 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, static_panel_approach='wg_approx')\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='wg_approx')\n",
"obj_dml_plpr.fit()\n",
"print(obj_dml_plpr.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 20,
"id": "586d5edf",
"metadata": {},
"outputs": [
@@ -433,7 +433,7 @@
" x_cols=[col for col in cre_data.columns if \"x\" in col],\n",
" static_panel=True)\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " static_panel_approach='cre_general')\n",
+ " approach='cre_general')\n",
" dml_plpr.fit()\n",
" res_cre_general[i, 0] = dml_plpr.coef[0]\n",
" res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -442,7 +442,7 @@
"\n",
" # CRE normality\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " static_panel_approach='cre_normal')\n",
+ " approach='cre_normal')\n",
" dml_plpr.fit()\n",
" res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
" res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -455,7 +455,7 @@
" x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
" static_panel=True)\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " static_panel_approach='fd_exact')\n",
+ " approach='fd_exact')\n",
" dml_plpr.fit()\n",
" res_fd[i, 0] = dml_plpr.coef[0]\n",
" res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -468,7 +468,7 @@
" x_cols=[col for col in wd_data.columns if \"x\" in col],\n",
" static_panel=True)\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " static_panel_approach='wg_approx')\n",
+ " approach='wg_approx')\n",
" dml_plpr.fit()\n",
" res_wd[i, 0] = dml_plpr.coef[0]\n",
" res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
@@ -478,8 +478,8 @@
},
{
"cell_type": "code",
- "execution_count": 28,
- "id": "987c0496",
+ "execution_count": 21,
+ "id": "33119186",
"metadata": {},
"outputs": [
{
@@ -545,7 +545,7 @@
"WG approx 0.502006 0.002006 0.94"
]
},
- "execution_count": 28,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -575,7 +575,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.13.9"
+ "version": "3.12.6"
}
},
"nbformat": 4,
From eb68557c3923d59b8e08366a62302ad2fba896d2 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 10 Nov 2025 19:50:51 +0100
Subject: [PATCH 18/33] add transformed data arrays for nuisance estimation
---
doubleml/plm/plpr.py | 25 +-
doubleml/plm/sim/example_sim.ipynb | 357 ++++++++++++++++++++++++++++-
2 files changed, 360 insertions(+), 22 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index f94238c22..2fa99b8d7 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -124,11 +124,15 @@ def __init__(
self._sensitivity_implemented = False
self._external_predictions_implemented = True
- # Get transformed data depending on approach
- # TODO: get y, x, d cols, set additional properties for y_data, d_data, x_data to be used in
- # nuisance
+ # get transformed data depending on approach
self._data_transform, self._transform_col_names, = self._transform_data(self._approach)
+ # save transformed data parts for ML estimation
+ # TODO: check d_cols dimension issue, for now as for panel data only one treatment allowed currently
+ self._y_data_transform = self.data_transform.loc[:, self.transform_col_names['y_col']].values
+ self._d_data_transform = self.data_transform.loc[:, self.transform_col_names['d_cols']].values.flatten()
+ self._x_data_transform = self.data_transform.loc[:, self.transform_col_names['x_cols']].values
+ # TODO: for fd_exact, n_obs changes, smpls originally drawn are not working anymore
def _format_score_info_str(self):
score_approach_info = (
@@ -164,7 +168,7 @@ def transform_col_names(self):
The column names of the transformed static panel data.
"""
return self._transform_col_names
-
+
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
@@ -234,8 +238,8 @@ def _transform_data(self, approach):
return data, col_names
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
- x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
- x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
+ x, y = check_X_y(self._x_data_transform, self._y_data_transform, force_all_finite=False)
+ x, d = check_X_y(x, self._d_data_transform, force_all_finite=False)
m_external = external_predictions["ml_m"] is not None
l_external = external_predictions["ml_l"] is not None
if "ml_g" in self._learner:
@@ -268,9 +272,9 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
# TODO: update this section
# cre using m_d + x for m_hat, otherwise only x
if self._approach == "cre_normal":
- help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "d": d})
- m_d = help_data.groupby(["id"]).transform("mean").values
- x = np.column_stack((x, m_d))
+ help_d_mean = pd.DataFrame({"id": self._dml_data.id_var, "d": d})
+ d_mean = help_d_mean.groupby(["id"]).transform("mean").values
+ x = np.column_stack((x, d_mean))
m_hat = _dml_cv_predict(
self._learner["ml_m"],
@@ -285,12 +289,11 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
# general cre adjustment
if self._approach == "cre_general":
- help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "m_hat": m_hat["preds"], "d": d})
+ help_data = pd.DataFrame({"id": self._dml_data.id_var, "m_hat": m_hat["preds"], "d": d})
group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
m_hat["preds"] = m_hat_star
-
_check_finite_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls)
if self._check_learner(self._learner["ml_m"], "ml_m", regressor=True, classifier=True):
_check_is_propensity(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls, eps=1e-12)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 8d9108dd1..56a869ce8 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -77,7 +77,8 @@
"import numpy as np\n",
"import pandas as pd\n",
"import statsmodels.api as sm\n",
- "from doubleml.data.base_data import DoubleMLData\n",
+ "from doubleml.data.base_data import DoubleMLData \n",
+ "from doubleml.data.panel_data import DoubleMLPanelData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
"from sklearn.linear_model import LassoCV, LinearRegression\n",
"from sklearn.base import clone\n",
@@ -124,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,7 +134,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ "d 0.456003 0.020668 22.063024 7.163019e-108 0.415494 0.496512\n"
]
}
],
@@ -150,8 +151,6 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "from doubleml.data.panel_data import DoubleMLPanelData\n",
- "\n",
"obj_panel = DoubleMLPanelData(cre_data,\n",
" y_col='y',\n",
" d_cols='d',\n",
@@ -167,6 +166,300 @@
"print(dml_panel_plpr.summary)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "44387af3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.515035 0.020056 25.679326 1.989663e-145 0.475725 0.554345\n"
+ ]
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "\n",
+ "x_cols = [col for col in data.columns if \"x\" in col]\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_panel = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
+ " approach='cre_general', n_folds=5\n",
+ " )\n",
+ "dml_panel_plpr.fit()\n",
+ "print(dml_panel_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "a8592f23",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[[(array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
+ " array([ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 100,\n",
+ " 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,\n",
+ " 112, 113, 114, 115, 116, 117, 118, 119, 160, 161, 162,\n",
+ " 163, 164, 165, 166, 167, 168, 169, 180, 181, 182, 183,\n",
+ " 184, 185, 186, 187, 188, 189, 270, 271, 272, 273, 274,\n",
+ " 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,\n",
+ " 286, 287, 288, 289, 310, 311, 312, 313, 314, 315, 316,\n",
+ " 317, 318, 319, 360, 361, 362, 363, 364, 365, 366, 367,\n",
+ " 368, 369, 410, 411, 412, 413, 414, 415, 416, 417, 418,\n",
+ " 419, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639,\n",
+ " 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 770,\n",
+ " 771, 772, 773, 774, 775, 776, 777, 778, 779, 810, 811,\n",
+ " 812, 813, 814, 815, 816, 817, 818, 819, 870, 871, 872,\n",
+ " 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,\n",
+ " 884, 885, 886, 887, 888, 889, 990, 991, 992, 993, 994,\n",
+ " 995, 996, 997, 998, 999, 1050, 1051, 1052, 1053, 1054, 1055,\n",
+ " 1056, 1057, 1058, 1059, 1160, 1161, 1162, 1163, 1164, 1165, 1166,\n",
+ " 1167, 1168, 1169, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237,\n",
+ " 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248,\n",
+ " 1249, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269,\n",
+ " 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1390,\n",
+ " 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1540, 1541,\n",
+ " 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1570, 1571, 1572,\n",
+ " 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1660, 1661, 1662, 1663,\n",
+ " 1664, 1665, 1666, 1667, 1668, 1669, 1680, 1681, 1682, 1683, 1684,\n",
+ " 1685, 1686, 1687, 1688, 1689, 1760, 1761, 1762, 1763, 1764, 1765,\n",
+ " 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776,\n",
+ " 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787,\n",
+ " 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798,\n",
+ " 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809,\n",
+ " 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880,\n",
+ " 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1910, 1911,\n",
+ " 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1970, 1971, 1972,\n",
+ " 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983,\n",
+ " 1984, 1985, 1986, 1987, 1988, 1989, 2000, 2001, 2002, 2003, 2004,\n",
+ " 2005, 2006, 2007, 2008, 2009, 2040, 2041, 2042, 2043, 2044, 2045,\n",
+ " 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056,\n",
+ " 2057, 2058, 2059, 2070, 2071, 2072, 2073, 2074, 2075, 2076, 2077,\n",
+ " 2078, 2079, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108,\n",
+ " 2109, 2210, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219,\n",
+ " 2230, 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2260,\n",
+ " 2261, 2262, 2263, 2264, 2265, 2266, 2267, 2268, 2269, 2300, 2301,\n",
+ " 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312,\n",
+ " 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2340, 2341, 2342, 2343,\n",
+ " 2344, 2345, 2346, 2347, 2348, 2349, 2370, 2371, 2372, 2373, 2374,\n",
+ " 2375, 2376, 2377, 2378, 2379])),\n",
+ " (array([ 0, 1, 2, ..., 2487, 2488, 2489], shape=(2000,)),\n",
+ " array([ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 50,\n",
+ " 51, 52, 53, 54, 55, 56, 57, 58, 59, 70, 71,\n",
+ " 72, 73, 74, 75, 76, 77, 78, 79, 120, 121, 122,\n",
+ " 123, 124, 125, 126, 127, 128, 129, 150, 151, 152, 153,\n",
+ " 154, 155, 156, 157, 158, 159, 170, 171, 172, 173, 174,\n",
+ " 175, 176, 177, 178, 179, 300, 301, 302, 303, 304, 305,\n",
+ " 306, 307, 308, 309, 320, 321, 322, 323, 324, 325, 326,\n",
+ " 327, 328, 329, 540, 541, 542, 543, 544, 545, 546, 547,\n",
+ " 548, 549, 650, 651, 652, 653, 654, 655, 656, 657, 658,\n",
+ " 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669,\n",
+ " 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 730,\n",
+ " 731, 732, 733, 734, 735, 736, 737, 738, 739, 840, 841,\n",
+ " 842, 843, 844, 845, 846, 847, 848, 849, 960, 961, 962,\n",
+ " 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973,\n",
+ " 974, 975, 976, 977, 978, 979, 1000, 1001, 1002, 1003, 1004,\n",
+ " 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015,\n",
+ " 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026,\n",
+ " 1027, 1028, 1029, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067,\n",
+ " 1068, 1069, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098,\n",
+ " 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109,\n",
+ " 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1180,\n",
+ " 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1250, 1251,\n",
+ " 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1270, 1271, 1272,\n",
+ " 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283,\n",
+ " 1284, 1285, 1286, 1287, 1288, 1289, 1320, 1321, 1322, 1323, 1324,\n",
+ " 1325, 1326, 1327, 1328, 1329, 1340, 1341, 1342, 1343, 1344, 1345,\n",
+ " 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356,\n",
+ " 1357, 1358, 1359, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417,\n",
+ " 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428,\n",
+ " 1429, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479,\n",
+ " 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1610,\n",
+ " 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1650, 1651,\n",
+ " 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1700, 1701, 1702,\n",
+ " 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1750, 1751, 1752, 1753,\n",
+ " 1754, 1755, 1756, 1757, 1758, 1759, 1920, 1921, 1922, 1923, 1924,\n",
+ " 1925, 1926, 1927, 1928, 1929, 2080, 2081, 2082, 2083, 2084, 2085,\n",
+ " 2086, 2087, 2088, 2089, 2140, 2141, 2142, 2143, 2144, 2145, 2146,\n",
+ " 2147, 2148, 2149, 2170, 2171, 2172, 2173, 2174, 2175, 2176, 2177,\n",
+ " 2178, 2179, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247, 2248,\n",
+ " 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259,\n",
+ " 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330,\n",
+ " 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2350, 2351,\n",
+ " 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362,\n",
+ " 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2390, 2391, 2392, 2393,\n",
+ " 2394, 2395, 2396, 2397, 2398, 2399, 2490, 2491, 2492, 2493, 2494,\n",
+ " 2495, 2496, 2497, 2498, 2499])),\n",
+ " (array([ 20, 21, 22, ..., 2497, 2498, 2499], shape=(2000,)),\n",
+ " array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,\n",
+ " 11, 12, 13, 14, 15, 16, 17, 18, 19, 190, 191,\n",
+ " 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,\n",
+ " 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213,\n",
+ " 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,\n",
+ " 225, 226, 227, 228, 229, 290, 291, 292, 293, 294, 295,\n",
+ " 296, 297, 298, 299, 340, 341, 342, 343, 344, 345, 346,\n",
+ " 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,\n",
+ " 358, 359, 440, 441, 442, 443, 444, 445, 446, 447, 448,\n",
+ " 449, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469,\n",
+ " 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 490,\n",
+ " 491, 492, 493, 494, 495, 496, 497, 498, 499, 530, 531,\n",
+ " 532, 533, 534, 535, 536, 537, 538, 539, 560, 561, 562,\n",
+ " 563, 564, 565, 566, 567, 568, 569, 580, 581, 582, 583,\n",
+ " 584, 585, 586, 587, 588, 589, 600, 601, 602, 603, 604,\n",
+ " 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615,\n",
+ " 616, 617, 618, 619, 690, 691, 692, 693, 694, 695, 696,\n",
+ " 697, 698, 699, 750, 751, 752, 753, 754, 755, 756, 757,\n",
+ " 758, 759, 780, 781, 782, 783, 784, 785, 786, 787, 788,\n",
+ " 789, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809,\n",
+ " 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 920,\n",
+ " 921, 922, 923, 924, 925, 926, 927, 928, 929, 950, 951,\n",
+ " 952, 953, 954, 955, 956, 957, 958, 959, 980, 981, 982,\n",
+ " 983, 984, 985, 986, 987, 988, 989, 1110, 1111, 1112, 1113,\n",
+ " 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124,\n",
+ " 1125, 1126, 1127, 1128, 1129, 1150, 1151, 1152, 1153, 1154, 1155,\n",
+ " 1156, 1157, 1158, 1159, 1190, 1191, 1192, 1193, 1194, 1195, 1196,\n",
+ " 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207,\n",
+ " 1208, 1209, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458,\n",
+ " 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469,\n",
+ " 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1580,\n",
+ " 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1730, 1731,\n",
+ " 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742,\n",
+ " 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1820, 1821, 1822, 1823,\n",
+ " 1824, 1825, 1826, 1827, 1828, 1829, 1840, 1841, 1842, 1843, 1844,\n",
+ " 1845, 1846, 1847, 1848, 1849, 1890, 1891, 1892, 1893, 1894, 1895,\n",
+ " 1896, 1897, 1898, 1899, 1930, 1931, 1932, 1933, 1934, 1935, 1936,\n",
+ " 1937, 1938, 1939, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957,\n",
+ " 1958, 1959, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018,\n",
+ " 2019, 2030, 2031, 2032, 2033, 2034, 2035, 2036, 2037, 2038, 2039,\n",
+ " 2130, 2131, 2132, 2133, 2134, 2135, 2136, 2137, 2138, 2139, 2180,\n",
+ " 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2280, 2281,\n",
+ " 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2400, 2401, 2402,\n",
+ " 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2440, 2441, 2442, 2443,\n",
+ " 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454,\n",
+ " 2455, 2456, 2457, 2458, 2459])),\n",
+ " (array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
+ " array([ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 80,\n",
+ " 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,\n",
+ " 92, 93, 94, 95, 96, 97, 98, 99, 230, 231, 232,\n",
+ " 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243,\n",
+ " 244, 245, 246, 247, 248, 249, 260, 261, 262, 263, 264,\n",
+ " 265, 266, 267, 268, 269, 370, 371, 372, 373, 374, 375,\n",
+ " 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386,\n",
+ " 387, 388, 389, 400, 401, 402, 403, 404, 405, 406, 407,\n",
+ " 408, 409, 480, 481, 482, 483, 484, 485, 486, 487, 488,\n",
+ " 489, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689,\n",
+ " 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 740,\n",
+ " 741, 742, 743, 744, 745, 746, 747, 748, 749, 760, 761,\n",
+ " 762, 763, 764, 765, 766, 767, 768, 769, 820, 821, 822,\n",
+ " 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833,\n",
+ " 834, 835, 836, 837, 838, 839, 850, 851, 852, 853, 854,\n",
+ " 855, 856, 857, 858, 859, 930, 931, 932, 933, 934, 935,\n",
+ " 936, 937, 938, 939, 1030, 1031, 1032, 1033, 1034, 1035, 1036,\n",
+ " 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047,\n",
+ " 1048, 1049, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078,\n",
+ " 1079, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149,\n",
+ " 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1210,\n",
+ " 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1290, 1291,\n",
+ " 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302,\n",
+ " 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1360, 1361, 1362, 1363,\n",
+ " 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374,\n",
+ " 1375, 1376, 1377, 1378, 1379, 1400, 1401, 1402, 1403, 1404, 1405,\n",
+ " 1406, 1407, 1408, 1409, 1440, 1441, 1442, 1443, 1444, 1445, 1446,\n",
+ " 1447, 1448, 1449, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507,\n",
+ " 1508, 1509, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538,\n",
+ " 1539, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559,\n",
+ " 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1600,\n",
+ " 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1620, 1621,\n",
+ " 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632,\n",
+ " 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1690, 1691, 1692, 1693,\n",
+ " 1694, 1695, 1696, 1697, 1698, 1699, 1710, 1711, 1712, 1713, 1714,\n",
+ " 1715, 1716, 1717, 1718, 1719, 1850, 1851, 1852, 1853, 1854, 1855,\n",
+ " 1856, 1857, 1858, 1859, 1990, 1991, 1992, 1993, 1994, 1995, 1996,\n",
+ " 1997, 1998, 1999, 2090, 2091, 2092, 2093, 2094, 2095, 2096, 2097,\n",
+ " 2098, 2099, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118,\n",
+ " 2119, 2120, 2121, 2122, 2123, 2124, 2125, 2126, 2127, 2128, 2129,\n",
+ " 2160, 2161, 2162, 2163, 2164, 2165, 2166, 2167, 2168, 2169, 2200,\n",
+ " 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2380, 2381,\n",
+ " 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2410, 2411, 2412,\n",
+ " 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2470, 2471, 2472, 2473,\n",
+ " 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484,\n",
+ " 2485, 2486, 2487, 2488, 2489])),\n",
+ " (array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
+ " array([ 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 130,\n",
+ " 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,\n",
+ " 142, 143, 144, 145, 146, 147, 148, 149, 250, 251, 252,\n",
+ " 253, 254, 255, 256, 257, 258, 259, 330, 331, 332, 333,\n",
+ " 334, 335, 336, 337, 338, 339, 390, 391, 392, 393, 394,\n",
+ " 395, 396, 397, 398, 399, 420, 421, 422, 423, 424, 425,\n",
+ " 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436,\n",
+ " 437, 438, 439, 450, 451, 452, 453, 454, 455, 456, 457,\n",
+ " 458, 459, 500, 501, 502, 503, 504, 505, 506, 507, 508,\n",
+ " 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519,\n",
+ " 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 550,\n",
+ " 551, 552, 553, 554, 555, 556, 557, 558, 559, 570, 571,\n",
+ " 572, 573, 574, 575, 576, 577, 578, 579, 590, 591, 592,\n",
+ " 593, 594, 595, 596, 597, 598, 599, 620, 621, 622, 623,\n",
+ " 624, 625, 626, 627, 628, 629, 640, 641, 642, 643, 644,\n",
+ " 645, 646, 647, 648, 649, 670, 671, 672, 673, 674, 675,\n",
+ " 676, 677, 678, 679, 790, 791, 792, 793, 794, 795, 796,\n",
+ " 797, 798, 799, 860, 861, 862, 863, 864, 865, 866, 867,\n",
+ " 868, 869, 890, 891, 892, 893, 894, 895, 896, 897, 898,\n",
+ " 899, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919,\n",
+ " 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 1080,\n",
+ " 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1220, 1221,\n",
+ " 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1310, 1311, 1312,\n",
+ " 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1380, 1381, 1382, 1383,\n",
+ " 1384, 1385, 1386, 1387, 1388, 1389, 1430, 1431, 1432, 1433, 1434,\n",
+ " 1435, 1436, 1437, 1438, 1439, 1510, 1511, 1512, 1513, 1514, 1515,\n",
+ " 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526,\n",
+ " 1527, 1528, 1529, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597,\n",
+ " 1598, 1599, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648,\n",
+ " 1649, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679,\n",
+ " 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1810,\n",
+ " 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1830, 1831,\n",
+ " 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1860, 1861, 1862,\n",
+ " 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1900, 1901, 1902, 1903,\n",
+ " 1904, 1905, 1906, 1907, 1908, 1909, 1940, 1941, 1942, 1943, 1944,\n",
+ " 1945, 1946, 1947, 1948, 1949, 1960, 1961, 1962, 1963, 1964, 1965,\n",
+ " 1966, 1967, 1968, 1969, 2020, 2021, 2022, 2023, 2024, 2025, 2026,\n",
+ " 2027, 2028, 2029, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067,\n",
+ " 2068, 2069, 2150, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158,\n",
+ " 2159, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2197, 2198, 2199,\n",
+ " 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2270,\n",
+ " 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2290, 2291,\n",
+ " 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2420, 2421, 2422,\n",
+ " 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433,\n",
+ " 2434, 2435, 2436, 2437, 2438, 2439, 2460, 2461, 2462, 2463, 2464,\n",
+ " 2465, 2466, 2467, 2468, 2469]))]]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dml_panel_plpr.smpls"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 12,
@@ -326,7 +619,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 25,
"id": "61a72563",
"metadata": {},
"outputs": [
@@ -334,8 +627,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.531576 0.025707 20.677905 5.476813e-95 0.48119 0.581961\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.512959 0.025253 20.313176 9.833638e-92 0.463465 0.562453\n"
]
}
],
@@ -362,7 +655,49 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 5,
+ "id": "ff0e322b",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "AssertionError",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mAssertionError\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 13\u001b[39m ml_m = clone(learner)\n\u001b[32m 15\u001b[39m obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach=\u001b[33m'\u001b[39m\u001b[33mfd_exact\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[43mobj_dml_plpr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[38;5;28mprint\u001b[39m(obj_dml_plpr.summary)\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\double_ml.py:570\u001b[39m, in \u001b[36mDoubleML.fit\u001b[39m\u001b[34m(self, n_jobs_cv, store_predictions, external_predictions, store_models)\u001b[39m\n\u001b[32m 567\u001b[39m \u001b[38;5;28mself\u001b[39m._dml_data.set_x_d(\u001b[38;5;28mself\u001b[39m._dml_data.d_cols[i_d])\n\u001b[32m 569\u001b[39m \u001b[38;5;66;03m# predictions have to be stored in loop for sensitivity analysis\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m570\u001b[39m nuisance_predictions = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_fit_nuisance_and_score_elements\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 571\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstore_predictions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexternal_predictions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstore_models\u001b[49m\n\u001b[32m 572\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 574\u001b[39m \u001b[38;5;28mself\u001b[39m._solve_score_and_estimate_se()\n\u001b[32m 576\u001b[39m \u001b[38;5;66;03m# sensitivity elements can depend on the estimated parameter\u001b[39;00m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\double_ml.py:1088\u001b[39m, in \u001b[36mDoubleML._fit_nuisance_and_score_elements\u001b[39m\u001b[34m(self, n_jobs_cv, store_predictions, external_predictions, store_models)\u001b[39m\n\u001b[32m 1083\u001b[39m ext_prediction_dict = _set_external_predictions(\n\u001b[32m 1084\u001b[39m external_predictions, learners=\u001b[38;5;28mself\u001b[39m.params_names, treatment=\u001b[38;5;28mself\u001b[39m._dml_data.d_cols[\u001b[38;5;28mself\u001b[39m._i_treat], i_rep=\u001b[38;5;28mself\u001b[39m._i_rep\n\u001b[32m 1085\u001b[39m )\n\u001b[32m 1087\u001b[39m \u001b[38;5;66;03m# ml estimation of nuisance models and computation of score elements\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1088\u001b[39m score_elements, preds = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_nuisance_est\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1089\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__smpls\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexternal_predictions\u001b[49m\u001b[43m=\u001b[49m\u001b[43mext_prediction_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstore_models\u001b[49m\n\u001b[32m 1090\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1092\u001b[39m \u001b[38;5;28mself\u001b[39m._set_score_elements(score_elements, \u001b[38;5;28mself\u001b[39m._i_rep, \u001b[38;5;28mself\u001b[39m._i_treat)\n\u001b[32m 1094\u001b[39m \u001b[38;5;66;03m# calculate nuisance losses and store predictions and targets of the nuisance models\u001b[39;00m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\plm\\plpr.py:254\u001b[39m, in \u001b[36mDoubleMLPLPR._nuisance_est\u001b[39m\u001b[34m(self, smpls, n_jobs_cv, external_predictions, return_models)\u001b[39m\n\u001b[32m 252\u001b[39m l_hat = {\u001b[33m\"\u001b[39m\u001b[33mpreds\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[33m\"\u001b[39m\u001b[33mtargets\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[33m\"\u001b[39m\u001b[33mmodels\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m}\n\u001b[32m 253\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m254\u001b[39m l_hat = \u001b[43m_dml_cv_predict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 255\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_learner\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 257\u001b[39m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 258\u001b[39m \u001b[43m \u001b[49m\u001b[43msmpls\u001b[49m\u001b[43m=\u001b[49m\u001b[43msmpls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 259\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 260\u001b[39m \u001b[43m \u001b[49m\u001b[43mest_params\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_params\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 261\u001b[39m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_predict_method\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 262\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 263\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 264\u001b[39m _check_finite_predictions(l_hat[\u001b[33m\"\u001b[39m\u001b[33mpreds\u001b[39m\u001b[33m\"\u001b[39m], \u001b[38;5;28mself\u001b[39m._learner[\u001b[33m\"\u001b[39m\u001b[33mml_l\u001b[39m\u001b[33m\"\u001b[39m], \u001b[33m\"\u001b[39m\u001b[33mml_l\u001b[39m\u001b[33m\"\u001b[39m, smpls)\n\u001b[32m 266\u001b[39m \u001b[38;5;66;03m# nuisance m\u001b[39;00m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\utils\\_estimation.py:76\u001b[39m, in \u001b[36m_dml_cv_predict\u001b[39m\u001b[34m(estimator, x, y, smpls, n_jobs, est_params, method, return_train_preds, return_models)\u001b[39m\n\u001b[32m 74\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m smpls_is_partition:\n\u001b[32m 75\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m fold_specific_target, \u001b[33m\"\u001b[39m\u001b[33mcombination of fold-specific y and no cross-fitting not implemented yet\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m76\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(smpls) == \u001b[32m1\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m method == \u001b[33m\"\u001b[39m\u001b[33mpredict_proba\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m fold_specific_target \u001b[38;5;66;03m# fold_specific_target only needed for PLIV.partialXZ\u001b[39;00m\n",
+ "\u001b[31mAssertionError\u001b[39m: "
+ ]
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "\n",
+ "obj_dml_data_pdml = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='fd_exact')\n",
+ "obj_dml_plpr.fit()\n",
+ "print(obj_dml_plpr.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
"id": "aeb00efe",
"metadata": {},
"outputs": [
@@ -370,8 +705,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.464302 0.022013 21.092019 9.415595e-99 0.421157 0.507447\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.513121 0.020666 24.828987 4.361657e-136 0.472616 0.553626\n"
]
}
],
From c404f06604e7df01ea7db91c0f0928f3c539da3f Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Tue, 11 Nov 2025 17:10:16 +0100
Subject: [PATCH 19/33] add _initialize_fd_model because of n_obs and smpls
issue
---
doubleml/plm/plpr.py | 127 ++++--
doubleml/plm/sim/example_sim.ipynb | 709 +++++++++++++----------------
2 files changed, 406 insertions(+), 430 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 2fa99b8d7..c301b6468 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -10,6 +10,7 @@
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
from ..utils._estimation import _dml_cv_predict, _dml_tune
+from ..utils.resampling import DoubleMLClusterResampling
class DoubleMLPLPR(LinearScoreMixin, DoubleML):
@@ -65,7 +66,7 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
--------
>>> import numpy as np
>>> import doubleml as dml
-
+
Notes
-----
@@ -74,7 +75,17 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
"""
def __init__(
- self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", approach="fd_exact", draw_sample_splitting=True):
+ self,
+ obj_dml_data,
+ ml_l,
+ ml_m,
+ ml_g=None,
+ n_folds=5,
+ n_rep=1,
+ score="partialling out",
+ approach="fd_exact",
+ draw_sample_splitting=True,
+ ):
super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)
self._check_data(self._dml_data)
@@ -125,29 +136,37 @@ def __init__(
self._external_predictions_implemented = True
# get transformed data depending on approach
- self._data_transform, self._transform_col_names, = self._transform_data(self._approach)
+ (
+ self._data_transform,
+ self._transform_col_names,
+ ) = self._transform_data(self._approach)
# save transformed data parts for ML estimation
# TODO: check d_cols dimension issue, for now as for panel data only one treatment allowed currently
- self._y_data_transform = self.data_transform.loc[:, self.transform_col_names['y_col']].values
- self._d_data_transform = self.data_transform.loc[:, self.transform_col_names['d_cols']].values.flatten()
- self._x_data_transform = self.data_transform.loc[:, self.transform_col_names['x_cols']].values
+ self._y_data_transform = self.data_transform.loc[:, self.transform_col_names["y_col"]].values
+ self._d_data_transform = self.data_transform.loc[:, self.transform_col_names["d_cols"]].values.flatten()
+ self._x_data_transform = self.data_transform.loc[:, self.transform_col_names["x_cols"]].values
# TODO: for fd_exact, n_obs changes, smpls originally drawn are not working anymore
+ self._n_obs_transform = self._data_transform.shape[0]
+ self._initialize_fd_model()
def _format_score_info_str(self):
- score_approach_info = (
- f"Score function: {str(self.score)}\n"
- f"Static panel model approach: {str(self.approach)}"
- )
+ score_approach_info = f"Score function: {str(self.score)}\n" f"Static panel model approach: {str(self.approach)}"
return score_approach_info
def _format_additional_info_str(self):
"""
Includes information on the transformed features based on the estimation approach.
"""
- # TODO: Add Information on features after transformation
- return ""
-
+ data_transform_summary = (
+ f"Post Transformation Data Summary:\n\n"
+ f"Outcome variable: {self.transform_col_names['y_col']}\n"
+ f"Treatment variable(s): {self.transform_col_names['d_cols']}\n"
+ f"Covariates: {self.transform_col_names['x_cols']}\n"
+ f"No. Observations: {self._n_obs_transform}\n"
+ )
+ return data_transform_summary
+
@property
def approach(self):
"""
@@ -161,22 +180,27 @@ def data_transform(self):
The transformed static panel data.
"""
return self._data_transform
-
+
@property
def transform_col_names(self):
"""
The column names of the transformed static panel data.
"""
return self._transform_col_names
-
+
+ @property
+ def n_obs_transform(self):
+ """
+ The number of observations after data transformation.
+ """
+ return self._n_obs_transform
+
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
def _check_data(self, obj_dml_data):
if not isinstance(obj_dml_data, DoubleMLPanelData):
- raise TypeError(
- f"The data must be of DoubleMLPanelData type. {str(type(obj_dml_data))} was passed."
- )
+ raise TypeError(f"The data must be of DoubleMLPanelData type. {str(type(obj_dml_data))} was passed.")
if not obj_dml_data.static_panel:
raise ValueError(
"For the PLPR model, the DoubleMLPanelData object requires the static_panel flag to be set to True."
@@ -187,7 +211,7 @@ def _check_data(self, obj_dml_data):
"DoubleMLPLPR currently does not support instrumental variables."
)
return
-
+
def _check_approach(self, approach, valid_approach):
if isinstance(approach, str):
if approach not in valid_approach:
@@ -195,8 +219,7 @@ def _check_approach(self, approach, valid_approach):
else:
raise TypeError(f"approach should be a string. {str(approach)} was passed.")
return
-
- # TODO: preprocess and transform data based on approach (cre, fd, wd)
+
def _transform_data(self, approach):
df = self._dml_data.data.copy()
@@ -208,7 +231,7 @@ def _transform_data(self, approach):
if approach in ["cre_general", "cre_normal"]:
df_id_means = df[[id_col] + x_cols].groupby(id_col).transform("mean")
- df_means = df_id_means.add_suffix("_mean")
+ df_means = df_id_means.add_suffix("_mean")
data = pd.concat([df, df_means], axis=1)
col_names = {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"{x}_mean" for x in x_cols]}
elif approach == "fd_exact":
@@ -224,19 +247,51 @@ def _transform_data(self, approach):
df_fd = df_fd.rename(columns=cols_rename_dict)
# drop rows for first period
data = df_fd.dropna(subset=[x_cols[0] + "_lag"]).reset_index(drop=True)
- col_names = {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
+ col_names = {
+ "y_col": f"{y_col}_diff",
+ "d_cols": [f"{d}_diff" for d in d_cols],
+ "x_cols": x_cols + [f"{x}_lag" for x in x_cols],
+ }
elif approach == "wg_approx":
cols_to_demean = [y_col] + d_cols + x_cols
# compute group and grand means for within means
- group_means = df.groupby(id_col)[cols_to_demean].transform('mean')
+ group_means = df.groupby(id_col)[cols_to_demean].transform("mean")
grand_means = df[cols_to_demean].mean()
within_means = df[cols_to_demean] - group_means + grand_means
within_means = within_means.add_suffix("_demean")
data = pd.concat([df[[id_col, t_col]], within_means], axis=1)
- col_names = {"y_col": f"{y_col}_demean", "d_cols": [f"{d}_demean" for d in d_cols], "x_cols": [f"{x}_demean" for x in x_cols]}
+ col_names = {
+ "y_col": f"{y_col}_demean",
+ "d_cols": [f"{d}_demean" for d in d_cols],
+ "x_cols": [f"{x}_demean" for x in x_cols],
+ }
return data, col_names
+ def _initialize_fd_model(self):
+ if self._approach == "fd_exact":
+ self._smpls = None
+ self._smpls_cluster = None
+ # TODO: # overwrites data property _cluster_vars, but then data object can't be reused with other approach
+ # when using a new model specific _cluster_vars_fd, _se_causal_pars() does not run as it uses
+ # self._dml_data.cluster_vars, where n_obs dimension does not match psi arrays.
+ # overwrite _se_causal_pars?
+ self._dml_data._cluster_vars = self._data_transform.loc[:, self._dml_data.cluster_cols]
+ self._cluster_vars_fd = self._data_transform.loc[:, self._dml_data.cluster_cols].values
+
+ # initialize model again
+ self._score_dim = (self._data_transform.shape[0], self.n_rep, self._dml_data.n_coefs)
+ self._initialize_dml_model()
+ # draw smpls for first difference transformed data
+ obj_dml_resampling = DoubleMLClusterResampling(
+ n_folds=self._n_folds_per_cluster,
+ n_rep=self.n_rep,
+ n_obs=self._n_obs_transform,
+ n_cluster_vars=self._dml_data.n_cluster_vars,
+ cluster_vars=self._dml_data.cluster_vars,
+ )
+ self._smpls, self._smpls_cluster = obj_dml_resampling.split_samples()
+
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._x_data_transform, self._y_data_transform, force_all_finite=False)
x, d = check_X_y(x, self._d_data_transform, force_all_finite=False)
@@ -277,15 +332,15 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
x = np.column_stack((x, d_mean))
m_hat = _dml_cv_predict(
- self._learner["ml_m"],
- x,
- d,
- smpls=smpls,
- n_jobs=n_jobs_cv,
- est_params=self._get_params("ml_m"),
- method=self._predict_method["ml_m"],
- return_models=return_models,
- )
+ self._learner["ml_m"],
+ x,
+ d,
+ smpls=smpls,
+ n_jobs=n_jobs_cv,
+ est_params=self._get_params("ml_m"),
+ method=self._predict_method["ml_m"],
+ return_models=return_models,
+ )
# general cre adjustment
if self._approach == "cre_general":
@@ -293,7 +348,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
m_hat["preds"] = m_hat_star
-
+
_check_finite_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls)
if self._check_learner(self._learner["ml_m"], "ml_m", regressor=True, classifier=True):
_check_is_propensity(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls, eps=1e-12)
@@ -427,4 +482,4 @@ def _nuisance_tuning(
res = {"params": params, "tune_res": tune_res}
- return res
\ No newline at end of file
+ return res
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 56a869ce8..16f4036e5 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -99,9 +99,9 @@
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913912),\n",
- " np.float64(0.6090488219157397),\n",
- " np.float64(0.7348254130670426))"
+ "(np.float64(0.6719371174913908),\n",
+ " np.float64(0.6090488219157394),\n",
+ " np.float64(0.7348254130670423))"
]
},
"execution_count": 2,
@@ -125,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "73c6599b",
"metadata": {},
"outputs": [
@@ -133,8 +133,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.456003 0.020668 22.063024 7.163019e-108 0.415494 0.496512\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.485919 0.021214 22.905716 4.075038e-116 0.44434 0.527497\n"
]
}
],
@@ -151,24 +151,22 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_panel = DoubleMLPanelData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=x_cols,\n",
- " static_panel=True)\n",
- "\n",
- "dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
- " approach='cre_general', n_folds=5\n",
- " )\n",
- "dml_panel_plpr.fit()\n",
- "print(dml_panel_plpr.summary)"
+ "panel_data_obj = DoubleMLPanelData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"id": "44387af3",
"metadata": {},
"outputs": [
@@ -177,7 +175,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.515035 0.020056 25.679326 1.989663e-145 0.475725 0.554345\n"
+ "d 0.480439 0.020407 23.542639 1.493410e-122 0.440442 0.520436\n"
]
}
],
@@ -190,274 +188,18 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_panel = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=x_cols,\n",
- " static_panel=True)\n",
- "\n",
- "dml_panel_plpr = DoubleMLPLPR(obj_panel, ml_l=ml_l, ml_m=ml_m,\n",
- " approach='cre_general', n_folds=5\n",
- " )\n",
- "dml_panel_plpr.fit()\n",
- "print(dml_panel_plpr.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "a8592f23",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[[(array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
- " array([ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 100,\n",
- " 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,\n",
- " 112, 113, 114, 115, 116, 117, 118, 119, 160, 161, 162,\n",
- " 163, 164, 165, 166, 167, 168, 169, 180, 181, 182, 183,\n",
- " 184, 185, 186, 187, 188, 189, 270, 271, 272, 273, 274,\n",
- " 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,\n",
- " 286, 287, 288, 289, 310, 311, 312, 313, 314, 315, 316,\n",
- " 317, 318, 319, 360, 361, 362, 363, 364, 365, 366, 367,\n",
- " 368, 369, 410, 411, 412, 413, 414, 415, 416, 417, 418,\n",
- " 419, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639,\n",
- " 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 770,\n",
- " 771, 772, 773, 774, 775, 776, 777, 778, 779, 810, 811,\n",
- " 812, 813, 814, 815, 816, 817, 818, 819, 870, 871, 872,\n",
- " 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,\n",
- " 884, 885, 886, 887, 888, 889, 990, 991, 992, 993, 994,\n",
- " 995, 996, 997, 998, 999, 1050, 1051, 1052, 1053, 1054, 1055,\n",
- " 1056, 1057, 1058, 1059, 1160, 1161, 1162, 1163, 1164, 1165, 1166,\n",
- " 1167, 1168, 1169, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237,\n",
- " 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248,\n",
- " 1249, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269,\n",
- " 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1390,\n",
- " 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1540, 1541,\n",
- " 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1570, 1571, 1572,\n",
- " 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1660, 1661, 1662, 1663,\n",
- " 1664, 1665, 1666, 1667, 1668, 1669, 1680, 1681, 1682, 1683, 1684,\n",
- " 1685, 1686, 1687, 1688, 1689, 1760, 1761, 1762, 1763, 1764, 1765,\n",
- " 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776,\n",
- " 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787,\n",
- " 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798,\n",
- " 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809,\n",
- " 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880,\n",
- " 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1910, 1911,\n",
- " 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1970, 1971, 1972,\n",
- " 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983,\n",
- " 1984, 1985, 1986, 1987, 1988, 1989, 2000, 2001, 2002, 2003, 2004,\n",
- " 2005, 2006, 2007, 2008, 2009, 2040, 2041, 2042, 2043, 2044, 2045,\n",
- " 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056,\n",
- " 2057, 2058, 2059, 2070, 2071, 2072, 2073, 2074, 2075, 2076, 2077,\n",
- " 2078, 2079, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108,\n",
- " 2109, 2210, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219,\n",
- " 2230, 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2260,\n",
- " 2261, 2262, 2263, 2264, 2265, 2266, 2267, 2268, 2269, 2300, 2301,\n",
- " 2302, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312,\n",
- " 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2340, 2341, 2342, 2343,\n",
- " 2344, 2345, 2346, 2347, 2348, 2349, 2370, 2371, 2372, 2373, 2374,\n",
- " 2375, 2376, 2377, 2378, 2379])),\n",
- " (array([ 0, 1, 2, ..., 2487, 2488, 2489], shape=(2000,)),\n",
- " array([ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 50,\n",
- " 51, 52, 53, 54, 55, 56, 57, 58, 59, 70, 71,\n",
- " 72, 73, 74, 75, 76, 77, 78, 79, 120, 121, 122,\n",
- " 123, 124, 125, 126, 127, 128, 129, 150, 151, 152, 153,\n",
- " 154, 155, 156, 157, 158, 159, 170, 171, 172, 173, 174,\n",
- " 175, 176, 177, 178, 179, 300, 301, 302, 303, 304, 305,\n",
- " 306, 307, 308, 309, 320, 321, 322, 323, 324, 325, 326,\n",
- " 327, 328, 329, 540, 541, 542, 543, 544, 545, 546, 547,\n",
- " 548, 549, 650, 651, 652, 653, 654, 655, 656, 657, 658,\n",
- " 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669,\n",
- " 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 730,\n",
- " 731, 732, 733, 734, 735, 736, 737, 738, 739, 840, 841,\n",
- " 842, 843, 844, 845, 846, 847, 848, 849, 960, 961, 962,\n",
- " 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973,\n",
- " 974, 975, 976, 977, 978, 979, 1000, 1001, 1002, 1003, 1004,\n",
- " 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015,\n",
- " 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026,\n",
- " 1027, 1028, 1029, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067,\n",
- " 1068, 1069, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098,\n",
- " 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109,\n",
- " 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1180,\n",
- " 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1250, 1251,\n",
- " 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1270, 1271, 1272,\n",
- " 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283,\n",
- " 1284, 1285, 1286, 1287, 1288, 1289, 1320, 1321, 1322, 1323, 1324,\n",
- " 1325, 1326, 1327, 1328, 1329, 1340, 1341, 1342, 1343, 1344, 1345,\n",
- " 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356,\n",
- " 1357, 1358, 1359, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417,\n",
- " 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428,\n",
- " 1429, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479,\n",
- " 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1610,\n",
- " 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1650, 1651,\n",
- " 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1700, 1701, 1702,\n",
- " 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1750, 1751, 1752, 1753,\n",
- " 1754, 1755, 1756, 1757, 1758, 1759, 1920, 1921, 1922, 1923, 1924,\n",
- " 1925, 1926, 1927, 1928, 1929, 2080, 2081, 2082, 2083, 2084, 2085,\n",
- " 2086, 2087, 2088, 2089, 2140, 2141, 2142, 2143, 2144, 2145, 2146,\n",
- " 2147, 2148, 2149, 2170, 2171, 2172, 2173, 2174, 2175, 2176, 2177,\n",
- " 2178, 2179, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247, 2248,\n",
- " 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259,\n",
- " 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330,\n",
- " 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2350, 2351,\n",
- " 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362,\n",
- " 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2390, 2391, 2392, 2393,\n",
- " 2394, 2395, 2396, 2397, 2398, 2399, 2490, 2491, 2492, 2493, 2494,\n",
- " 2495, 2496, 2497, 2498, 2499])),\n",
- " (array([ 20, 21, 22, ..., 2497, 2498, 2499], shape=(2000,)),\n",
- " array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,\n",
- " 11, 12, 13, 14, 15, 16, 17, 18, 19, 190, 191,\n",
- " 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,\n",
- " 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213,\n",
- " 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,\n",
- " 225, 226, 227, 228, 229, 290, 291, 292, 293, 294, 295,\n",
- " 296, 297, 298, 299, 340, 341, 342, 343, 344, 345, 346,\n",
- " 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,\n",
- " 358, 359, 440, 441, 442, 443, 444, 445, 446, 447, 448,\n",
- " 449, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469,\n",
- " 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 490,\n",
- " 491, 492, 493, 494, 495, 496, 497, 498, 499, 530, 531,\n",
- " 532, 533, 534, 535, 536, 537, 538, 539, 560, 561, 562,\n",
- " 563, 564, 565, 566, 567, 568, 569, 580, 581, 582, 583,\n",
- " 584, 585, 586, 587, 588, 589, 600, 601, 602, 603, 604,\n",
- " 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615,\n",
- " 616, 617, 618, 619, 690, 691, 692, 693, 694, 695, 696,\n",
- " 697, 698, 699, 750, 751, 752, 753, 754, 755, 756, 757,\n",
- " 758, 759, 780, 781, 782, 783, 784, 785, 786, 787, 788,\n",
- " 789, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809,\n",
- " 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 920,\n",
- " 921, 922, 923, 924, 925, 926, 927, 928, 929, 950, 951,\n",
- " 952, 953, 954, 955, 956, 957, 958, 959, 980, 981, 982,\n",
- " 983, 984, 985, 986, 987, 988, 989, 1110, 1111, 1112, 1113,\n",
- " 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124,\n",
- " 1125, 1126, 1127, 1128, 1129, 1150, 1151, 1152, 1153, 1154, 1155,\n",
- " 1156, 1157, 1158, 1159, 1190, 1191, 1192, 1193, 1194, 1195, 1196,\n",
- " 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207,\n",
- " 1208, 1209, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458,\n",
- " 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469,\n",
- " 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1580,\n",
- " 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1730, 1731,\n",
- " 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742,\n",
- " 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1820, 1821, 1822, 1823,\n",
- " 1824, 1825, 1826, 1827, 1828, 1829, 1840, 1841, 1842, 1843, 1844,\n",
- " 1845, 1846, 1847, 1848, 1849, 1890, 1891, 1892, 1893, 1894, 1895,\n",
- " 1896, 1897, 1898, 1899, 1930, 1931, 1932, 1933, 1934, 1935, 1936,\n",
- " 1937, 1938, 1939, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957,\n",
- " 1958, 1959, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018,\n",
- " 2019, 2030, 2031, 2032, 2033, 2034, 2035, 2036, 2037, 2038, 2039,\n",
- " 2130, 2131, 2132, 2133, 2134, 2135, 2136, 2137, 2138, 2139, 2180,\n",
- " 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2280, 2281,\n",
- " 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2400, 2401, 2402,\n",
- " 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2440, 2441, 2442, 2443,\n",
- " 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454,\n",
- " 2455, 2456, 2457, 2458, 2459])),\n",
- " (array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
- " array([ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 80,\n",
- " 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,\n",
- " 92, 93, 94, 95, 96, 97, 98, 99, 230, 231, 232,\n",
- " 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243,\n",
- " 244, 245, 246, 247, 248, 249, 260, 261, 262, 263, 264,\n",
- " 265, 266, 267, 268, 269, 370, 371, 372, 373, 374, 375,\n",
- " 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386,\n",
- " 387, 388, 389, 400, 401, 402, 403, 404, 405, 406, 407,\n",
- " 408, 409, 480, 481, 482, 483, 484, 485, 486, 487, 488,\n",
- " 489, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689,\n",
- " 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 740,\n",
- " 741, 742, 743, 744, 745, 746, 747, 748, 749, 760, 761,\n",
- " 762, 763, 764, 765, 766, 767, 768, 769, 820, 821, 822,\n",
- " 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833,\n",
- " 834, 835, 836, 837, 838, 839, 850, 851, 852, 853, 854,\n",
- " 855, 856, 857, 858, 859, 930, 931, 932, 933, 934, 935,\n",
- " 936, 937, 938, 939, 1030, 1031, 1032, 1033, 1034, 1035, 1036,\n",
- " 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047,\n",
- " 1048, 1049, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078,\n",
- " 1079, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149,\n",
- " 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1210,\n",
- " 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1290, 1291,\n",
- " 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302,\n",
- " 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1360, 1361, 1362, 1363,\n",
- " 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374,\n",
- " 1375, 1376, 1377, 1378, 1379, 1400, 1401, 1402, 1403, 1404, 1405,\n",
- " 1406, 1407, 1408, 1409, 1440, 1441, 1442, 1443, 1444, 1445, 1446,\n",
- " 1447, 1448, 1449, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507,\n",
- " 1508, 1509, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538,\n",
- " 1539, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559,\n",
- " 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1600,\n",
- " 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1620, 1621,\n",
- " 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632,\n",
- " 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1690, 1691, 1692, 1693,\n",
- " 1694, 1695, 1696, 1697, 1698, 1699, 1710, 1711, 1712, 1713, 1714,\n",
- " 1715, 1716, 1717, 1718, 1719, 1850, 1851, 1852, 1853, 1854, 1855,\n",
- " 1856, 1857, 1858, 1859, 1990, 1991, 1992, 1993, 1994, 1995, 1996,\n",
- " 1997, 1998, 1999, 2090, 2091, 2092, 2093, 2094, 2095, 2096, 2097,\n",
- " 2098, 2099, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118,\n",
- " 2119, 2120, 2121, 2122, 2123, 2124, 2125, 2126, 2127, 2128, 2129,\n",
- " 2160, 2161, 2162, 2163, 2164, 2165, 2166, 2167, 2168, 2169, 2200,\n",
- " 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2380, 2381,\n",
- " 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2410, 2411, 2412,\n",
- " 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2470, 2471, 2472, 2473,\n",
- " 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484,\n",
- " 2485, 2486, 2487, 2488, 2489])),\n",
- " (array([ 0, 1, 2, ..., 2497, 2498, 2499], shape=(2000,)),\n",
- " array([ 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 130,\n",
- " 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,\n",
- " 142, 143, 144, 145, 146, 147, 148, 149, 250, 251, 252,\n",
- " 253, 254, 255, 256, 257, 258, 259, 330, 331, 332, 333,\n",
- " 334, 335, 336, 337, 338, 339, 390, 391, 392, 393, 394,\n",
- " 395, 396, 397, 398, 399, 420, 421, 422, 423, 424, 425,\n",
- " 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436,\n",
- " 437, 438, 439, 450, 451, 452, 453, 454, 455, 456, 457,\n",
- " 458, 459, 500, 501, 502, 503, 504, 505, 506, 507, 508,\n",
- " 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519,\n",
- " 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 550,\n",
- " 551, 552, 553, 554, 555, 556, 557, 558, 559, 570, 571,\n",
- " 572, 573, 574, 575, 576, 577, 578, 579, 590, 591, 592,\n",
- " 593, 594, 595, 596, 597, 598, 599, 620, 621, 622, 623,\n",
- " 624, 625, 626, 627, 628, 629, 640, 641, 642, 643, 644,\n",
- " 645, 646, 647, 648, 649, 670, 671, 672, 673, 674, 675,\n",
- " 676, 677, 678, 679, 790, 791, 792, 793, 794, 795, 796,\n",
- " 797, 798, 799, 860, 861, 862, 863, 864, 865, 866, 867,\n",
- " 868, 869, 890, 891, 892, 893, 894, 895, 896, 897, 898,\n",
- " 899, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919,\n",
- " 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 1080,\n",
- " 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1220, 1221,\n",
- " 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1310, 1311, 1312,\n",
- " 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1380, 1381, 1382, 1383,\n",
- " 1384, 1385, 1386, 1387, 1388, 1389, 1430, 1431, 1432, 1433, 1434,\n",
- " 1435, 1436, 1437, 1438, 1439, 1510, 1511, 1512, 1513, 1514, 1515,\n",
- " 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526,\n",
- " 1527, 1528, 1529, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597,\n",
- " 1598, 1599, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648,\n",
- " 1649, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679,\n",
- " 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1810,\n",
- " 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1830, 1831,\n",
- " 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1860, 1861, 1862,\n",
- " 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1900, 1901, 1902, 1903,\n",
- " 1904, 1905, 1906, 1907, 1908, 1909, 1940, 1941, 1942, 1943, 1944,\n",
- " 1945, 1946, 1947, 1948, 1949, 1960, 1961, 1962, 1963, 1964, 1965,\n",
- " 1966, 1967, 1968, 1969, 2020, 2021, 2022, 2023, 2024, 2025, 2026,\n",
- " 2027, 2028, 2029, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067,\n",
- " 2068, 2069, 2150, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158,\n",
- " 2159, 2190, 2191, 2192, 2193, 2194, 2195, 2196, 2197, 2198, 2199,\n",
- " 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2270,\n",
- " 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2290, 2291,\n",
- " 2292, 2293, 2294, 2295, 2296, 2297, 2298, 2299, 2420, 2421, 2422,\n",
- " 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433,\n",
- " 2434, 2435, 2436, 2437, 2438, 2439, 2460, 2461, 2462, 2463, 2464,\n",
- " 2465, 2466, 2467, 2468, 2469]))]]"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_panel_plpr.smpls"
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
+ " \n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
@@ -477,8 +219,8 @@
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "11bbc988",
+ "execution_count": 6,
+ "id": "48d4dbd8",
"metadata": {},
"outputs": [
{
@@ -490,7 +232,7 @@
"------------------ Data Summary ------------------\n",
"Outcome variable: y\n",
"Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30']\n",
"Instrument variable(s): None\n",
"Time variable: time\n",
"Id variable: id\n",
@@ -508,8 +250,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.72960098]]\n",
- "Learner ml_m RMSE: [[0.95035703]]\n",
+ "Learner ml_l RMSE: [[1.76970493]]\n",
+ "Learner ml_m RMSE: [[0.94373673]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -518,93 +260,91 @@
"\n",
"------------------ Fit Summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ "d 0.480439 0.020407 23.542639 1.493410e-122 0.440442 0.520436\n",
+ "\n",
+ "------------------ Additional Information ------------------\n",
+ "Post Transformation Data Summary:\n",
+ "\n",
+ "Outcome variable: y\n",
+ "Treatment variable(s): ['d']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_mean', 'x2_mean', 'x3_mean', 'x4_mean', 'x5_mean', 'x6_mean', 'x7_mean', 'x8_mean', 'x9_mean', 'x10_mean', 'x11_mean', 'x12_mean', 'x13_mean', 'x14_mean', 'x15_mean', 'x16_mean', 'x17_mean', 'x18_mean', 'x19_mean', 'x20_mean', 'x21_mean', 'x22_mean', 'x23_mean', 'x24_mean', 'x25_mean', 'x26_mean', 'x27_mean', 'x28_mean', 'x29_mean', 'x30_mean']\n",
+ "No. Observations: 2500\n",
+ "\n"
]
}
],
"source": [
- "print(dml_panel_plpr)"
+ "print(dml_plpr_obj)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
- "id": "48d4dbd8",
+ "execution_count": 16,
+ "id": "24f06d62",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "================== DoubleMLPLPR Object ==================\n",
- "\n",
- "------------------ Data summary ------------------\n",
- "Outcome variable: y\n",
- "Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'm_x1', 'm_x2', 'm_x3', 'm_x4', 'm_x5', 'm_x6', 'm_x7', 'm_x8', 'm_x9', 'm_x10', 'm_x11', 'm_x12', 'm_x13', 'm_x14', 'm_x15', 'm_x16', 'm_x17', 'm_x18', 'm_x19', 'm_x20', 'm_x21', 'm_x22', 'm_x23', 'm_x24', 'm_x25', 'm_x26', 'm_x27', 'm_x28', 'm_x29', 'm_x30']\n",
- "Instrument variable(s): None\n",
- "Time variable: time\n",
- "Id variable: id\n",
- "Static panel data: True\n",
- "No. Unique Ids: 250\n",
- "No. Observations: 2500\n",
- "\n",
- "------------------ Score & algorithm ------------------\n",
- "Score function: partialling out\n",
- "Static panel model approach: cre_general\n",
- "\n",
- "------------------ Machine learner ------------------\n",
- "Learner ml_l: LassoCV()\n",
- "Learner ml_m: LassoCV()\n",
- "Out-of-sample Performance:\n",
- "Regression:\n",
- "Learner ml_l RMSE: [[1.72960098]]\n",
- "Learner ml_m RMSE: [[0.95035703]]\n",
- "\n",
- "------------------ Resampling ------------------\n",
- "No. folds per cluster: 5\n",
- "No. folds: 5\n",
- "No. repeated sample splits: 1\n",
- "\n",
- "------------------ Fit summary ------------------\n",
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498106 0.020829 23.913515 2.215861e-126 0.457281 0.538931\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.503726 0.022932 21.965879 6.106437e-107 0.45878 0.548672\n"
]
}
],
"source": [
- "print(dml_panel_plpr)"
+ "# cre normality assumption\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "cre_data = cre_fct(data)\n",
+ "\n",
+ "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
+ "\n",
+ "panel_data_obj = DoubleMLPanelData(cre_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
+ "\n",
+ "# learner = LassoCV()\n",
+ "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='cre_normal')\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "24f06d62",
+ "execution_count": null,
+ "id": "135a38ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.503726 0.022932 21.965879 6.106437e-107 0.45878 0.548672\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.535179 0.020567 26.021098 2.858305e-149 0.494868 0.57549\n"
]
}
],
"source": [
- "# cre normality assumption\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
"\n",
- "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
+ "x_cols = [col for col in data.columns if \"x\" in col]\n",
"\n",
- "obj_dml_data_pdml = DoubleMLPanelData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=x_cols,\n",
- " static_panel=True)\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=x_cols,\n",
+ " static_panel=True)\n",
"\n",
"# learner = LassoCV()\n",
"learner = make_pipeline(StandardScaler(), LassoCV())\n",
@@ -612,9 +352,9 @@
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='cre_normal')\n",
- "obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)"
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='cre_normal')\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
@@ -636,63 +376,57 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"fd_data = fd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLPanelData(fd_data,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
- " static_panel=True)\n",
+ "panel_data_obj = DoubleMLPanelData(fd_data,\n",
+ " y_col='y_diff',\n",
+ " d_cols='d_diff',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='fd_exact')\n",
- "obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)"
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='fd_exact')\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "ff0e322b",
+ "execution_count": null,
+ "id": "d869e493",
"metadata": {},
"outputs": [
{
- "ename": "AssertionError",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
- "\u001b[31mAssertionError\u001b[39m Traceback (most recent call last)",
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 13\u001b[39m ml_m = clone(learner)\n\u001b[32m 15\u001b[39m obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach=\u001b[33m'\u001b[39m\u001b[33mfd_exact\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[43mobj_dml_plpr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[38;5;28mprint\u001b[39m(obj_dml_plpr.summary)\n",
- "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\double_ml.py:570\u001b[39m, in \u001b[36mDoubleML.fit\u001b[39m\u001b[34m(self, n_jobs_cv, store_predictions, external_predictions, store_models)\u001b[39m\n\u001b[32m 567\u001b[39m \u001b[38;5;28mself\u001b[39m._dml_data.set_x_d(\u001b[38;5;28mself\u001b[39m._dml_data.d_cols[i_d])\n\u001b[32m 569\u001b[39m \u001b[38;5;66;03m# predictions have to be stored in loop for sensitivity analysis\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m570\u001b[39m nuisance_predictions = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_fit_nuisance_and_score_elements\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 571\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstore_predictions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexternal_predictions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstore_models\u001b[49m\n\u001b[32m 572\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 574\u001b[39m \u001b[38;5;28mself\u001b[39m._solve_score_and_estimate_se()\n\u001b[32m 576\u001b[39m \u001b[38;5;66;03m# sensitivity elements can depend on the estimated parameter\u001b[39;00m\n",
- "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\double_ml.py:1088\u001b[39m, in \u001b[36mDoubleML._fit_nuisance_and_score_elements\u001b[39m\u001b[34m(self, n_jobs_cv, store_predictions, external_predictions, store_models)\u001b[39m\n\u001b[32m 1083\u001b[39m ext_prediction_dict = _set_external_predictions(\n\u001b[32m 1084\u001b[39m external_predictions, learners=\u001b[38;5;28mself\u001b[39m.params_names, treatment=\u001b[38;5;28mself\u001b[39m._dml_data.d_cols[\u001b[38;5;28mself\u001b[39m._i_treat], i_rep=\u001b[38;5;28mself\u001b[39m._i_rep\n\u001b[32m 1085\u001b[39m )\n\u001b[32m 1087\u001b[39m \u001b[38;5;66;03m# ml estimation of nuisance models and computation of score elements\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1088\u001b[39m score_elements, preds = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_nuisance_est\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1089\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__smpls\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexternal_predictions\u001b[49m\u001b[43m=\u001b[49m\u001b[43mext_prediction_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstore_models\u001b[49m\n\u001b[32m 1090\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1092\u001b[39m \u001b[38;5;28mself\u001b[39m._set_score_elements(score_elements, \u001b[38;5;28mself\u001b[39m._i_rep, \u001b[38;5;28mself\u001b[39m._i_treat)\n\u001b[32m 1094\u001b[39m \u001b[38;5;66;03m# calculate nuisance losses and store predictions and targets of the nuisance models\u001b[39;00m\n",
- "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\plm\\plpr.py:254\u001b[39m, in \u001b[36mDoubleMLPLPR._nuisance_est\u001b[39m\u001b[34m(self, smpls, n_jobs_cv, external_predictions, return_models)\u001b[39m\n\u001b[32m 252\u001b[39m l_hat = {\u001b[33m\"\u001b[39m\u001b[33mpreds\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[33m\"\u001b[39m\u001b[33mtargets\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[33m\"\u001b[39m\u001b[33mmodels\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m}\n\u001b[32m 253\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m254\u001b[39m l_hat = \u001b[43m_dml_cv_predict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 255\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_learner\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 257\u001b[39m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 258\u001b[39m \u001b[43m \u001b[49m\u001b[43msmpls\u001b[49m\u001b[43m=\u001b[49m\u001b[43msmpls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 259\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_jobs_cv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 260\u001b[39m \u001b[43m \u001b[49m\u001b[43mest_params\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_params\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 261\u001b[39m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_predict_method\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mml_l\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 262\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_models\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 263\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 264\u001b[39m _check_finite_predictions(l_hat[\u001b[33m\"\u001b[39m\u001b[33mpreds\u001b[39m\u001b[33m\"\u001b[39m], \u001b[38;5;28mself\u001b[39m._learner[\u001b[33m\"\u001b[39m\u001b[33mml_l\u001b[39m\u001b[33m\"\u001b[39m], \u001b[33m\"\u001b[39m\u001b[33mml_l\u001b[39m\u001b[33m\"\u001b[39m, smpls)\n\u001b[32m 266\u001b[39m \u001b[38;5;66;03m# nuisance m\u001b[39;00m\n",
- "\u001b[36mFile \u001b[39m\u001b[32m~\\PycharmProjects\\doubleml-for-py\\doubleml\\utils\\_estimation.py:76\u001b[39m, in \u001b[36m_dml_cv_predict\u001b[39m\u001b[34m(estimator, x, y, smpls, n_jobs, est_params, method, return_train_preds, return_models)\u001b[39m\n\u001b[32m 74\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m smpls_is_partition:\n\u001b[32m 75\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m fold_specific_target, \u001b[33m\"\u001b[39m\u001b[33mcombination of fold-specific y and no cross-fitting not implemented yet\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m76\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(smpls) == \u001b[32m1\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m method == \u001b[33m\"\u001b[39m\u001b[33mpredict_proba\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m fold_specific_target \u001b[38;5;66;03m# fold_specific_target only needed for PLIV.partialXZ\u001b[39;00m\n",
- "\u001b[31mAssertionError\u001b[39m: "
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.493737 0.024198 20.403722 1.549613e-92 0.446309 0.541165\n"
]
}
],
"source": [
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"\n",
- "obj_dml_data_pdml = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
- " static_panel=True)\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='fd_exact')\n",
- "obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)"
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='fd_exact')\n",
+ "\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
@@ -714,21 +448,56 @@
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"wd_data = wd_fct(data)\n",
"\n",
- "obj_dml_data_pdml = DoubleMLPanelData(wd_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")],\n",
- " static_panel=True)\n",
+ "panel_data_obj = DoubleMLPanelData(wd_data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "obj_dml_plpr = DoubleMLPLPR(obj_dml_data_pdml, ml_l, ml_m, approach='wg_approx')\n",
- "obj_dml_plpr.fit()\n",
- "print(obj_dml_plpr.summary)"
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "579b43fa",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.498587 0.020292 24.570284 2.625882e-133 0.458815 0.538359\n"
+ ]
+ }
+ ],
+ "source": [
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
+ " static_panel=True)\n",
+ "\n",
+ "learner = LassoCV()\n",
+ "ml_l = clone(learner)\n",
+ "ml_m = clone(learner)\n",
+ "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
+ "dml_plpr_obj.fit()\n",
+ "print(dml_plpr_obj.summary)"
]
},
{
@@ -892,6 +661,158 @@
" index=['CRE general', 'CRE normal', \n",
" 'FD exact', 'WG approx'])"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "18127ee6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Coef | \n",
+ " Bias | \n",
+ " Coverage | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | CRE general | \n",
+ " 0.513181 | \n",
+ " 0.013181 | \n",
+ " 0.93 | \n",
+ "
\n",
+ " \n",
+ " | CRE normal | \n",
+ " 0.534498 | \n",
+ " 0.034498 | \n",
+ " 0.87 | \n",
+ "
\n",
+ " \n",
+ " | FD exact | \n",
+ " 0.501095 | \n",
+ " 0.001095 | \n",
+ " 0.94 | \n",
+ "
\n",
+ " \n",
+ " | WG approx | \n",
+ " 0.497307 | \n",
+ " -0.002693 | \n",
+ " 0.94 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Coef Bias Coverage\n",
+ "CRE general 0.513181 0.013181 0.93\n",
+ "CRE normal 0.534498 0.034498 0.87\n",
+ "FD exact 0.501095 0.001095 0.94\n",
+ "WG approx 0.497307 -0.002693 0.94"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# simulation with built-in transformations\n",
+ "\n",
+ "n_reps = 100\n",
+ "theta = 0.5\n",
+ "\n",
+ "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "res_cre_general = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
+ "res_fd = np.full((n_reps, 3), np.nan)\n",
+ "res_wd = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "np.random.seed(1)\n",
+ "\n",
+ "for i in range(n_reps):\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
+ "\n",
+ " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ " \n",
+ " # CRE general Lasso\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='cre_general')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # CRE normality\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='cre_normal')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # FD approach\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='fd_exact')\n",
+ " dml_plpr.fit()\n",
+ " res_fd[i, 0] = dml_plpr.coef[0]\n",
+ " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " \n",
+ " # WD approach, for now need new data object as FD approach overwrites _cluster_vars \n",
+ " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ "\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='wg_approx')\n",
+ " dml_plpr.fit()\n",
+ " res_wd[i, 0] = dml_plpr.coef[0]\n",
+ " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ "\n",
+ "pd.DataFrame(np.vstack([res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
+ " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
+ " columns=['Coef', 'Bias', 'Coverage'], \n",
+ " index=['CRE general', 'CRE normal', \n",
+ " 'FD exact', 'WG approx'])"
+ ]
}
],
"metadata": {
@@ -910,7 +831,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.6"
+ "version": "3.13.9"
}
},
"nbformat": 4,
From 3ccfb34060759af95fc82014de2ca74c95f6fa3d Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Tue, 11 Nov 2025 17:25:09 +0100
Subject: [PATCH 20/33] clearer TODO description
---
doubleml/plm/plpr.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index c301b6468..18a45518d 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -272,12 +272,13 @@ def _initialize_fd_model(self):
if self._approach == "fd_exact":
self._smpls = None
self._smpls_cluster = None
- # TODO: # overwrites data property _cluster_vars, but then data object can't be reused with other approach
- # when using a new model specific _cluster_vars_fd, _se_causal_pars() does not run as it uses
- # self._dml_data.cluster_vars, where n_obs dimension does not match psi arrays.
+ # TODO: currently overwrites data property _cluster_vars, but then the data object can't be reused with other approaches.
+ # When using a new model specific _cluster_vars_fd, like
+ # self._cluster_vars_fd = self._data_transform.loc[:, self._dml_data.cluster_cols].values
+ # _se_causal_pars() does not run anymore as it uses self._dml_data.cluster_vars, where n_obs dimension does not match
+ # dimension of psi arrays.
# overwrite _se_causal_pars?
self._dml_data._cluster_vars = self._data_transform.loc[:, self._dml_data.cluster_cols]
- self._cluster_vars_fd = self._data_transform.loc[:, self._dml_data.cluster_cols].values
# initialize model again
self._score_dim = (self._data_transform.shape[0], self.n_rep, self._dml_data.n_coefs)
From ab9f83f2cd445506f7a490bd66d583e4f1d9563e Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Thu, 13 Nov 2025 10:09:00 +0100
Subject: [PATCH 21/33] move data transformation before init
---
doubleml/plm/plpr.py | 139 ++++----
doubleml/plm/sim/example_sim.ipynb | 515 ++++++++++++++++-------------
2 files changed, 345 insertions(+), 309 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 18a45518d..768850c75 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -10,7 +10,6 @@
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_binary_predictions, _check_finite_predictions, _check_is_propensity, _check_score
from ..utils._estimation import _dml_cv_predict, _dml_tune
-from ..utils.resampling import DoubleMLClusterResampling
class DoubleMLPLPR(LinearScoreMixin, DoubleML):
@@ -86,17 +85,32 @@ def __init__(
approach="fd_exact",
draw_sample_splitting=True,
):
- super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)
-
- self._check_data(self._dml_data)
- # TODO: assert cluster?
- valid_scores = ["IV-type", "partialling out"]
- _check_score(self.score, valid_scores, allow_callable=True)
+ self._check_data(obj_dml_data)
+ self._original_dml_data = obj_dml_data
valid_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
self._check_approach(approach, valid_approach)
self._approach = approach
+ # pass transformed data as DoubleMLPanelData to init
+ self._data_transform, self._transform_cols = self._transform_data()
+ obj_dml_data_transform = DoubleMLPanelData(
+ self._data_transform,
+ y_col=self._transform_cols["y_col"],
+ d_cols=self._transform_cols["d_cols"],
+ t_col=self._original_dml_data._t_col,
+ id_col=self._original_dml_data._id_col,
+ x_cols=self._transform_cols["x_cols"],
+ z_cols=self._original_dml_data._z_cols,
+ static_panel=True,
+ use_other_treat_as_covariate=self._original_dml_data._use_other_treat_as_covariate,
+ force_all_x_finite=self._original_dml_data._force_all_x_finite,
+ )
+ super().__init__(obj_dml_data_transform, n_folds, n_rep, score, draw_sample_splitting)
+
+ valid_scores = ["IV-type", "partialling out"]
+ _check_score(self.score, valid_scores, allow_callable=True)
+ # TODO: update learner checks
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
self._learner = {"ml_l": ml_l, "ml_m": ml_m}
@@ -135,42 +149,27 @@ def __init__(
self._sensitivity_implemented = False
self._external_predictions_implemented = True
- # get transformed data depending on approach
- (
- self._data_transform,
- self._transform_col_names,
- ) = self._transform_data(self._approach)
- # save transformed data parts for ML estimation
- # TODO: check d_cols dimension issue, for now as for panel data only one treatment allowed currently
- self._y_data_transform = self.data_transform.loc[:, self.transform_col_names["y_col"]].values
- self._d_data_transform = self.data_transform.loc[:, self.transform_col_names["d_cols"]].values.flatten()
- self._x_data_transform = self.data_transform.loc[:, self.transform_col_names["x_cols"]].values
-
- # TODO: for fd_exact, n_obs changes, smpls originally drawn are not working anymore
- self._n_obs_transform = self._data_transform.shape[0]
- self._initialize_fd_model()
-
def _format_score_info_str(self):
score_approach_info = f"Score function: {str(self.score)}\n" f"Static panel model approach: {str(self.approach)}"
return score_approach_info
def _format_additional_info_str(self):
"""
- Includes information on the transformed features based on the estimation approach.
+ Includes information on the original data before transformation.
"""
- data_transform_summary = (
- f"Post Transformation Data Summary:\n\n"
- f"Outcome variable: {self.transform_col_names['y_col']}\n"
- f"Treatment variable(s): {self.transform_col_names['d_cols']}\n"
- f"Covariates: {self.transform_col_names['x_cols']}\n"
- f"No. Observations: {self._n_obs_transform}\n"
+ data_original_summary = (
+ f"Original Data Summary Pre-transformation:\n\n"
+ f"Outcome variable: {self._original_dml_data.y_col}\n"
+ f"Treatment variable(s): {self._original_dml_data.d_cols}\n"
+ f"Covariates: {self._original_dml_data.x_cols}\n"
+ f"No. Observations: {self._original_dml_data.n_obs}\n"
)
- return data_transform_summary
+ return data_original_summary
@property
def approach(self):
"""
- The score function.
+ The static panel approach.
"""
return self._approach
@@ -182,18 +181,11 @@ def data_transform(self):
return self._data_transform
@property
- def transform_col_names(self):
+ def transform_cols(self):
"""
The column names of the transformed static panel data.
"""
- return self._transform_col_names
-
- @property
- def n_obs_transform(self):
- """
- The number of observations after data transformation.
- """
- return self._n_obs_transform
+ return self._transform_cols
def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}
@@ -220,21 +212,21 @@ def _check_approach(self, approach, valid_approach):
raise TypeError(f"approach should be a string. {str(approach)} was passed.")
return
- def _transform_data(self, approach):
- df = self._dml_data.data.copy()
+ def _transform_data(self):
+ df = self._original_dml_data.data.copy()
- y_col = self._dml_data.y_col
- d_cols = self._dml_data.d_cols
- x_cols = self._dml_data.x_cols
- t_col = self._dml_data.t_col
- id_col = self._dml_data.id_col
+ y_col = self._original_dml_data.y_col
+ d_cols = self._original_dml_data.d_cols
+ x_cols = self._original_dml_data.x_cols
+ t_col = self._original_dml_data.t_col
+ id_col = self._original_dml_data.id_col
- if approach in ["cre_general", "cre_normal"]:
+ if self._approach in ["cre_general", "cre_normal"]:
df_id_means = df[[id_col] + x_cols].groupby(id_col).transform("mean")
df_means = df_id_means.add_suffix("_mean")
data = pd.concat([df, df_means], axis=1)
- col_names = {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"{x}_mean" for x in x_cols]}
- elif approach == "fd_exact":
+ cols = {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"{x}_mean" for x in x_cols]}
+ elif self._approach == "fd_exact":
# TODO: potential issues with unbalanced panels/missing periods, right now the
# last available is used for the lag and first difference. Maybe reindex to a complete time grid per id.
df = df.sort_values([id_col, t_col])
@@ -247,12 +239,12 @@ def _transform_data(self, approach):
df_fd = df_fd.rename(columns=cols_rename_dict)
# drop rows for first period
data = df_fd.dropna(subset=[x_cols[0] + "_lag"]).reset_index(drop=True)
- col_names = {
+ cols = {
"y_col": f"{y_col}_diff",
"d_cols": [f"{d}_diff" for d in d_cols],
"x_cols": x_cols + [f"{x}_lag" for x in x_cols],
}
- elif approach == "wg_approx":
+ elif self._approach == "wg_approx":
cols_to_demean = [y_col] + d_cols + x_cols
# compute group and grand means for within means
group_means = df.groupby(id_col)[cols_to_demean].transform("mean")
@@ -260,42 +252,17 @@ def _transform_data(self, approach):
within_means = df[cols_to_demean] - group_means + grand_means
within_means = within_means.add_suffix("_demean")
data = pd.concat([df[[id_col, t_col]], within_means], axis=1)
- col_names = {
+ cols = {
"y_col": f"{y_col}_demean",
"d_cols": [f"{d}_demean" for d in d_cols],
"x_cols": [f"{x}_demean" for x in x_cols],
}
- return data, col_names
-
- def _initialize_fd_model(self):
- if self._approach == "fd_exact":
- self._smpls = None
- self._smpls_cluster = None
- # TODO: currently overwrites data property _cluster_vars, but then the data object can't be reused with other approaches.
- # When using a new model specific _cluster_vars_fd, like
- # self._cluster_vars_fd = self._data_transform.loc[:, self._dml_data.cluster_cols].values
- # _se_causal_pars() does not run anymore as it uses self._dml_data.cluster_vars, where n_obs dimension does not match
- # dimension of psi arrays.
- # overwrite _se_causal_pars?
- self._dml_data._cluster_vars = self._data_transform.loc[:, self._dml_data.cluster_cols]
-
- # initialize model again
- self._score_dim = (self._data_transform.shape[0], self.n_rep, self._dml_data.n_coefs)
- self._initialize_dml_model()
- # draw smpls for first difference transformed data
- obj_dml_resampling = DoubleMLClusterResampling(
- n_folds=self._n_folds_per_cluster,
- n_rep=self.n_rep,
- n_obs=self._n_obs_transform,
- n_cluster_vars=self._dml_data.n_cluster_vars,
- cluster_vars=self._dml_data.cluster_vars,
- )
- self._smpls, self._smpls_cluster = obj_dml_resampling.split_samples()
+ return data, cols
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
- x, y = check_X_y(self._x_data_transform, self._y_data_transform, force_all_finite=False)
- x, d = check_X_y(x, self._d_data_transform, force_all_finite=False)
+ x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
+ x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
m_external = external_predictions["ml_m"] is not None
l_external = external_predictions["ml_l"] is not None
if "ml_g" in self._learner:
@@ -413,7 +380,15 @@ def _sensitivity_element_est(self, preds):
pass
def _nuisance_tuning(
- self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search
+ # TODO: include mean_x for cre approach
+ self,
+ smpls,
+ param_grids,
+ scoring_methods,
+ n_folds_tune,
+ n_jobs_cv,
+ search_mode,
+ n_iter_randomized_search,
):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index 16f4036e5..e4eb0c6f0 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -83,6 +83,8 @@
"from sklearn.linear_model import LassoCV, LinearRegression\n",
"from sklearn.base import clone\n",
"from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.preprocessing import PolynomialFeatures\n",
+ "from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import make_pipeline\n",
"from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct, extend_data\n",
"from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
@@ -92,7 +94,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"id": "3c061cf1",
"metadata": {},
"outputs": [
@@ -104,7 +106,7 @@
" np.float64(0.7348254130670423))"
]
},
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -125,48 +127,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "73c6599b",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.485919 0.021214 22.905716 4.075038e-116 0.44434 0.527497\n"
- ]
- }
- ],
- "source": [
- "# cre general\n",
- "\n",
- "# np.random.seed(1)\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
- "\n",
- "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
- "\n",
- "learner = LassoCV()\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=x_cols,\n",
- " static_panel=True)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "44387af3",
"metadata": {},
"outputs": [
@@ -175,14 +136,14 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.480439 0.020407 23.542639 1.493410e-122 0.440442 0.520436\n"
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
]
}
],
"source": [
+ "# cre general\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "\n",
- "x_cols = [col for col in data.columns if \"x\" in col]\n",
+ "cre_data = cre_fct(data)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
@@ -193,7 +154,7 @@
" d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=x_cols,\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
"dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
@@ -204,7 +165,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 5,
"id": "7deabe55",
"metadata": {},
"outputs": [],
@@ -232,7 +193,7 @@
"------------------ Data Summary ------------------\n",
"Outcome variable: y\n",
"Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_mean', 'x2_mean', 'x3_mean', 'x4_mean', 'x5_mean', 'x6_mean', 'x7_mean', 'x8_mean', 'x9_mean', 'x10_mean', 'x11_mean', 'x12_mean', 'x13_mean', 'x14_mean', 'x15_mean', 'x16_mean', 'x17_mean', 'x18_mean', 'x19_mean', 'x20_mean', 'x21_mean', 'x22_mean', 'x23_mean', 'x24_mean', 'x25_mean', 'x26_mean', 'x27_mean', 'x28_mean', 'x29_mean', 'x30_mean']\n",
"Instrument variable(s): None\n",
"Time variable: time\n",
"Id variable: id\n",
@@ -250,8 +211,8 @@
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.76970493]]\n",
- "Learner ml_m RMSE: [[0.94373673]]\n",
+ "Learner ml_l RMSE: [[1.63784321]]\n",
+ "Learner ml_m RMSE: [[0.96294553]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -260,14 +221,14 @@
"\n",
"------------------ Fit Summary ------------------\n",
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.480439 0.020407 23.542639 1.493410e-122 0.440442 0.520436\n",
+ "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n",
"\n",
"------------------ Additional Information ------------------\n",
- "Post Transformation Data Summary:\n",
+ "Original Data Summary Pre-transformation:\n",
"\n",
"Outcome variable: y\n",
"Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_mean', 'x2_mean', 'x3_mean', 'x4_mean', 'x5_mean', 'x6_mean', 'x7_mean', 'x8_mean', 'x9_mean', 'x10_mean', 'x11_mean', 'x12_mean', 'x13_mean', 'x14_mean', 'x15_mean', 'x16_mean', 'x17_mean', 'x18_mean', 'x19_mean', 'x20_mean', 'x21_mean', 'x22_mean', 'x23_mean', 'x24_mean', 'x25_mean', 'x26_mean', 'x27_mean', 'x28_mean', 'x29_mean', 'x30_mean']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30']\n",
"No. Observations: 2500\n",
"\n"
]
@@ -279,45 +240,90 @@
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "24f06d62",
+ "execution_count": 56,
+ "id": "8731057f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.503726 0.022932 21.965879 6.106437e-107 0.45878 0.548672\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.551055 0.020727 26.586753 9.659490e-156 0.510432 0.591679\n"
]
}
],
"source": [
- "# cre normality assumption\n",
+ "# cre general, extend features\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
"\n",
- "x_cols = [col for col in cre_data.columns if \"x\" in col]\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(cre_data,\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
" d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=x_cols,\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
- "# learner = LassoCV()\n",
- "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "n_features = len(panel_data_obj.x_cols)\n",
+ "#\n",
+ "indices_x = [i for i in range(n_features)]\n",
+ "indices_x_mean = [i for i in range(n_features, 2 * n_features)]\n",
+ "\n",
+ "preprocessor = ColumnTransformer([\n",
+ " ('poly_x', make_pipeline(\n",
+ " PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ " ), indices_x), \n",
+ " ('poly_x_mean', make_pipeline(\n",
+ " PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ " ), indices_x_mean) \n",
+ "])\n",
+ "\n",
+ "learner = make_pipeline(\n",
+ " preprocessor,\n",
+ " StandardScaler(),\n",
+ " LassoCV()\n",
+ ")\n",
"\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='cre_normal')\n",
- "dml_plpr_obj.fit()\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
+ "dml_plpr_obj.fit(store_models=True)\n",
"print(dml_plpr_obj.summary)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "id": "d7936e42",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "990"
+ ]
+ },
+ "execution_count": 62,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dml_plpr_obj.models['ml_l']['d'][0][0].named_steps['lassocv'].n_features_in_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "26039cea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# dml_plpr_obj.models['ml_l']['d'][0][0].named_steps['columntransformer']['poly_x'].get_feature_names_out()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -328,22 +334,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.535179 0.020567 26.021098 2.858305e-149 0.494868 0.57549\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.548085 0.02097 26.136802 1.392306e-150 0.506985 0.589186\n"
]
}
],
"source": [
+ "# cre normality assumption\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "\n",
- "x_cols = [col for col in data.columns if \"x\" in col]\n",
+ "cre_data = cre_fct(data)\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
" d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=x_cols,\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
"# learner = LassoCV()\n",
@@ -359,29 +365,30 @@
},
{
"cell_type": "code",
- "execution_count": 25,
- "id": "61a72563",
+ "execution_count": 8,
+ "id": "d869e493",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.512959 0.025253 20.313176 9.833638e-92 0.463465 0.562453\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.492487 0.025352 19.42565 4.684180e-84 0.442797 0.542176\n"
]
}
],
"source": [
+ "# First difference approach\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
"fd_data = fd_fct(data)\n",
"\n",
- "panel_data_obj = DoubleMLPanelData(fd_data,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
@@ -389,6 +396,7 @@
"ml_m = clone(learner)\n",
"\n",
"dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='fd_exact')\n",
+ "\n",
"dml_plpr_obj.fit()\n",
"print(dml_plpr_obj.summary)"
]
@@ -396,196 +404,196 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d869e493",
+ "id": "579b43fa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.493737 0.024198 20.403722 1.549613e-92 0.446309 0.541165\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_demean 0.545586 0.021263 25.659316 3.327986e-145 0.503912 0.58726\n"
]
}
],
"source": [
+ "# Within group approach\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "wd_data = wd_fct(data)\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
" d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
"learner = LassoCV()\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='fd_exact')\n",
- "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
"dml_plpr_obj.fit()\n",
"print(dml_plpr_obj.summary)"
]
},
{
"cell_type": "code",
- "execution_count": 31,
- "id": "aeb00efe",
+ "execution_count": 19,
+ "id": "32e5bc9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.513121 0.020666 24.828987 4.361657e-136 0.472616 0.553626\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_demean 0.472861 0.022779 20.758218 1.033280e-95 0.428214 0.517507\n"
]
}
],
"source": [
+ "# Within group approach, polynomials\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
- "wd_data = wd_fct(data)\n",
"\n",
- "panel_data_obj = DoubleMLPanelData(wd_data,\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
" d_cols='d',\n",
" t_col='time',\n",
" id_col='id',\n",
- " x_cols=[col for col in wd_data.columns if col.startswith(\"x\")],\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
- "learner = LassoCV()\n",
+ "# learner = LassoCV()\n",
+ "\n",
+ "# preprocessor = ColumnTransformer([\n",
+ "# ('poly', make_pipeline(\n",
+ "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),\n",
+ "# StandardScaler()\n",
+ "# ), ['x1', 'x2']), # Columns to expand\n",
+ "# ('pass', 'passthrough', ['cat']) # Columns to keep unchanged\n",
+ "# ])\n",
+ "\n",
+ "# learner = make_pipeline(\n",
+ "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),\n",
+ "# StandardScaler(),\n",
+ "# LassoCV()\n",
+ "# )\n",
+ "\n",
+ "preprocessor = ColumnTransformer([\n",
+ " ('poly', make_pipeline(\n",
+ " PolynomialFeatures(degree=2, include_bias=False)\n",
+ " ), [0, 1])\n",
+ "], remainder='passthrough')\n",
+ "\n",
+ "learner = make_pipeline(\n",
+ " preprocessor,\n",
+ " StandardScaler(),\n",
+ " LassoCV()\n",
+ ")\n",
+ "\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
"dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
+ "dml_plpr_obj.fit(store_models=True)\n",
+ "print(dml_plpr_obj.summary)\n",
+ "\n",
+ "# dml_plpr_obj.transform_cols['x_cols']"
]
},
{
"cell_type": "code",
- "execution_count": 18,
- "id": "579b43fa",
+ "execution_count": 41,
+ "id": "c43c504b",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.498587 0.020292 24.570284 2.625882e-133 0.458815 0.538359\n"
- ]
+ "data": {
+ "text/plain": [
+ "[0, 1, 11]"
+ ]
+ },
+ "execution_count": 41,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "x_cols_tranform = dml_plpr_obj.transform_cols['x_cols']\n",
"\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if col.startswith(\"x\")],\n",
- " static_panel=True)\n",
+ "x_cols_for_poly = ['x1_demean', 'x2_demean', 'x12_demean']\n",
"\n",
- "learner = LassoCV()\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
+ "indices = [i for i, c in enumerate(x_cols_tranform) if c in x_cols_for_poly]\n",
+ "indices"
]
},
{
"cell_type": "code",
- "execution_count": 20,
- "id": "586d5edf",
+ "execution_count": 32,
+ "id": "1a0e629d",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
+ "data": {
+ "text/plain": [
+ "33"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "n_reps = 100\n",
- "theta = 0.5\n",
- "\n",
- "learner = make_pipeline(StandardScaler(), LassoCV())\n",
- "\n",
- "res_cre_general = np.full((n_reps, 3), np.nan)\n",
- "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
- "res_fd = np.full((n_reps, 3), np.nan)\n",
- "res_wd = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "np.random.seed(1)\n",
- "\n",
- "for i in range(n_reps):\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
- "\n",
- " # CRE general Lasso\n",
- " cre_data = cre_fct(data)\n",
- " dml_data = DoubleMLPanelData(cre_data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_general')\n",
- " dml_plpr.fit()\n",
- " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # CRE normality\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_normal')\n",
- " dml_plpr.fit()\n",
- " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # FD approach\n",
- " fd_data = fd_fct(data)\n",
- " dml_data = DoubleMLPanelData(fd_data, y_col='y_diff', d_cols='d_diff', t_col='time', id_col='id',\n",
- " x_cols=[col for col in fd_data.columns if col.startswith(\"x\")],\n",
- " static_panel=True)\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='fd_exact')\n",
- " dml_plpr.fit()\n",
- " res_fd[i, 0] = dml_plpr.coef[0]\n",
- " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- " \n",
- " # WD approach\n",
- " wd_data = wd_fct(data)\n",
- " dml_data = DoubleMLPanelData(wd_data, y_col='y', d_cols='d', t_col='time', id_col='id',\n",
- " x_cols=[col for col in wd_data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='wg_approx')\n",
- " dml_plpr.fit()\n",
- " res_wd[i, 0] = dml_plpr.coef[0]\n",
- " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
+ "dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['lassocv'].n_features_in_"
]
},
{
"cell_type": "code",
- "execution_count": 21,
- "id": "33119186",
+ "execution_count": 38,
+ "id": "68f20f79",
"metadata": {},
"outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array(['x0', 'x1', 'x0^2', 'x0 x1', 'x1^2'], dtype=object)"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['columntransformer']['poly'].get_feature_names_out()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "94e89dc1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['polynomialfeatures'].get_feature_names_out()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5a9c137",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ },
{
"data": {
"text/html": [
@@ -649,23 +657,82 @@
"WG approx 0.502006 0.002006 0.94"
]
},
- "execution_count": 21,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
+ "# simulation with built-in transformations\n",
+ "\n",
+ "n_reps = 100\n",
+ "theta = 0.5\n",
+ "\n",
+ "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "\n",
+ "res_cre_general = np.full((n_reps, 3), np.nan)\n",
+ "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
+ "res_fd = np.full((n_reps, 3), np.nan)\n",
+ "res_wd = np.full((n_reps, 3), np.nan)\n",
+ "\n",
+ "np.random.seed(1)\n",
+ "\n",
+ "for i in range(n_reps):\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
+ "\n",
+ " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ " \n",
+ " # CRE general Lasso\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='cre_general')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # CRE normality\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='cre_normal')\n",
+ " dml_plpr.fit()\n",
+ " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
+ " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ " # FD approach\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='fd_exact')\n",
+ " dml_plpr.fit()\n",
+ " res_fd[i, 0] = dml_plpr.coef[0]\n",
+ " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " \n",
+ " # WD approach, for now need new data object as FD approach overwrites _cluster_vars\n",
+ " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
+ " approach='wg_approx')\n",
+ " dml_plpr.fit()\n",
+ " res_wd[i, 0] = dml_plpr.coef[0]\n",
+ " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
+ " confint = dml_plpr.confint()\n",
+ " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ "\n",
+ "\n",
"pd.DataFrame(np.vstack([res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
- " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
- " columns=['Coef', 'Bias', 'Coverage'], \n",
- " index=['CRE general', 'CRE normal', \n",
- " 'FD exact', 'WG approx'])"
+ " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
+ " columns=['Coef', 'Bias', 'Coverage'], \n",
+ " index=['CRE general', 'CRE normal', \n",
+ " 'FD exact', 'WG approx'])"
]
},
{
"cell_type": "code",
- "execution_count": 33,
- "id": "18127ee6",
+ "execution_count": 13,
+ "id": "24ef531c",
"metadata": {},
"outputs": [
{
@@ -704,27 +771,27 @@
" \n",
" \n",
" | CRE general | \n",
- " 0.513181 | \n",
- " 0.013181 | \n",
- " 0.93 | \n",
+ " 0.498318 | \n",
+ " -0.001682 | \n",
+ " 0.94 | \n",
"
\n",
" \n",
" | CRE normal | \n",
- " 0.534498 | \n",
- " 0.034498 | \n",
- " 0.87 | \n",
+ " 0.497383 | \n",
+ " -0.002617 | \n",
+ " 0.96 | \n",
"
\n",
" \n",
" | FD exact | \n",
- " 0.501095 | \n",
- " 0.001095 | \n",
- " 0.94 | \n",
+ " 0.494321 | \n",
+ " -0.005679 | \n",
+ " 0.93 | \n",
"
\n",
" \n",
" | WG approx | \n",
- " 0.497307 | \n",
- " -0.002693 | \n",
- " 0.94 | \n",
+ " 0.497754 | \n",
+ " -0.002246 | \n",
+ " 0.95 | \n",
"
\n",
" \n",
"\n",
@@ -732,31 +799,29 @@
],
"text/plain": [
" Coef Bias Coverage\n",
- "CRE general 0.513181 0.013181 0.93\n",
- "CRE normal 0.534498 0.034498 0.87\n",
- "FD exact 0.501095 0.001095 0.94\n",
- "WG approx 0.497307 -0.002693 0.94"
+ "CRE general 0.498318 -0.001682 0.94\n",
+ "CRE normal 0.497383 -0.002617 0.96\n",
+ "FD exact 0.494321 -0.005679 0.93\n",
+ "WG approx 0.497754 -0.002246 0.95"
]
},
- "execution_count": 33,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# simulation with built-in transformations\n",
- "\n",
"n_reps = 100\n",
"theta = 0.5\n",
"\n",
- "learner = make_pipeline(StandardScaler(), LassoCV())\n",
+ "learner = LinearRegression()\n",
"\n",
"res_cre_general = np.full((n_reps, 3), np.nan)\n",
"res_cre_normal = np.full((n_reps, 3), np.nan)\n",
"res_fd = np.full((n_reps, 3), np.nan)\n",
"res_wd = np.full((n_reps, 3), np.nan)\n",
"\n",
- "np.random.seed(1)\n",
+ "np.random.seed(12)\n",
"\n",
"for i in range(n_reps):\n",
" print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
@@ -793,11 +858,7 @@
" confint = dml_plpr.confint()\n",
" res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
" \n",
- " # WD approach, for now need new data object as FD approach overwrites _cluster_vars \n",
- " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
+ " # WD approach, for now need new data object as FD approach overwrites _cluster_vars\n",
" dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
" approach='wg_approx')\n",
" dml_plpr.fit()\n",
@@ -831,7 +892,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.13.9"
+ "version": "3.13.5"
}
},
"nbformat": 4,
From f27ebeba09cf94e035bfc7f814ca49cbd5c485a8 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 13 Nov 2025 22:23:56 +0100
Subject: [PATCH 22/33] update logic for cre_normal approach in estimation and
tuning
---
doubleml/plm/plpr.py | 37 +++--
doubleml/plm/sim/example_sim.ipynb | 244 +++++++++++++++++++++++------
2 files changed, 227 insertions(+), 54 deletions(-)
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 768850c75..9c7d2cb40 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -148,6 +148,7 @@ def __init__(
self._initialize_ml_nuisance_params()
self._sensitivity_implemented = False
self._external_predictions_implemented = True
+ self._set_d_mean()
def _format_score_info_str(self):
score_approach_info = f"Score function: {str(self.score)}\n" f"Static panel model approach: {str(self.approach)}"
@@ -213,7 +214,7 @@ def _check_approach(self, approach, valid_approach):
return
def _transform_data(self):
- df = self._original_dml_data.data.copy()
+ df = self._original_dml_data.data
y_col = self._original_dml_data.y_col
d_cols = self._original_dml_data.d_cols
@@ -229,7 +230,7 @@ def _transform_data(self):
elif self._approach == "fd_exact":
# TODO: potential issues with unbalanced panels/missing periods, right now the
# last available is used for the lag and first difference. Maybe reindex to a complete time grid per id.
- df = df.sort_values([id_col, t_col])
+ df = df.sort_values([id_col, t_col]).copy()
shifted = df[[id_col] + x_cols].groupby(id_col).shift(1).add_suffix("_lag")
first_diff = df[[id_col] + [y_col] + d_cols].groupby(id_col).diff().add_suffix("_diff")
df_fd = pd.concat([df, shifted], axis=1)
@@ -260,6 +261,17 @@ def _transform_data(self):
return data, cols
+ def _set_d_mean(self):
+ if self._approach == "cre_normal":
+ data = self._original_dml_data.data
+ d_cols = self._original_dml_data.d_cols
+ id_col = self._original_dml_data.id_col
+ help_d_mean = data.loc[:, [id_col] + d_cols]
+ d_mean = help_d_mean.groupby(id_col).transform("mean").values
+ self._d_mean = d_mean
+ else:
+ self._d_mean = None
+
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
@@ -292,16 +304,15 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
if m_external:
m_hat = {"preds": external_predictions["ml_m"], "targets": None, "models": None}
else:
- # TODO: update this section
- # cre using m_d + x for m_hat, otherwise only x
if self._approach == "cre_normal":
- help_d_mean = pd.DataFrame({"id": self._dml_data.id_var, "d": d})
- d_mean = help_d_mean.groupby(["id"]).transform("mean").values
- x = np.column_stack((x, d_mean))
+ d_mean = self._d_mean[:, self._i_treat]
+ x_m = np.column_stack((x, d_mean))
+ else:
+ x_m = x
m_hat = _dml_cv_predict(
self._learner["ml_m"],
- x,
+ x_m,
d,
smpls=smpls,
n_jobs=n_jobs_cv,
@@ -311,6 +322,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
)
# general cre adjustment
+ # TODO: update this section
if self._approach == "cre_general":
help_data = pd.DataFrame({"id": self._dml_data.id_var, "m_hat": m_hat["preds"], "d": d})
group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
@@ -380,7 +392,6 @@ def _sensitivity_element_est(self, preds):
pass
def _nuisance_tuning(
- # TODO: include mean_x for cre approach
self,
smpls,
param_grids,
@@ -409,9 +420,15 @@ def _nuisance_tuning(
search_mode,
n_iter_randomized_search,
)
+ if self._approach == "cre_normal":
+ d_mean = self._d_mean[:, self._i_treat]
+ x_m = np.column_stack((x, d_mean))
+ else:
+ x_m = x
+
m_tune_res = _dml_tune(
d,
- x,
+ x_m,
train_inds,
self._learner["ml_m"],
param_grids["ml_m"],
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index e4eb0c6f0..af90919fa 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -92,18 +92,61 @@
"warnings.filterwarnings(\"ignore\")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "8bf5c99c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.base import BaseEstimator, TransformerMixin\n",
+ "\n",
+ "class PolyPlus(BaseEstimator, TransformerMixin):\n",
+ " \"\"\"PolynomialFeatures(degree=k) and additional terms x_i^(k+1).\"\"\"\n",
+ "\n",
+ " def __init__(self, degree=2, interaction_only=False, include_bias=False):\n",
+ " self.degree = degree\n",
+ " self.extra_degree = degree + 1\n",
+ " self.interaction_only = interaction_only\n",
+ " self.include_bias = include_bias\n",
+ " self.poly = PolynomialFeatures(degree=degree, interaction_only=interaction_only, include_bias=include_bias)\n",
+ "\n",
+ " def fit(self, X, y=None):\n",
+ " self.poly.fit(X)\n",
+ " self.n_features_in_ = X.shape[1]\n",
+ " return self\n",
+ "\n",
+ " def transform(self, X):\n",
+ " X = np.asarray(X)\n",
+ " X_poly = self.poly.transform(X)\n",
+ " X_extra = X ** self.extra_degree\n",
+ " return np.hstack([X_poly, X_extra])\n",
+ "\n",
+ " def get_feature_names_out(self, input_features=None):\n",
+ " input_features = np.array(\n",
+ " input_features\n",
+ " if input_features is not None\n",
+ " else [f\"x{i}\" for i in range(self.n_features_in_)]\n",
+ " )\n",
+ " poly_names = self.poly.get_feature_names_out(input_features)\n",
+ " extra_names = [f\"{name}^{self.extra_degree}\" for name in input_features]\n",
+ " return np.concatenate([poly_names, extra_names])"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 3,
- "id": "3c061cf1",
+ "id": "7c5cd6bc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "(np.float64(0.6719371174913908),\n",
- " np.float64(0.6090488219157394),\n",
- " np.float64(0.7348254130670423))"
+ "array(['poly_x__x0', 'poly_x__x1', 'poly_x__x0^2', 'poly_x__x0 x1',\n",
+ " 'poly_x__x1^2', 'poly_x__x0^3', 'poly_x__x1^3', 'poly_x_mean__x2',\n",
+ " 'poly_x_mean__x3', 'poly_x_mean__x2^2', 'poly_x_mean__x2 x3',\n",
+ " 'poly_x_mean__x3^2', 'poly_x_mean__x2^3', 'poly_x_mean__x3^3',\n",
+ " 'remainder__x4'], dtype=object)"
]
},
"execution_count": 3,
@@ -111,6 +154,54 @@
"output_type": "execute_result"
}
],
+ "source": [
+ "preprocessor = ColumnTransformer([\n",
+ " ('poly_x', make_pipeline(\n",
+ " PolyPlus(include_bias=False, interaction_only=False)\n",
+ " ), [0, 1]), \n",
+ " ('poly_x_mean', make_pipeline(\n",
+ " PolyPlus(include_bias=False, interaction_only=False)\n",
+ " ), [2, 3]) \n",
+ "], remainder='passthrough')\n",
+ "\n",
+ "learner = make_pipeline(\n",
+ " preprocessor,\n",
+ " StandardScaler(),\n",
+ " LinearRegression()\n",
+ ")\n",
+ "\n",
+ "np.random.seed(1)\n",
+ "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "\n",
+ "y = np.array(data['y'])\n",
+ "X = np.array(data[['x1', 'x2', 'x3', 'x4', 'x5']])\n",
+ "\n",
+ "pred = learner.fit(X, y).predict(X)\n",
+ "\n",
+ "learner.named_steps['columntransformer'].get_feature_names_out()\n",
+ "# learner.named_steps['columntransformer']['poly_x'].get_feature_names_out()\n",
+ "# learner.named_steps['linearregression'].n_features_in_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "3c061cf1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(np.float64(0.6719371174913912),\n",
+ " np.float64(0.6090488219157397),\n",
+ " np.float64(0.7348254130670426))"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"np.random.seed(1)\n",
"data = make_static_panel_CP2025(dgp_type='dgp1')\n",
@@ -127,7 +218,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 37,
"id": "44387af3",
"metadata": {},
"outputs": [
@@ -136,7 +227,7 @@
"output_type": "stream",
"text": [
" coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n"
+ "d 0.510898 0.020339 25.119106 3.075876e-139 0.471034 0.550762\n"
]
}
],
@@ -240,7 +331,7 @@
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 15,
"id": "8731057f",
"metadata": {},
"outputs": [
@@ -248,14 +339,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.551055 0.020727 26.586753 9.659490e-156 0.510432 0.591679\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.51012 0.034875 14.627007 1.889172e-48 0.441766 0.578475\n"
]
}
],
"source": [
"# cre general, extend features\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp3')\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
@@ -265,19 +356,34 @@
" x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
- "n_features = len(panel_data_obj.x_cols)\n",
- "#\n",
- "indices_x = [i for i in range(n_features)]\n",
- "indices_x_mean = [i for i in range(n_features, 2 * n_features)]\n",
+ "dim_x = len(panel_data_obj.x_cols)\n",
+ "indices_x = [x for x in range(dim_x)]\n",
+ "indices_x_mean = [x + dim_x for x in indices_x]\n",
+ "\n",
+ "# x_cols = panel_data_obj.x_cols\n",
+ "# dim_x = len(x_cols)\n",
+ "# x_cols_for_poly = ['x1', 'x11', 'x22']\n",
+ "\n",
+ "# indices_x = [i for i, c in enumerate(x_cols) if c in x_cols_for_poly]\n",
+ "# indices_x_mean = [x + dim_x for x in indices_x]\n",
+ "\n",
+ "# preprocessor = ColumnTransformer([\n",
+ "# ('poly_x', make_pipeline(\n",
+ "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ "# ), indices_x), \n",
+ "# ('poly_x_mean', make_pipeline(\n",
+ "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ "# ), indices_x_mean) \n",
+ "# ], remainder='passthrough') # add passthrough for cre_normal needed\n",
"\n",
"preprocessor = ColumnTransformer([\n",
" ('poly_x', make_pipeline(\n",
- " PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
" ), indices_x), \n",
" ('poly_x_mean', make_pipeline(\n",
- " PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
+ " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
" ), indices_x_mean) \n",
- "])\n",
+ "], remainder='passthrough')\n",
"\n",
"learner = make_pipeline(\n",
" preprocessor,\n",
@@ -285,6 +391,8 @@
" LassoCV()\n",
")\n",
"\n",
+ "# learner = LassoCV()\n",
+ "\n",
"ml_l = clone(learner)\n",
"ml_m = clone(learner)\n",
"\n",
@@ -295,17 +403,17 @@
},
{
"cell_type": "code",
- "execution_count": 62,
+ "execution_count": 16,
"id": "d7936e42",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "990"
+ "1050"
]
},
- "execution_count": 62,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -316,17 +424,75 @@
},
{
"cell_type": "code",
- "execution_count": 69,
- "id": "26039cea",
+ "execution_count": 17,
+ "id": "20f158d6",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1050"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "# dml_plpr_obj.models['ml_l']['d'][0][0].named_steps['columntransformer']['poly_x'].get_feature_names_out()"
+ "dml_plpr_obj.models['ml_m']['d'][0][0].named_steps['lassocv'].n_features_in_"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
+ "id": "9edf129b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_demean 1.195553 0.048717 24.540889 5.410930e-133 1.10007 1.291036\n"
+ ]
+ }
+ ],
+ "source": [
+ "from lightgbm import LGBMRegressor\n",
+ "\n",
+ "ml_boost = LGBMRegressor(verbose=-1, \n",
+ " n_estimators=100, \n",
+ " learning_rate=0.3,\n",
+ " min_child_samples=1) \n",
+ "\n",
+ "ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ "\n",
+ "ml_l = clone(ml_boost)\n",
+ "ml_m = clone(ml_boost)\n",
+ "\n",
+ "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp3')\n",
+ "\n",
+ "panel_data_obj = DoubleMLPanelData(data,\n",
+ " y_col='y',\n",
+ " d_cols='d',\n",
+ " t_col='time',\n",
+ " id_col='id',\n",
+ " x_cols=[col for col in data.columns if \"x\" in col],\n",
+ " static_panel=True)\n",
+ "\n",
+ "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='wg_approx', n_folds=5)\n",
+ "dml_plpr_obj.tune(param_grids=ml_boost_grid, n_jobs_cv=5)\n",
+ "dml_plpr_obj.fit(store_models=True)\n",
+ "print(dml_plpr_obj.summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
"id": "135a38ab",
"metadata": {},
"outputs": [
@@ -334,8 +500,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.548085 0.02097 26.136802 1.392306e-150 0.506985 0.589186\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.564418 0.020749 27.201758 6.191263e-163 0.52375 0.605086\n"
]
}
],
@@ -365,7 +531,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 26,
"id": "d869e493",
"metadata": {},
"outputs": [
@@ -373,8 +539,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.492487 0.025352 19.42565 4.684180e-84 0.442797 0.542176\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.495522 0.023769 20.847281 1.613320e-96 0.448936 0.542109\n"
]
}
],
@@ -403,7 +569,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 30,
"id": "579b43fa",
"metadata": {},
"outputs": [
@@ -411,8 +577,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_demean 0.545586 0.021263 25.659316 3.327986e-145 0.503912 0.58726\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_demean 0.507186 0.022073 22.977214 7.878187e-117 0.463922 0.550449\n"
]
}
],
@@ -571,16 +737,6 @@
"dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['columntransformer']['poly'].get_feature_names_out()"
]
},
- {
- "cell_type": "code",
- "execution_count": 42,
- "id": "94e89dc1",
- "metadata": {},
- "outputs": [],
- "source": [
- "# dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['polynomialfeatures'].get_feature_names_out()"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
@@ -892,7 +1048,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.13.5"
+ "version": "3.12.6"
}
},
"nbformat": 4,
From 011f8cb124fcb48be015df2e0c81b32967d8d063 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 24 Nov 2025 18:10:52 +0100
Subject: [PATCH 23/33] add simulation replication
---
doubleml/plm/sim/learners_sim.ipynb | 1463 +++++++++++++--------------
1 file changed, 697 insertions(+), 766 deletions(-)
diff --git a/doubleml/plm/sim/learners_sim.ipynb b/doubleml/plm/sim/learners_sim.ipynb
index c29e03a6b..d6662488b 100644
--- a/doubleml/plm/sim/learners_sim.ipynb
+++ b/doubleml/plm/sim/learners_sim.ipynb
@@ -9,50 +9,101 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
- "from doubleml.data.base_data import DoubleMLData\n",
+ "from doubleml.data.panel_data import DoubleMLPanelData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
- "from sklearn.linear_model import LassoCV, LinearRegression\n",
+ "from sklearn.linear_model import LassoCV\n",
"from sklearn.base import clone\n",
"from sklearn.tree import DecisionTreeRegressor\n",
- "from sklearn.ensemble import RandomForestRegressor\n",
"from lightgbm import LGBMRegressor\n",
- "from doubleml.plm.utils._plpr_util import extend_data, cre_fct, fd_fct, wd_fct\n",
+ "# from doubleml.plm.utils._plpr_util import extend_data, cre_fct, fd_fct, wd_fct\n",
"from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
- "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.preprocessing import StandardScaler, PolynomialFeatures\n",
"from sklearn.pipeline import make_pipeline\n",
+ "from sklearn.base import BaseEstimator, TransformerMixin\n",
+ "from sklearn.compose import ColumnTransformer\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "2715990b",
+ "execution_count": 3,
+ "id": "2650a1dd",
"metadata": {},
"outputs": [],
"source": [
- "ml_ols = LinearRegression()\n",
- "\n",
- "ml_lasso = make_pipeline(StandardScaler(), LassoCV())\n",
+ "class PolyPlus(BaseEstimator, TransformerMixin):\n",
+ " \"\"\"PolynomialFeatures(degree=k) and additional terms x_i^(k+1).\"\"\"\n",
+ "\n",
+ " def __init__(self, degree=2, interaction_only=False, include_bias=False):\n",
+ " self.degree = degree\n",
+ " self.extra_degree = degree + 1\n",
+ " self.interaction_only = interaction_only\n",
+ " self.include_bias = include_bias\n",
+ " self.poly = PolynomialFeatures(degree=degree, interaction_only=interaction_only, include_bias=include_bias)\n",
+ "\n",
+ " def fit(self, X, y=None):\n",
+ " self.poly.fit(X)\n",
+ " self.n_features_in_ = X.shape[1]\n",
+ " return self\n",
+ "\n",
+ " def transform(self, X):\n",
+ " X = np.asarray(X)\n",
+ " X_poly = self.poly.transform(X)\n",
+ " X_extra = X ** self.extra_degree\n",
+ " return np.hstack([X_poly, X_extra])\n",
+ "\n",
+ " def get_feature_names_out(self, input_features=None):\n",
+ " input_features = np.array(\n",
+ " input_features\n",
+ " if input_features is not None\n",
+ " else [f\"x{i}\" for i in range(self.n_features_in_)]\n",
+ " )\n",
+ " poly_names = self.poly.get_feature_names_out(input_features)\n",
+ " extra_names = [f\"{name}^{self.extra_degree}\" for name in input_features]\n",
+ " return np.concatenate([poly_names, extra_names])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "68cab57c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dim_x = 30\n",
+ "indices_x = [x for x in range(dim_x)]\n",
+ "indices_x_tr = [x + dim_x for x in indices_x]\n",
+ "\n",
+ "preprocessor = ColumnTransformer([\n",
+ " ('poly_x', make_pipeline(\n",
+ " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
+ " ), indices_x), \n",
+ " ('poly_x_mean', make_pipeline(\n",
+ " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
+ " ), indices_x_tr) \n",
+ "], remainder='passthrough')\n",
+ "\n",
+ "preprocessor_wg = ColumnTransformer([\n",
+ " ('poly_x', make_pipeline(\n",
+ " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
+ " ), indices_x), \n",
+ "], remainder='passthrough')\n",
+ "\n",
+ "ml_lasso = make_pipeline(\n",
+ " preprocessor,\n",
+ " StandardScaler(),\n",
+ " LassoCV(n_alphas=20, cv=2, n_jobs=5)\n",
+ ")\n",
+ "\n",
+ "ml_lasso_wg = make_pipeline(\n",
+ " preprocessor_wg,\n",
+ " StandardScaler(),\n",
+ " LassoCV(n_alphas=20, cv=2, n_jobs=5)\n",
+ ")\n",
"\n",
"ml_cart = DecisionTreeRegressor()\n",
"\n",
- "ml_rf = RandomForestRegressor(n_estimators=100, \n",
- " max_features=1.0, \n",
- " min_samples_leaf=5)\n",
- "\n",
- "# Rf\n",
- "# ml_rf_grid = {'ml_l': {'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- "# 'ml_m': {'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- "\n",
- "# dml_plpr = DoubleMLPLPR(data_pdml, clone(ml_rf), clone(ml_rf), pdml_approach='cre', n_folds=5)\n",
- "# dml_plpr.tune(param_grids=ml_rf_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- "# dml_plpr.fit(n_jobs_cv=5)\n",
- "# res_cre_rf[i, 0] = dml_plpr.coef[0] - theta\n",
- "# res_cre_rf[i, 1] = dml_plpr.se[0] \n",
- "# confint = dml_plpr.confint()\n",
- "# res_cre_rf[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
"ml_boost = LGBMRegressor(verbose=-1, \n",
" n_estimators=100, \n",
" learning_rate=0.3,\n",
@@ -61,7 +112,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"id": "dca81b0b",
"metadata": {},
"outputs": [
@@ -88,8 +139,8 @@
" | \n",
" id | \n",
" time | \n",
- " d | \n",
" y | \n",
+ " d | \n",
" x1 | \n",
" x2 | \n",
" x3 | \n",
@@ -114,121 +165,121 @@
" 0 | \n",
" 1 | \n",
" 1 | \n",
- " -8.112787 | \n",
- " -8.912584 | \n",
- " -5.796365 | \n",
- " -0.601492 | \n",
- " -3.487003 | \n",
- " 4.357256 | \n",
- " -3.527997 | \n",
- " -7.455948 | \n",
+ " -1.767568 | \n",
+ " -1.348237 | \n",
+ " 0.492677 | \n",
+ " 0.281933 | \n",
+ " -2.142618 | \n",
+ " 0.791443 | \n",
+ " 2.976620 | \n",
+ " -1.001715 | \n",
" ... | \n",
- " 5.577388 | \n",
- " -1.605127 | \n",
- " -0.814059 | \n",
- " -3.103182 | \n",
- " 2.631538 | \n",
- " -4.643003 | \n",
- " 5.162550 | \n",
- " 3.740774 | \n",
- " 2.113925 | \n",
- " 2.026183 | \n",
+ " -1.198500 | \n",
+ " -0.049211 | \n",
+ " 0.600889 | \n",
+ " 2.435667 | \n",
+ " -1.387149 | \n",
+ " 3.034459 | \n",
+ " -0.062419 | \n",
+ " 0.258166 | \n",
+ " -1.168477 | \n",
+ " -1.061057 | \n",
" \n",
" \n",
" | 1 | \n",
" 1 | \n",
" 2 | \n",
- " -6.949439 | \n",
- " -11.955038 | \n",
- " -3.906188 | \n",
- " 2.728437 | \n",
- " -4.309356 | \n",
- " 4.652335 | \n",
- " 4.837147 | \n",
- " 5.113480 | \n",
+ " -5.095199 | \n",
+ " -3.566642 | \n",
+ " -1.608388 | \n",
+ " -0.819905 | \n",
+ " -3.570497 | \n",
+ " 1.583374 | \n",
+ " 1.644214 | \n",
+ " -4.221177 | \n",
" ... | \n",
- " -6.215166 | \n",
- " -1.291356 | \n",
- " 1.542859 | \n",
- " -5.832660 | \n",
- " -6.999235 | \n",
- " -1.041017 | \n",
- " 0.388897 | \n",
- " 0.135666 | \n",
- " -5.257444 | \n",
- " -4.460909 | \n",
+ " 4.517936 | \n",
+ " 0.413499 | \n",
+ " 2.150563 | \n",
+ " -2.971910 | \n",
+ " 0.922270 | \n",
+ " -2.628696 | \n",
+ " -1.772420 | \n",
+ " -3.851087 | \n",
+ " 3.270008 | \n",
+ " 0.820763 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 3 | \n",
- " -4.068573 | \n",
- " -6.083197 | \n",
- " 1.199280 | \n",
- " 1.113007 | \n",
- " -3.238536 | \n",
- " 5.611841 | \n",
- " -3.096405 | \n",
- " 7.262224 | \n",
+ " 6.437985 | \n",
+ " 4.202518 | \n",
+ " -1.050850 | \n",
+ " -1.400580 | \n",
+ " 3.183289 | \n",
+ " 3.513685 | \n",
+ " 1.861339 | \n",
+ " 0.888485 | \n",
" ... | \n",
- " -6.793106 | \n",
- " 5.217539 | \n",
- " 4.765350 | \n",
- " 3.238961 | \n",
- " -3.244586 | \n",
- " 0.046503 | \n",
- " 7.297417 | \n",
- " 5.151098 | \n",
- " 0.353556 | \n",
- " -6.192547 | \n",
+ " 1.163815 | \n",
+ " -0.069711 | \n",
+ " -0.202117 | \n",
+ " -1.262765 | \n",
+ " 1.133570 | \n",
+ " 0.884130 | \n",
+ " 0.484024 | \n",
+ " 3.124910 | \n",
+ " 0.004369 | \n",
+ " 0.349072 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 4 | \n",
- " 4.268473 | \n",
- " 8.099756 | \n",
- " -3.690119 | \n",
- " -3.551698 | \n",
- " 7.695905 | \n",
- " 3.349990 | \n",
- " -3.575687 | \n",
- " -9.272200 | \n",
+ " 1.692969 | \n",
+ " 0.776318 | \n",
+ " -0.275264 | \n",
+ " -0.787588 | \n",
+ " -1.492324 | \n",
+ " -3.920095 | \n",
+ " -2.246768 | \n",
+ " -1.655923 | \n",
" ... | \n",
- " 2.183245 | \n",
- " -9.719218 | \n",
- " -3.691420 | \n",
- " -4.724887 | \n",
- " -2.681429 | \n",
- " -3.256659 | \n",
- " 2.039591 | \n",
- " -5.688881 | \n",
- " -1.675406 | \n",
- " -1.537060 | \n",
+ " -1.120645 | \n",
+ " -1.726098 | \n",
+ " -2.561617 | \n",
+ " -2.247641 | \n",
+ " 0.685799 | \n",
+ " 3.943749 | \n",
+ " 2.891479 | \n",
+ " 5.381948 | \n",
+ " 1.455669 | \n",
+ " -2.480590 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1 | \n",
" 5 | \n",
- " -8.490611 | \n",
- " -13.074335 | \n",
- " -8.383416 | \n",
- " 1.125561 | \n",
- " -4.826987 | \n",
- " 1.226380 | \n",
- " 0.565376 | \n",
- " 1.337693 | \n",
+ " -0.760349 | \n",
+ " -0.423762 | \n",
+ " -5.334104 | \n",
+ " -3.650439 | \n",
+ " 0.447345 | \n",
+ " -4.053885 | \n",
+ " 1.367370 | \n",
+ " 0.752763 | \n",
" ... | \n",
- " -1.622405 | \n",
- " -11.514240 | \n",
- " -4.995206 | \n",
- " -0.293343 | \n",
- " 5.670162 | \n",
- " 5.218059 | \n",
- " -10.535997 | \n",
- " -0.007612 | \n",
- " 4.940226 | \n",
- " -2.512659 | \n",
+ " -1.666836 | \n",
+ " -0.607192 | \n",
+ " 3.658921 | \n",
+ " -1.153617 | \n",
+ " 1.338251 | \n",
+ " 2.990290 | \n",
+ " -0.717240 | \n",
+ " 2.494413 | \n",
+ " -0.576748 | \n",
+ " -0.049214 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -258,121 +309,121 @@
" 995 | \n",
" 100 | \n",
" 6 | \n",
- " 7.979518 | \n",
- " 13.313478 | \n",
- " 0.743929 | \n",
- " 0.479841 | \n",
- " 8.463661 | \n",
- " -3.785925 | \n",
- " 3.066799 | \n",
- " -5.972398 | \n",
+ " 6.736806 | \n",
+ " 3.652808 | \n",
+ " 0.935941 | \n",
+ " -1.046372 | \n",
+ " 2.972970 | \n",
+ " -3.188139 | \n",
+ " -1.368655 | \n",
+ " -1.973138 | \n",
" ... | \n",
- " -8.675939 | \n",
- " -0.339098 | \n",
- " 0.200580 | \n",
- " 4.741587 | \n",
- " 3.884253 | \n",
- " 0.082965 | \n",
- " -3.765886 | \n",
- " 2.210837 | \n",
- " -2.203842 | \n",
- " 9.350995 | \n",
+ " -0.693739 | \n",
+ " 1.142071 | \n",
+ " -1.017755 | \n",
+ " 4.427896 | \n",
+ " 0.988486 | \n",
+ " -0.082238 | \n",
+ " 1.047705 | \n",
+ " 1.222772 | \n",
+ " 3.264437 | \n",
+ " 1.173541 | \n",
"
\n",
" \n",
" | 996 | \n",
" 100 | \n",
" 7 | \n",
- " 4.525037 | \n",
- " 7.323752 | \n",
- " 2.795891 | \n",
- " -0.028399 | \n",
- " 3.351155 | \n",
- " -13.480410 | \n",
- " 4.504775 | \n",
- " 2.866025 | \n",
+ " 6.023375 | \n",
+ " 3.868859 | \n",
+ " -0.667464 | \n",
+ " 5.690558 | \n",
+ " 4.230361 | \n",
+ " 0.512239 | \n",
+ " -0.016779 | \n",
+ " -1.998309 | \n",
" ... | \n",
- " 2.935810 | \n",
- " -6.909156 | \n",
- " -6.092518 | \n",
- " 7.090190 | \n",
- " -0.192387 | \n",
- " -0.971816 | \n",
- " 2.114409 | \n",
- " 7.572450 | \n",
- " -3.337941 | \n",
- " 4.831238 | \n",
+ " 0.429095 | \n",
+ " -0.879970 | \n",
+ " -0.333486 | \n",
+ " 0.856337 | \n",
+ " 3.159868 | \n",
+ " -2.772002 | \n",
+ " 1.782697 | \n",
+ " 1.158639 | \n",
+ " 1.373919 | \n",
+ " -3.298531 | \n",
"
\n",
" \n",
" | 997 | \n",
" 100 | \n",
" 8 | \n",
- " 2.510815 | \n",
- " 3.504373 | \n",
- " 4.272010 | \n",
- " -3.236265 | \n",
- " 1.253958 | \n",
- " 1.062489 | \n",
- " -7.690689 | \n",
- " 6.750913 | \n",
+ " 3.110866 | \n",
+ " 0.955741 | \n",
+ " 1.789798 | \n",
+ " -1.256050 | \n",
+ " -0.369970 | \n",
+ " -1.965363 | \n",
+ " 0.805850 | \n",
+ " -0.837638 | \n",
" ... | \n",
- " -9.397734 | \n",
- " 1.931898 | \n",
- " 7.888287 | \n",
- " 0.276521 | \n",
- " 3.114361 | \n",
- " 4.152857 | \n",
- " 0.079838 | \n",
- " 2.297878 | \n",
- " 9.451616 | \n",
- " -1.324771 | \n",
+ " 0.314910 | \n",
+ " 2.215582 | \n",
+ " 1.695434 | \n",
+ " -0.092416 | \n",
+ " -0.691905 | \n",
+ " 2.864254 | \n",
+ " 1.673074 | \n",
+ " -0.839339 | \n",
+ " 1.389325 | \n",
+ " 1.714632 | \n",
"
\n",
" \n",
" | 998 | \n",
" 100 | \n",
" 9 | \n",
- " -4.087541 | \n",
- " -3.451450 | \n",
- " 0.115834 | \n",
- " -2.387410 | \n",
- " -1.961343 | \n",
- " -4.106975 | \n",
- " 4.037239 | \n",
- " -3.903956 | \n",
+ " 4.572219 | \n",
+ " 3.239447 | \n",
+ " -3.129920 | \n",
+ " -1.654972 | \n",
+ " 3.222430 | \n",
+ " -1.193018 | \n",
+ " 0.287887 | \n",
+ " -0.382851 | \n",
" ... | \n",
- " -5.021652 | \n",
- " 1.694328 | \n",
- " -1.283313 | \n",
- " 7.283484 | \n",
- " 8.015243 | \n",
- " 6.879811 | \n",
- " -7.213541 | \n",
- " -2.226587 | \n",
- " -0.305480 | \n",
- " -1.568153 | \n",
+ " -0.481192 | \n",
+ " 1.541299 | \n",
+ " 1.153674 | \n",
+ " -4.382081 | \n",
+ " 4.017794 | \n",
+ " 1.117018 | \n",
+ " -1.648193 | \n",
+ " -1.139779 | \n",
+ " 2.748231 | \n",
+ " 3.032575 | \n",
"
\n",
" \n",
" | 999 | \n",
" 100 | \n",
" 10 | \n",
- " -8.074941 | \n",
- " -12.453872 | \n",
- " -0.695072 | \n",
- " -1.788528 | \n",
- " -7.955557 | \n",
- " 4.716530 | \n",
- " 5.760638 | \n",
- " -6.033057 | \n",
+ " -1.170401 | \n",
+ " -0.463559 | \n",
+ " -5.695356 | \n",
+ " -1.765611 | \n",
+ " 0.205248 | \n",
+ " -0.671551 | \n",
+ " 3.462028 | \n",
+ " -2.128120 | \n",
" ... | \n",
- " 2.323859 | \n",
- " 0.301849 | \n",
- " 0.853097 | \n",
- " 3.270169 | \n",
- " 3.749521 | \n",
- " -2.260064 | \n",
- " 5.343868 | \n",
- " -0.764016 | \n",
- " 2.769752 | \n",
- " -4.067194 | \n",
+ " 3.930542 | \n",
+ " -1.196083 | \n",
+ " 1.066743 | \n",
+ " 0.690258 | \n",
+ " 0.791818 | \n",
+ " 2.491745 | \n",
+ " 0.168359 | \n",
+ " 2.278172 | \n",
+ " -1.443654 | \n",
+ " -5.695066 | \n",
"
\n",
" \n",
"\n",
@@ -380,295 +431,172 @@
""
],
"text/plain": [
- " id time d y x1 x2 x3 x4 \\\n",
- "0 1 1 -8.112787 -8.912584 -5.796365 -0.601492 -3.487003 4.357256 \n",
- "1 1 2 -6.949439 -11.955038 -3.906188 2.728437 -4.309356 4.652335 \n",
- "2 1 3 -4.068573 -6.083197 1.199280 1.113007 -3.238536 5.611841 \n",
- "3 1 4 4.268473 8.099756 -3.690119 -3.551698 7.695905 3.349990 \n",
- "4 1 5 -8.490611 -13.074335 -8.383416 1.125561 -4.826987 1.226380 \n",
- ".. ... ... ... ... ... ... ... ... \n",
- "995 100 6 7.979518 13.313478 0.743929 0.479841 8.463661 -3.785925 \n",
- "996 100 7 4.525037 7.323752 2.795891 -0.028399 3.351155 -13.480410 \n",
- "997 100 8 2.510815 3.504373 4.272010 -3.236265 1.253958 1.062489 \n",
- "998 100 9 -4.087541 -3.451450 0.115834 -2.387410 -1.961343 -4.106975 \n",
- "999 100 10 -8.074941 -12.453872 -0.695072 -1.788528 -7.955557 4.716530 \n",
+ " id time y d x1 x2 x3 x4 \\\n",
+ "0 1 1 -1.767568 -1.348237 0.492677 0.281933 -2.142618 0.791443 \n",
+ "1 1 2 -5.095199 -3.566642 -1.608388 -0.819905 -3.570497 1.583374 \n",
+ "2 1 3 6.437985 4.202518 -1.050850 -1.400580 3.183289 3.513685 \n",
+ "3 1 4 1.692969 0.776318 -0.275264 -0.787588 -1.492324 -3.920095 \n",
+ "4 1 5 -0.760349 -0.423762 -5.334104 -3.650439 0.447345 -4.053885 \n",
+ ".. ... ... ... ... ... ... ... ... \n",
+ "995 100 6 6.736806 3.652808 0.935941 -1.046372 2.972970 -3.188139 \n",
+ "996 100 7 6.023375 3.868859 -0.667464 5.690558 4.230361 0.512239 \n",
+ "997 100 8 3.110866 0.955741 1.789798 -1.256050 -0.369970 -1.965363 \n",
+ "998 100 9 4.572219 3.239447 -3.129920 -1.654972 3.222430 -1.193018 \n",
+ "999 100 10 -1.170401 -0.463559 -5.695356 -1.765611 0.205248 -0.671551 \n",
"\n",
- " x5 x6 ... x21 x22 x23 x24 \\\n",
- "0 -3.527997 -7.455948 ... 5.577388 -1.605127 -0.814059 -3.103182 \n",
- "1 4.837147 5.113480 ... -6.215166 -1.291356 1.542859 -5.832660 \n",
- "2 -3.096405 7.262224 ... -6.793106 5.217539 4.765350 3.238961 \n",
- "3 -3.575687 -9.272200 ... 2.183245 -9.719218 -3.691420 -4.724887 \n",
- "4 0.565376 1.337693 ... -1.622405 -11.514240 -4.995206 -0.293343 \n",
- ".. ... ... ... ... ... ... ... \n",
- "995 3.066799 -5.972398 ... -8.675939 -0.339098 0.200580 4.741587 \n",
- "996 4.504775 2.866025 ... 2.935810 -6.909156 -6.092518 7.090190 \n",
- "997 -7.690689 6.750913 ... -9.397734 1.931898 7.888287 0.276521 \n",
- "998 4.037239 -3.903956 ... -5.021652 1.694328 -1.283313 7.283484 \n",
- "999 5.760638 -6.033057 ... 2.323859 0.301849 0.853097 3.270169 \n",
+ " x5 x6 ... x21 x22 x23 x24 \\\n",
+ "0 2.976620 -1.001715 ... -1.198500 -0.049211 0.600889 2.435667 \n",
+ "1 1.644214 -4.221177 ... 4.517936 0.413499 2.150563 -2.971910 \n",
+ "2 1.861339 0.888485 ... 1.163815 -0.069711 -0.202117 -1.262765 \n",
+ "3 -2.246768 -1.655923 ... -1.120645 -1.726098 -2.561617 -2.247641 \n",
+ "4 1.367370 0.752763 ... -1.666836 -0.607192 3.658921 -1.153617 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "995 -1.368655 -1.973138 ... -0.693739 1.142071 -1.017755 4.427896 \n",
+ "996 -0.016779 -1.998309 ... 0.429095 -0.879970 -0.333486 0.856337 \n",
+ "997 0.805850 -0.837638 ... 0.314910 2.215582 1.695434 -0.092416 \n",
+ "998 0.287887 -0.382851 ... -0.481192 1.541299 1.153674 -4.382081 \n",
+ "999 3.462028 -2.128120 ... 3.930542 -1.196083 1.066743 0.690258 \n",
"\n",
- " x25 x26 x27 x28 x29 x30 \n",
- "0 2.631538 -4.643003 5.162550 3.740774 2.113925 2.026183 \n",
- "1 -6.999235 -1.041017 0.388897 0.135666 -5.257444 -4.460909 \n",
- "2 -3.244586 0.046503 7.297417 5.151098 0.353556 -6.192547 \n",
- "3 -2.681429 -3.256659 2.039591 -5.688881 -1.675406 -1.537060 \n",
- "4 5.670162 5.218059 -10.535997 -0.007612 4.940226 -2.512659 \n",
- ".. ... ... ... ... ... ... \n",
- "995 3.884253 0.082965 -3.765886 2.210837 -2.203842 9.350995 \n",
- "996 -0.192387 -0.971816 2.114409 7.572450 -3.337941 4.831238 \n",
- "997 3.114361 4.152857 0.079838 2.297878 9.451616 -1.324771 \n",
- "998 8.015243 6.879811 -7.213541 -2.226587 -0.305480 -1.568153 \n",
- "999 3.749521 -2.260064 5.343868 -0.764016 2.769752 -4.067194 \n",
+ " x25 x26 x27 x28 x29 x30 \n",
+ "0 -1.387149 3.034459 -0.062419 0.258166 -1.168477 -1.061057 \n",
+ "1 0.922270 -2.628696 -1.772420 -3.851087 3.270008 0.820763 \n",
+ "2 1.133570 0.884130 0.484024 3.124910 0.004369 0.349072 \n",
+ "3 0.685799 3.943749 2.891479 5.381948 1.455669 -2.480590 \n",
+ "4 1.338251 2.990290 -0.717240 2.494413 -0.576748 -0.049214 \n",
+ ".. ... ... ... ... ... ... \n",
+ "995 0.988486 -0.082238 1.047705 1.222772 3.264437 1.173541 \n",
+ "996 3.159868 -2.772002 1.782697 1.158639 1.373919 -3.298531 \n",
+ "997 -0.691905 2.864254 1.673074 -0.839339 1.389325 1.714632 \n",
+ "998 4.017794 1.117018 -1.648193 -1.139779 2.748231 3.032575 \n",
+ "999 0.791818 2.491745 0.168359 2.278172 -1.443654 -5.695066 \n",
"\n",
"[1000 rows x 34 columns]"
]
},
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp1', x_var=5**2, a_var=0.95**2)\n",
+ "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp1', x_var=5, a_var=0.95)\n",
"data"
]
},
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "d4580f3d",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Index(['id', 'time', 'd', 'y', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8',\n",
- " 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18',\n",
- " 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28',\n",
- " 'x29', 'x30'],\n",
- " dtype='object')"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "wd_fct(data).columns"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
- "id": "f1342648",
+ "id": "b15dcb8d",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "n_reps = 20\n",
- "theta = 0.5\n",
- "dgp = 'dgp3'\n",
- "\n",
- "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
- "res_cre_lasso = np.full((n_reps, 3), np.nan)\n",
- "res_cre_cart = np.full((n_reps, 3), np.nan)\n",
- "res_cre_boost = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "res_fd_ols = np.full((n_reps, 3), np.nan)\n",
- "res_fd_lasso = np.full((n_reps, 3), np.nan)\n",
- "res_fd_cart = np.full((n_reps, 3), np.nan)\n",
- "res_fd_boost = np.full((n_reps, 3), np.nan)\n",
+ "def run_sim(n_reps, num_n, dim_x=30, theta=0.5, dgp_type='dgp3'):\n",
"\n",
- "res_wd_ols = np.full((n_reps, 3), np.nan)\n",
- "res_wd_lasso = np.full((n_reps, 3), np.nan)\n",
- "res_wd_cart = np.full((n_reps, 3), np.nan)\n",
- "res_wd_boost = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "\n",
- "np.random.seed(123)\n",
+ " approaches = [\"cre_general\", \"cre_normal\", \"fd_exact\", \"wg_approx\"]\n",
+ " models = [\"lasso\", \"cart\", \"boost\"]\n",
"\n",
- "for i in range(n_reps):\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
+ " res = {\n",
+ " d: {\n",
+ " m: np.full((n_reps, 3), np.nan)\n",
+ " for m in models\n",
+ " }\n",
+ " for d in approaches\n",
+ " }\n",
"\n",
- " ml_cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
+ " x_cols = [f\"x{i+1}\" for i in range(dim_x)]\n",
"\n",
- " ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- " \n",
- " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type=dgp, x_var=5**2, a_var=0.95**2)\n",
+ " def run_single_dml(dml_data, ml, approach, grid=None):\n",
+ " est = DoubleMLPLPR(dml_data, clone(ml), clone(ml), approach=approach, n_folds=5)\n",
"\n",
- " ## CRE\n",
- " cre_data = cre_fct(data)\n",
+ " if grid is not None:\n",
+ " est.tune(param_grids=grid, search_mode='randomized_search',\n",
+ " n_iter_randomized_search=5, n_jobs_cv=5)\n",
"\n",
- " data_cre_pdml = DoubleMLData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col])\n",
+ " est.fit()\n",
"\n",
- " # OLS\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.fit()\n",
- " res_cre_ols[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_ols[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " coef_err = est.coef[0] - theta\n",
+ " se = est.se[0]\n",
+ " conf = est.confint()\n",
+ " covered = (conf['2.5 %'].iloc[0] <= theta) & (conf['97.5 %'].iloc[0] >= theta)\n",
"\n",
- " # Lasso\n",
- " cre_data_ext = extend_data(cre_data)\n",
- " data_cre_pdml_ext = DoubleMLData(cre_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " return coef_err, se, covered\n",
"\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_lasso[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " for i in range(n_reps):\n",
"\n",
- " # Cart\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_cart[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_cart[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
"\n",
- " # Boost\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_boost[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_boost[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
"\n",
- " ## FD\n",
- " fd_data = fd_fct(data)\n",
+ " boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
+ " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
+ " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
"\n",
- " data_fd_pdml = DoubleMLData(fd_data,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data.columns if \"x\" in col])\n",
+ " data = make_static_panel_CP2025(num_n=num_n, dim_x=dim_x, theta=theta, dgp_type=dgp_type, x_var=5, a_var=0.95)\n",
+ " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', x_cols=x_cols, static_panel=True)\n",
"\n",
- " # OLS\n",
- " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.fit()\n",
- " res_fd_ols[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_fd_ols[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_fd_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " # CRE general\n",
+ " res['cre_general']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'cre_general', grid=None)\n",
+ " res['cre_general']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'cre_general', grid=cart_grid)\n",
+ " res['cre_general']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'cre_general', grid=boost_grid)\n",
"\n",
- " # Lasso\n",
- " fd_data_ext = extend_data(fd_data)\n",
- " data_fd_pdml_ext = DoubleMLData(fd_data_ext,\n",
- " y_col='y_diff',\n",
- " d_cols='d_diff',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in fd_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " # CRE normal\n",
+ " res['cre_normal']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'cre_normal', grid=None)\n",
+ " res['cre_normal']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'cre_normal', grid=cart_grid)\n",
+ " res['cre_normal']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'cre_normal', grid=boost_grid)\n",
"\n",
- " dml_plpr = DoubleMLPLPR(data_fd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_fd_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_fd_lasso[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_fd_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " # FD\n",
+ " res['fd_exact']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'fd_exact', grid=None)\n",
+ " res['fd_exact']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'fd_exact', grid=cart_grid)\n",
+ " res['fd_exact']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'fd_exact', grid=boost_grid)\n",
"\n",
- " # Cart\n",
- " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_fd_cart[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_fd_cart[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_fd_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " # WD\n",
+ " res['wg_approx']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso_wg, 'wg_approx', grid=None)\n",
+ " res['wg_approx']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'wg_approx', grid=cart_grid)\n",
+ " res['wg_approx']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'wg_approx', grid=boost_grid)\n",
"\n",
- " # Boost\n",
- " dml_plpr = DoubleMLPLPR(data_fd_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_fd_boost[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_fd_boost[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_fd_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " # summary\n",
+ " rows = []\n",
+ " index = []\n",
"\n",
- " ## WD\n",
- " wd_data = wd_fct(data)\n",
+ " for approach, models_dict in res.items():\n",
+ " for model, arr in models_dict.items():\n",
"\n",
- " data_wd_pdml = DoubleMLData(wd_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data.columns if \"x\" in col])\n",
+ " bias = np.mean(arr[:, 0])\n",
+ " se_mean = np.mean(arr[:, 1])\n",
+ " sd = np.std(arr[:, 1])\n",
+ " coverage = np.mean(arr[:, 2])\n",
+ " se_over_sd = sd / se_mean if se_mean > 0 else np.nan\n",
+ " rmse = np.sqrt(np.mean(arr[:, 0]**2))\n",
"\n",
- " # OLS\n",
- " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.fit()\n",
- " res_wd_ols[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_wd_ols[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_wd_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
+ " rows.append([bias, rmse, se_over_sd, coverage])\n",
+ " index.append((approach, model)) \n",
"\n",
- " # Lasso\n",
- " wd_data_ext = extend_data(wd_data)\n",
- " data_wd_pdml_ext = DoubleMLData(wd_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in wd_data_ext.columns if \"x\" in col]\n",
- " )\n",
+ " summary = pd.DataFrame(\n",
+ " rows,\n",
+ " index=pd.MultiIndex.from_tuples(index, names=[\"Approach\", \"ML Model\"]),\n",
+ " columns=[\"Bias\", \"RMSE\", \"SE/SD\", \"Coverage\"]\n",
+ " )\n",
"\n",
- " dml_plpr = DoubleMLPLPR(data_wd_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_wd_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_wd_lasso[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_wd_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # Cart\n",
- " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_wd_cart[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_wd_cart[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_wd_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # Boost\n",
- " dml_plpr = DoubleMLPLPR(data_wd_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='transform', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_wd_boost[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_wd_boost[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_wd_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
+ " return summary"
]
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "cf49d96e",
+ "execution_count": 14,
+ "id": "6503f741",
"metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ },
{
"data": {
"text/html": [
@@ -690,179 +618,147 @@
" \n",
" \n",
" | \n",
+ " | \n",
" Bias | \n",
- " SE | \n",
- " Coverage | \n",
- " SE/SD | \n",
" RMSE | \n",
+ " SE/SD | \n",
+ " Coverage | \n",
+ "
\n",
+ " \n",
+ " | Approach | \n",
+ " ML Model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
" \n",
" \n",
" \n",
- " | OLS (CRE) | \n",
- " 0.992168 | \n",
- " 0.004074 | \n",
- " 0.0 | \n",
- " 0.111677 | \n",
- " 0.992171 | \n",
+ " cre_general | \n",
+ " lasso | \n",
+ " 0.024030 | \n",
+ " 0.040756 | \n",
+ " 0.076669 | \n",
+ " 0.92 | \n",
"
\n",
" \n",
- " | Lasso (CRE) | \n",
- " 0.006154 | \n",
- " 0.033672 | \n",
- " 1.0 | \n",
- " 0.090553 | \n",
- " 0.028267 | \n",
+ " cart | \n",
+ " -0.033091 | \n",
+ " 0.069709 | \n",
+ " 0.112711 | \n",
+ " 0.75 | \n",
"
\n",
" \n",
- " | Cart (CRE) | \n",
- " 0.665847 | \n",
- " 0.087393 | \n",
- " 0.0 | \n",
- " 0.405625 | \n",
- " 0.702282 | \n",
+ " boost | \n",
+ " -0.041728 | \n",
+ " 0.062508 | \n",
+ " 0.114469 | \n",
+ " 0.76 | \n",
"
\n",
" \n",
- " | Boost (CRE) | \n",
- " 0.667262 | \n",
- " 0.066499 | \n",
- " 0.0 | \n",
- " 0.225313 | \n",
- " 0.672432 | \n",
+ " cre_normal | \n",
+ " lasso | \n",
+ " 0.094727 | \n",
+ " 0.103281 | \n",
+ " 0.111451 | \n",
+ " 0.41 | \n",
"
\n",
" \n",
- " | OLS (FD) | \n",
- " 0.990554 | \n",
- " 0.004807 | \n",
- " 0.0 | \n",
- " 0.096511 | \n",
- " 0.990562 | \n",
+ " cart | \n",
+ " -0.015669 | \n",
+ " 0.072878 | \n",
+ " 0.125903 | \n",
+ " 0.89 | \n",
"
\n",
" \n",
- " | Lasso (FD) | \n",
- " 0.025886 | \n",
- " 0.039459 | \n",
- " 1.0 | \n",
- " 0.104909 | \n",
- " 0.037748 | \n",
+ " boost | \n",
+ " 0.051158 | \n",
+ " 0.081405 | \n",
+ " 0.128427 | \n",
+ " 0.90 | \n",
"
\n",
" \n",
- " | Cart (FD) | \n",
- " 0.851024 | \n",
- " 0.045228 | \n",
- " 0.0 | \n",
- " 0.402826 | \n",
- " 0.854973 | \n",
+ " fd_exact | \n",
+ " lasso | \n",
+ " 0.021691 | \n",
+ " 0.044563 | \n",
+ " 0.084716 | \n",
+ " 0.90 | \n",
"
\n",
" \n",
- " | Boost (FD) | \n",
- " 0.835220 | \n",
- " 0.040769 | \n",
- " 0.0 | \n",
- " 0.087226 | \n",
- " 0.837982 | \n",
+ " cart | \n",
+ " 0.076125 | \n",
+ " 0.105962 | \n",
+ " 0.099474 | \n",
+ " 0.62 | \n",
"
\n",
" \n",
- " | OLS (WD) | \n",
- " 0.992407 | \n",
- " 0.004118 | \n",
- " 0.0 | \n",
- " 0.111584 | \n",
- " 0.992413 | \n",
+ " boost | \n",
+ " 0.010003 | \n",
+ " 0.049025 | \n",
+ " 0.116624 | \n",
+ " 0.87 | \n",
"
\n",
" \n",
- " | Lasso (WD) | \n",
- " 0.969836 | \n",
- " 0.010887 | \n",
- " 0.0 | \n",
- " 0.084016 | \n",
- " 0.969867 | \n",
+ " wg_approx | \n",
+ " lasso | \n",
+ " 0.003884 | \n",
+ " 0.032382 | \n",
+ " 0.077256 | \n",
+ " 0.97 | \n",
"
\n",
" \n",
- " | Cart (WD) | \n",
- " 0.676682 | \n",
- " 0.067397 | \n",
- " 0.0 | \n",
- " 0.365097 | \n",
- " 0.709067 | \n",
+ " cart | \n",
+ " -0.002517 | \n",
+ " 0.048548 | \n",
+ " 0.090029 | \n",
+ " 0.87 | \n",
"
\n",
" \n",
- " | Boost (WD) | \n",
- " 0.813528 | \n",
- " 0.031009 | \n",
- " 0.0 | \n",
- " 0.057072 | \n",
- " 0.816461 | \n",
+ " boost | \n",
+ " -0.042819 | \n",
+ " 0.058042 | \n",
+ " 0.086614 | \n",
+ " 0.71 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " Bias SE Coverage SE/SD RMSE\n",
- "OLS (CRE) 0.992168 0.004074 0.0 0.111677 0.992171\n",
- "Lasso (CRE) 0.006154 0.033672 1.0 0.090553 0.028267\n",
- "Cart (CRE) 0.665847 0.087393 0.0 0.405625 0.702282\n",
- "Boost (CRE) 0.667262 0.066499 0.0 0.225313 0.672432\n",
- "OLS (FD) 0.990554 0.004807 0.0 0.096511 0.990562\n",
- "Lasso (FD) 0.025886 0.039459 1.0 0.104909 0.037748\n",
- "Cart (FD) 0.851024 0.045228 0.0 0.402826 0.854973\n",
- "Boost (FD) 0.835220 0.040769 0.0 0.087226 0.837982\n",
- "OLS (WD) 0.992407 0.004118 0.0 0.111584 0.992413\n",
- "Lasso (WD) 0.969836 0.010887 0.0 0.084016 0.969867\n",
- "Cart (WD) 0.676682 0.067397 0.0 0.365097 0.709067\n",
- "Boost (WD) 0.813528 0.031009 0.0 0.057072 0.816461"
+ " Bias RMSE SE/SD Coverage\n",
+ "Approach ML Model \n",
+ "cre_general lasso 0.024030 0.040756 0.076669 0.92\n",
+ " cart -0.033091 0.069709 0.112711 0.75\n",
+ " boost -0.041728 0.062508 0.114469 0.76\n",
+ "cre_normal lasso 0.094727 0.103281 0.111451 0.41\n",
+ " cart -0.015669 0.072878 0.125903 0.89\n",
+ " boost 0.051158 0.081405 0.128427 0.90\n",
+ "fd_exact lasso 0.021691 0.044563 0.084716 0.90\n",
+ " cart 0.076125 0.105962 0.099474 0.62\n",
+ " boost 0.010003 0.049025 0.116624 0.87\n",
+ "wg_approx lasso 0.003884 0.032382 0.077256 0.97\n",
+ " cart -0.002517 0.048548 0.090029 0.87\n",
+ " boost -0.042819 0.058042 0.086614 0.71"
]
},
- "execution_count": 10,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# cre general, dgp square\n",
- "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
- " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0),\n",
- " res_fd_ols.mean(axis=0), res_fd_lasso.mean(axis=0), \n",
- " res_fd_cart.mean(axis=0), res_fd_boost.mean(axis=0),\n",
- " res_wd_ols.mean(axis=0), res_wd_lasso.mean(axis=0), \n",
- " res_wd_cart.mean(axis=0), res_wd_boost.mean(axis=0)])\n",
- "\n",
- "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
- " res_cre_cart[:,1].std(), res_cre_boost[:,1].std(),\n",
- " res_fd_ols[:,1].std(), res_fd_lasso[:,1].std(), \n",
- " res_fd_cart[:,1].std(), res_fd_boost[:,1].std(),\n",
- " res_wd_ols[:,1].std(), res_wd_lasso[:,1].std(), \n",
- " res_wd_cart[:,1].std(), res_wd_boost[:,1].std()])\n",
- "\n",
- "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
- " res_cre_cart[:,1], res_cre_boost[:,1],\n",
- " res_fd_ols[:,1], res_fd_lasso[:,1], \n",
- " res_fd_cart[:,1], res_fd_boost[:,1],\n",
- " res_wd_ols[:,1], res_wd_lasso[:,1], \n",
- " res_wd_cart[:,1], res_wd_boost[:,1]])\n",
- "\n",
- "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
- " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2)),\n",
- " np.sqrt(np.mean(res_fd_ols[:,0]**2)), np.sqrt(np.mean(res_fd_lasso[:,0]**2)), \n",
- " np.sqrt(np.mean(res_fd_cart[:,0]**2)), np.sqrt(np.mean(res_fd_boost[:,0]**2)),\n",
- " np.sqrt(np.mean(res_wd_ols[:,0]**2)), np.sqrt(np.mean(res_wd_lasso[:,0]**2)), \n",
- " np.sqrt(np.mean(res_wd_cart[:,0]**2)), np.sqrt(np.mean(res_wd_boost[:,0]**2))])\n",
- "\n",
- "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
- "\n",
- "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "np.random.seed(123)\n",
"\n",
- "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
- " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)',\n",
- " 'OLS (FD)', 'Lasso (FD)', 'Cart (FD)', 'Boost (FD)',\n",
- " 'OLS (WD)', 'Lasso (WD)', 'Cart (WD)', 'Boost (WD)'])"
+ "res_dgp1 = run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp1')\n",
+ "res_dgp1"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "93efa3c9",
+ "execution_count": 16,
+ "id": "7a7b3fa2",
"metadata": {},
"outputs": [
{
@@ -871,94 +767,7 @@
"text": [
"Processing: 100.0 %"
]
- }
- ],
- "source": [
- "n_reps = 20\n",
- "theta = 0.5\n",
- "dgp = 'dgp3'\n",
- "\n",
- "res_cre_ols = np.full((n_reps, 3), np.nan)\n",
- "res_cre_lasso = np.full((n_reps, 3), np.nan)\n",
- "res_cre_cart = np.full((n_reps, 3), np.nan)\n",
- "res_cre_boost = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "np.random.seed(123)\n",
- "\n",
- "for i in range(n_reps):\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- "\n",
- " ml_cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- "\n",
- " ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- " \n",
- " data = make_static_panel_CP2025(num_n=4000, theta=theta, dgp_type=dgp, x_var=5**2, a_var=0.95**2)\n",
- "\n",
- " ## CRE\n",
- " cre_data = cre_fct(data)\n",
- "\n",
- " data_cre_pdml = DoubleMLData(cre_data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data.columns if \"x\" in col]\n",
- " )\n",
- "\n",
- " # OLS\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_ols), clone(ml_ols), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.fit()\n",
- " res_cre_ols[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_ols[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_ols[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # Lasso\n",
- " cre_data_ext = extend_data(cre_data)\n",
- " data_cre_pdml_ext = DoubleMLData(cre_data_ext,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " cluster_cols='id',\n",
- " x_cols=[col for col in cre_data_ext.columns if \"x\" in col]\n",
- " )\n",
- "\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml_ext, clone(ml_lasso), clone(ml_lasso), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_lasso[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_lasso[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_lasso[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # Cart\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_cart), clone(ml_cart), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_cart_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_cart[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_cart[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_cart[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # Boost\n",
- " dml_plpr = DoubleMLPLPR(data_cre_pdml, clone(ml_boost), clone(ml_boost), pdml_approach='cre_general', n_folds=5)\n",
- " dml_plpr.tune(param_grids=ml_boost_grid, search_mode='randomized_search', n_iter_randomized_search=5, n_jobs_cv=5)\n",
- " dml_plpr.fit(n_jobs_cv=5)\n",
- " res_cre_boost[i, 0] = dml_plpr.coef[0] - theta\n",
- " res_cre_boost[i, 1] = dml_plpr.se[0] \n",
- " confint = dml_plpr.confint()\n",
- " res_cre_boost[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "9eda614d",
- "metadata": {},
- "outputs": [
+ },
{
"data": {
"text/html": [
@@ -980,91 +789,156 @@
" \n",
" \n",
" | \n",
+ " | \n",
" Bias | \n",
- " SE | \n",
- " Coverage | \n",
- " SE/SD | \n",
" RMSE | \n",
+ " SE/SD | \n",
+ " Coverage | \n",
+ "
\n",
+ " \n",
+ " | Approach | \n",
+ " ML Model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
" \n",
" \n",
" \n",
- " | OLS (CRE) | \n",
- " 0.992847 | \n",
- " 0.000642 | \n",
- " 0.0 | \n",
- " 0.010548 | \n",
- " 0.992848 | \n",
+ " cre_general | \n",
+ " lasso | \n",
+ " -0.005823 | \n",
+ " 0.027244 | \n",
+ " 0.073960 | \n",
+ " 0.95 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " -0.109613 | \n",
+ " 0.122162 | \n",
+ " 0.106904 | \n",
+ " 0.24 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " -0.063039 | \n",
+ " 0.075676 | \n",
+ " 0.081278 | \n",
+ " 0.50 | \n",
"
\n",
" \n",
- " | Lasso (CRE) | \n",
- " 0.025631 | \n",
- " 0.012231 | \n",
- " 0.4 | \n",
- " 0.024428 | \n",
- " 0.027179 | \n",
+ " cre_normal | \n",
+ " lasso | \n",
+ " 0.070362 | \n",
+ " 0.076794 | \n",
+ " 0.087415 | \n",
+ " 0.40 | \n",
"
\n",
" \n",
- " | Cart (CRE) | \n",
- " -0.008529 | \n",
- " 0.035518 | \n",
- " 0.2 | \n",
- " 0.293714 | \n",
- " 0.174186 | \n",
+ " cart | \n",
+ " -0.028230 | \n",
+ " 0.073113 | \n",
+ " 0.182335 | \n",
+ " 0.94 | \n",
"
\n",
" \n",
- " | Boost (CRE) | \n",
- " -0.194521 | \n",
- " 0.018691 | \n",
- " 0.0 | \n",
- " 0.235885 | \n",
- " 0.197435 | \n",
+ " boost | \n",
+ " -0.016764 | \n",
+ " 0.060976 | \n",
+ " 0.183879 | \n",
+ " 0.89 | \n",
+ "
\n",
+ " \n",
+ " | fd_exact | \n",
+ " lasso | \n",
+ " -0.002715 | \n",
+ " 0.032797 | \n",
+ " 0.091710 | \n",
+ " 0.95 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " -0.074209 | \n",
+ " 0.083379 | \n",
+ " 0.099775 | \n",
+ " 0.47 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " -0.062527 | \n",
+ " 0.076963 | \n",
+ " 0.085480 | \n",
+ " 0.58 | \n",
+ "
\n",
+ " \n",
+ " | wg_approx | \n",
+ " lasso | \n",
+ " -0.003015 | \n",
+ " 0.027330 | \n",
+ " 0.074295 | \n",
+ " 0.95 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " -0.024405 | \n",
+ " 0.038327 | \n",
+ " 0.079555 | \n",
+ " 0.90 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " -0.056258 | \n",
+ " 0.067245 | \n",
+ " 0.083798 | \n",
+ " 0.52 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " Bias SE Coverage SE/SD RMSE\n",
- "OLS (CRE) 0.992847 0.000642 0.0 0.010548 0.992848\n",
- "Lasso (CRE) 0.025631 0.012231 0.4 0.024428 0.027179\n",
- "Cart (CRE) -0.008529 0.035518 0.2 0.293714 0.174186\n",
- "Boost (CRE) -0.194521 0.018691 0.0 0.235885 0.197435"
+ " Bias RMSE SE/SD Coverage\n",
+ "Approach ML Model \n",
+ "cre_general lasso -0.005823 0.027244 0.073960 0.95\n",
+ " cart -0.109613 0.122162 0.106904 0.24\n",
+ " boost -0.063039 0.075676 0.081278 0.50\n",
+ "cre_normal lasso 0.070362 0.076794 0.087415 0.40\n",
+ " cart -0.028230 0.073113 0.182335 0.94\n",
+ " boost -0.016764 0.060976 0.183879 0.89\n",
+ "fd_exact lasso -0.002715 0.032797 0.091710 0.95\n",
+ " cart -0.074209 0.083379 0.099775 0.47\n",
+ " boost -0.062527 0.076963 0.085480 0.58\n",
+ "wg_approx lasso -0.003015 0.027330 0.074295 0.95\n",
+ " cart -0.024405 0.038327 0.079555 0.90\n",
+ " boost -0.056258 0.067245 0.083798 0.52"
]
},
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# cre normal, dgp square\n",
- "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
- " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0)])\n",
- "\n",
- "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
- " res_cre_cart[:,1].std(), res_cre_boost[:,1].std()])\n",
- "\n",
- "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
- " res_cre_cart[:,1], res_cre_boost[:,1]])\n",
- "\n",
- "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
- " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2))])\n",
- "\n",
- "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
- "\n",
- "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "np.random.seed(123)\n",
"\n",
- "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
- " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)'])"
+ "res_dgp2 = run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp2')\n",
+ "res_dgp2"
]
},
{
"cell_type": "code",
"execution_count": 17,
- "id": "cad7f1f9",
+ "id": "f8f0abee",
"metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing: 100.0 %"
+ ]
+ },
{
"data": {
"text/html": [
@@ -1086,56 +960,129 @@
" \n",
" \n",
" | \n",
+ " | \n",
" Bias | \n",
- " SE | \n",
- " Coverage | \n",
- " SE/SD | \n",
" RMSE | \n",
+ " SE/SD | \n",
+ " Coverage | \n",
+ "
\n",
+ " \n",
+ " | Approach | \n",
+ " ML Model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
" \n",
" \n",
" \n",
- " | OLS (CRE) | \n",
- " 0.992847 | \n",
- " 0.000642 | \n",
- " 0.0 | \n",
- " 0.010548 | \n",
- " 0.992848 | \n",
+ " cre_general | \n",
+ " lasso | \n",
+ " 0.016582 | \n",
+ " 0.037272 | \n",
+ " 0.080832 | \n",
+ " 0.94 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " 0.493280 | \n",
+ " 0.534554 | \n",
+ " 0.275800 | \n",
+ " 0.05 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " 0.443519 | \n",
+ " 0.451916 | \n",
+ " 0.180075 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " | cre_normal | \n",
+ " lasso | \n",
+ " 0.083721 | \n",
+ " 0.100538 | \n",
+ " 0.129024 | \n",
+ " 0.71 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " 0.413418 | \n",
+ " 0.468509 | \n",
+ " 0.218055 | \n",
+ " 0.13 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " 0.385419 | \n",
+ " 0.393097 | \n",
+ " 0.185301 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " | fd_exact | \n",
+ " lasso | \n",
+ " 0.014759 | \n",
+ " 0.041217 | \n",
+ " 0.086594 | \n",
+ " 0.92 | \n",
"
\n",
" \n",
- " | Lasso (CRE) | \n",
- " -0.008373 | \n",
- " 0.005419 | \n",
- " 0.6 | \n",
- " 0.008168 | \n",
- " 0.009689 | \n",
+ " cart | \n",
+ " 0.658754 | \n",
+ " 0.673590 | \n",
+ " 0.259508 | \n",
+ " 0.00 | \n",
"
\n",
" \n",
- " | Cart (CRE) | \n",
- " 0.059338 | \n",
- " 0.059713 | \n",
- " 0.4 | \n",
- " 0.452969 | \n",
- " 0.187953 | \n",
+ " boost | \n",
+ " 0.588552 | \n",
+ " 0.593496 | \n",
+ " 0.134829 | \n",
+ " 0.00 | \n",
"
\n",
" \n",
- " | Boost (CRE) | \n",
- " 0.036213 | \n",
- " 0.031977 | \n",
- " 0.7 | \n",
- " 0.247513 | \n",
- " 0.057652 | \n",
+ " wg_approx | \n",
+ " lasso | \n",
+ " 0.630003 | \n",
+ " 0.630942 | \n",
+ " 0.099154 | \n",
+ " 0.00 | \n",
+ "
\n",
+ " \n",
+ " | cart | \n",
+ " 0.515894 | \n",
+ " 0.532497 | \n",
+ " 0.227144 | \n",
+ " 0.01 | \n",
+ "
\n",
+ " \n",
+ " | boost | \n",
+ " 0.593587 | \n",
+ " 0.597220 | \n",
+ " 0.137237 | \n",
+ " 0.00 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " Bias SE Coverage SE/SD RMSE\n",
- "OLS (CRE) 0.992847 0.000642 0.0 0.010548 0.992848\n",
- "Lasso (CRE) -0.008373 0.005419 0.6 0.008168 0.009689\n",
- "Cart (CRE) 0.059338 0.059713 0.4 0.452969 0.187953\n",
- "Boost (CRE) 0.036213 0.031977 0.7 0.247513 0.057652"
+ " Bias RMSE SE/SD Coverage\n",
+ "Approach ML Model \n",
+ "cre_general lasso 0.016582 0.037272 0.080832 0.94\n",
+ " cart 0.493280 0.534554 0.275800 0.05\n",
+ " boost 0.443519 0.451916 0.180075 0.00\n",
+ "cre_normal lasso 0.083721 0.100538 0.129024 0.71\n",
+ " cart 0.413418 0.468509 0.218055 0.13\n",
+ " boost 0.385419 0.393097 0.185301 0.00\n",
+ "fd_exact lasso 0.014759 0.041217 0.086594 0.92\n",
+ " cart 0.658754 0.673590 0.259508 0.00\n",
+ " boost 0.588552 0.593496 0.134829 0.00\n",
+ "wg_approx lasso 0.630003 0.630942 0.099154 0.00\n",
+ " cart 0.515894 0.532497 0.227144 0.01\n",
+ " boost 0.593587 0.597220 0.137237 0.00"
]
},
"execution_count": 17,
@@ -1144,25 +1091,9 @@
}
],
"source": [
- "# cre general, dgp square\n",
- "tab_dat = np.vstack([res_cre_ols.mean(axis=0), res_cre_lasso.mean(axis=0), \n",
- " res_cre_cart.mean(axis=0), res_cre_boost.mean(axis=0)])\n",
- "\n",
- "tab_sd = np.vstack([res_cre_ols[:,1].std(), res_cre_lasso[:,1].std(), \n",
- " res_cre_cart[:,1].std(), res_cre_boost[:,1].std()])\n",
- "\n",
- "tab_se = np.column_stack([res_cre_ols[:,1], res_cre_lasso[:,1], \n",
- " res_cre_cart[:,1], res_cre_boost[:,1]])\n",
- "\n",
- "tab_rmse = np.vstack([np.sqrt(np.mean(res_cre_ols[:,0]**2)), np.sqrt(np.mean(res_cre_lasso[:,0]**2)), \n",
- " np.sqrt(np.mean(res_cre_cart[:,0]**2)), np.sqrt(np.mean(res_cre_boost[:,0]**2))])\n",
- "\n",
- "se_sd = tab_sd / tab_dat[:,1].reshape((-1,1))\n",
- "\n",
- "tab_dat = np.column_stack((tab_dat, se_sd, tab_rmse))\n",
+ "np.random.seed(123)\n",
"\n",
- "pd.DataFrame(tab_dat, columns=['Bias', 'SE', 'Coverage', 'SE/SD', 'RMSE'], \n",
- " index=['OLS (CRE)', 'Lasso (CRE)', 'Cart (CRE)', 'Boost (CRE)'])"
+ "run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp3')"
]
}
],
From 3f8eef66b5d3976166f5040b66d7fba3535d6d6e Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 24 Nov 2025 18:13:21 +0100
Subject: [PATCH 24/33] update PLPR model
---
doubleml/plm/__init__.py | 2 +
doubleml/plm/datasets/__init__.py | 2 +
doubleml/plm/plpr.py | 40 ++++++++++++-----
doubleml/plm/sim/example_sim.ipynb | 70 +++++++++++++++++-------------
4 files changed, 75 insertions(+), 39 deletions(-)
diff --git a/doubleml/plm/__init__.py b/doubleml/plm/__init__.py
index e81f00c52..e0949d6e1 100644
--- a/doubleml/plm/__init__.py
+++ b/doubleml/plm/__init__.py
@@ -4,8 +4,10 @@
from .pliv import DoubleMLPLIV
from .plr import DoubleMLPLR
+from .plpr import DoubleMLPLPR
__all__ = [
"DoubleMLPLR",
"DoubleMLPLIV",
+ "DoubleMLPLPR",
]
diff --git a/doubleml/plm/datasets/__init__.py b/doubleml/plm/datasets/__init__.py
index b2bb7df0e..17e97e3bd 100644
--- a/doubleml/plm/datasets/__init__.py
+++ b/doubleml/plm/datasets/__init__.py
@@ -8,6 +8,7 @@
from .dgp_pliv_multiway_cluster_CKMS2021 import make_pliv_multiway_cluster_CKMS2021
from .dgp_plr_CCDDHNR2018 import make_plr_CCDDHNR2018
from .dgp_plr_turrell2018 import make_plr_turrell2018
+from .dgp_static_panel_CP2025 import make_static_panel_CP2025
__all__ = [
"make_plr_CCDDHNR2018",
@@ -16,4 +17,5 @@
"make_pliv_CHS2015",
"make_pliv_multiway_cluster_CKMS2021",
"_make_pliv_data",
+ "make_static_panel_CP2025",
]
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index 9c7d2cb40..b98e14aed 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -60,12 +60,24 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
Indicates whether the sample splitting should be drawn during initialization of the object.
Default is ``True``.
- TODO: include example and notes
+ TODO: include notes
Examples
--------
>>> import numpy as np
>>> import doubleml as dml
-
+ >>> from doubleml.plm.datasets import make_static_panel_CP2025
+ >>> from sklearn.linear_model import LassoCV
+ >>> from sklearn.base import clone
+ >>> np.random.seed(3142)
+ >>> learner = LassoCV()
+ >>> ml_l = clone(learner)
+ >>> ml_m = clone(learner)
+ >>> data = make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1')
+ >>> obj_dml_data = DoubleMLPanelData(data, 'y', 'd', 'time', 'id', static_panel=True)
+ >>> dml_plpr_obj = DoubleMLPLPR(obj_dml_data, ml_l, ml_m)
+ >>> dml_plpr_obj.fit().summary
+ coef std err t P>|t| 2.5 % 97.5 %
+ d_diff 0.511626 0.024615 20.784933 5.924636e-96 0.463381 0.559871
Notes
-----
@@ -110,9 +122,9 @@ def __init__(
valid_scores = ["IV-type", "partialling out"]
_check_score(self.score, valid_scores, allow_callable=True)
- # TODO: update learner checks
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
+ # TODO: maybe warning for binary treatment with approaches 'fd_exact' and 'wg_approx'
self._learner = {"ml_l": ml_l, "ml_m": ml_m}
if ml_g is not None:
@@ -138,10 +150,16 @@ def __init__(
if self._dml_data.binary_treats.all():
self._predict_method["ml_m"] = "predict_proba"
else:
- raise ValueError(
+ msg = (
f"The ml_m learner {str(ml_m)} was identified as classifier "
"but at least one treatment variable is not binary with values 0 and 1."
)
+ if self._approach in ["fd_exact", "wg_approx"]:
+ msg += (
+ " Note: In case of binary input treatment variable(s), approaches 'fd_exact' and "
+ "'wg_approx' tansform the treatment variable(s), such that they are no longer binary."
+ )
+ raise ValueError(msg)
else:
self._predict_method["ml_m"] = "predict"
@@ -158,8 +176,10 @@ def _format_additional_info_str(self):
"""
Includes information on the original data before transformation.
"""
+ # TODO: adjust header length of additional info in double_ml.py
data_original_summary = (
- f"Original Data Summary Pre-transformation:\n\n"
+ f"Cluster variable(s): {self._original_dml_data.cluster_cols}\n"
+ f"\nPre-Transformation Data Summary: \n"
f"Outcome variable: {self._original_dml_data.y_col}\n"
f"Treatment variable(s): {self._original_dml_data.d_cols}\n"
f"Covariates: {self._original_dml_data.x_cols}\n"
@@ -262,7 +282,7 @@ def _transform_data(self):
return data, cols
def _set_d_mean(self):
- if self._approach == "cre_normal":
+ if self._approach in ["cre_general", "cre_normal"]:
data = self._original_dml_data.data
d_cols = self._original_dml_data.d_cols
id_col = self._original_dml_data.id_col
@@ -322,11 +342,11 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
)
# general cre adjustment
- # TODO: update this section
if self._approach == "cre_general":
- help_data = pd.DataFrame({"id": self._dml_data.id_var, "m_hat": m_hat["preds"], "d": d})
- group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
- m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
+ d_mean = self._d_mean[:, self._i_treat]
+ df_m_hat = pd.DataFrame({"id": self._dml_data.id_var, "m_hat": m_hat["preds"]})
+ m_hat_mean = df_m_hat.groupby(["id"]).transform("mean")
+ m_hat_star = m_hat["preds"] + d_mean - m_hat_mean["m_hat"]
m_hat["preds"] = m_hat_star
_check_finite_predictions(m_hat["preds"], self._learner["ml_m"], "ml_m", smpls)
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index af90919fa..aabccc688 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -80,7 +80,7 @@
"from doubleml.data.base_data import DoubleMLData \n",
"from doubleml.data.panel_data import DoubleMLPanelData\n",
"from doubleml.plm.plpr import DoubleMLPLPR\n",
- "from sklearn.linear_model import LassoCV, LinearRegression\n",
+ "from sklearn.linear_model import LassoCV, LinearRegression, LogisticRegressionCV\n",
"from sklearn.base import clone\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
@@ -218,7 +218,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 7,
"id": "44387af3",
"metadata": {},
"outputs": [
@@ -226,8 +226,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.510898 0.020339 25.119106 3.075876e-139 0.471034 0.550762\n"
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d 0.506545 0.020722 24.444675 5.733308e-132 0.465931 0.54716\n"
]
}
],
@@ -254,6 +254,27 @@
"print(dml_plpr_obj.summary)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "7f3b2faa",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dml_plpr_obj._is_cluster_data"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 5,
@@ -271,7 +292,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 3,
"id": "48d4dbd8",
"metadata": {},
"outputs": [
@@ -282,28 +303,28 @@
"================== DoubleMLPLPR Object ==================\n",
"\n",
"------------------ Data Summary ------------------\n",
- "Outcome variable: y\n",
- "Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_mean', 'x2_mean', 'x3_mean', 'x4_mean', 'x5_mean', 'x6_mean', 'x7_mean', 'x8_mean', 'x9_mean', 'x10_mean', 'x11_mean', 'x12_mean', 'x13_mean', 'x14_mean', 'x15_mean', 'x16_mean', 'x17_mean', 'x18_mean', 'x19_mean', 'x20_mean', 'x21_mean', 'x22_mean', 'x23_mean', 'x24_mean', 'x25_mean', 'x26_mean', 'x27_mean', 'x28_mean', 'x29_mean', 'x30_mean']\n",
+ "Outcome variable: y_diff\n",
+ "Treatment variable(s): ['d_diff']\n",
+ "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_lag', 'x2_lag', 'x3_lag', 'x4_lag', 'x5_lag', 'x6_lag', 'x7_lag', 'x8_lag', 'x9_lag', 'x10_lag', 'x11_lag', 'x12_lag', 'x13_lag', 'x14_lag', 'x15_lag', 'x16_lag', 'x17_lag', 'x18_lag', 'x19_lag', 'x20_lag', 'x21_lag', 'x22_lag', 'x23_lag', 'x24_lag', 'x25_lag', 'x26_lag', 'x27_lag', 'x28_lag', 'x29_lag', 'x30_lag']\n",
"Instrument variable(s): None\n",
"Time variable: time\n",
"Id variable: id\n",
"Static panel data: True\n",
"No. Unique Ids: 250\n",
- "No. Observations: 2500\n",
+ "No. Observations: 2250\n",
"\n",
"\n",
"------------------ Score & Algorithm ------------------\n",
"Score function: partialling out\n",
- "Static panel model approach: cre_general\n",
+ "Static panel model approach: fd_exact\n",
"\n",
"------------------ Machine Learner ------------------\n",
"Learner ml_l: LassoCV()\n",
"Learner ml_m: LassoCV()\n",
"Out-of-sample Performance:\n",
"Regression:\n",
- "Learner ml_l RMSE: [[1.63784321]]\n",
- "Learner ml_m RMSE: [[0.96294553]]\n",
+ "Learner ml_l RMSE: [[1.57005957]]\n",
+ "Learner ml_m RMSE: [[0.46246153]]\n",
"\n",
"------------------ Resampling ------------------\n",
"No. folds per cluster: 5\n",
@@ -311,12 +332,13 @@
"No. repeated sample splits: 1\n",
"\n",
"------------------ Fit Summary ------------------\n",
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.519934 0.020986 24.774727 1.678998e-135 0.478802 0.561067\n",
+ " coef std err t P>|t| 2.5 % 97.5 %\n",
+ "d_diff 0.62711 0.078221 8.017152 1.082254e-15 0.4738 0.780421\n",
"\n",
"------------------ Additional Information ------------------\n",
- "Original Data Summary Pre-transformation:\n",
+ "Cluster variable(s): ['id']\n",
"\n",
+ "Pre-Transformation Data Summary: \n",
"Outcome variable: y\n",
"Treatment variable(s): ['d']\n",
"Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30']\n",
@@ -606,7 +628,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"id": "32e5bc9f",
"metadata": {},
"outputs": [
@@ -631,8 +653,6 @@
" x_cols=[col for col in data.columns if \"x\" in col],\n",
" static_panel=True)\n",
"\n",
- "# learner = LassoCV()\n",
- "\n",
"# preprocessor = ColumnTransformer([\n",
"# ('poly', make_pipeline(\n",
"# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),\n",
@@ -641,12 +661,6 @@
"# ('pass', 'passthrough', ['cat']) # Columns to keep unchanged\n",
"# ])\n",
"\n",
- "# learner = make_pipeline(\n",
- "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),\n",
- "# StandardScaler(),\n",
- "# LassoCV()\n",
- "# )\n",
- "\n",
"preprocessor = ColumnTransformer([\n",
" ('poly', make_pipeline(\n",
" PolynomialFeatures(degree=2, include_bias=False)\n",
@@ -664,9 +678,7 @@
"\n",
"dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
"dml_plpr_obj.fit(store_models=True)\n",
- "print(dml_plpr_obj.summary)\n",
- "\n",
- "# dml_plpr_obj.transform_cols['x_cols']"
+ "print(dml_plpr_obj.summary)"
]
},
{
@@ -697,7 +709,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": null,
"id": "1a0e629d",
"metadata": {},
"outputs": [
@@ -718,7 +730,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": null,
"id": "68f20f79",
"metadata": {},
"outputs": [
From 677c5a6619cf9a978ae94c43791fb73b89e4229d Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 24 Nov 2025 18:13:59 +0100
Subject: [PATCH 25/33] allow binary treatment for PLPR
---
doubleml/data/panel_data.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py
index 1655b410b..9a113ebc5 100644
--- a/doubleml/data/panel_data.py
+++ b/doubleml/data/panel_data.py
@@ -107,8 +107,10 @@ def __init__(
if not self.static_panel:
cluster_cols = None
+ force_all_d_finite = False
else:
cluster_cols = id_col
+ force_all_d_finite = True
DoubleMLData.__init__(
self,
@@ -120,7 +122,7 @@ def __init__(
cluster_cols=cluster_cols,
use_other_treat_as_covariate=use_other_treat_as_covariate,
force_all_x_finite=force_all_x_finite,
- force_all_d_finite=False,
+ force_all_d_finite=force_all_d_finite,
)
# reset index to ensure a simple RangeIndex
From 7c7bd430c6775b2d6ca42b8da0aa66a2e5e8ed4f Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 27 Nov 2025 15:47:11 +0100
Subject: [PATCH 26/33] update plpr dgp function
---
doubleml/plm/datasets/dgp_plpr_CP2025.py | 134 ++++++++++++++++++
.../plm/datasets/dgp_static_panel_CP2025.py | 85 -----------
2 files changed, 134 insertions(+), 85 deletions(-)
create mode 100644 doubleml/plm/datasets/dgp_plpr_CP2025.py
delete mode 100644 doubleml/plm/datasets/dgp_static_panel_CP2025.py
diff --git a/doubleml/plm/datasets/dgp_plpr_CP2025.py b/doubleml/plm/datasets/dgp_plpr_CP2025.py
new file mode 100644
index 000000000..f6c16f30e
--- /dev/null
+++ b/doubleml/plm/datasets/dgp_plpr_CP2025.py
@@ -0,0 +1,134 @@
+import numpy as np
+import pandas as pd
+
+
+def make_plpr_CP2025(num_id=250, num_t=10, dim_x=30, theta=0.5, dgp_type="dgp1"):
+ """
+ Generates synthetic data for a partially linear panel regression model, based on Clarke and Polselli (2025).
+ The data generating process is defined as
+
+ .. math::
+
+ Y_{it} &= D_{it} \\theta + l_0(X_{it}) + \\alpha_i + U_{it} m_0(x_i), & &U_{it} \\sim \\mathcal{N}(0,1),
+
+ D_{it} &= m_0(X_{it}) + c_i + V_{it}, & &V_{it} \\sim \\mathcal{N}(0,1),
+
+ \\alpha_i &= 0.25 \\left(\\frac{1}{T} \\sum_{t=1}^{T} D_{it} - \\bar{D} \\right)
+ + 0.25 \\frac{1}{T} \\sum_{t=1}^{T} \\sum_{k \\in \\mathcal{K}} X_{it,k} + a_i
+
+
+ with :math:`a_i \\sim \\mathcal{N}(0,0.95)`, :math:`X_{it,p} \\sim \\mathcal{N}(0,5)`, :math:`c_i \\sim \\mathcal{N}(0,1)`.
+ Where :math:`k \\in \\mathcal{K} = \\{1,3\\}` is the number of relevant (non-zero) confounding variables, and :math:`p` is
+ the number of total confounding variables.
+
+ Clarke and Polselli (2025) consider three functional forms of the confounders to model the nuisance functions :math:`l_0`
+ and :math:`m_0` with varying levels of non-linearity and non-smoothness:
+
+ Design 1. (dgp1): Linear in the nuisance parameters
+
+ .. math::
+
+ l_0(X_{it}) &= a X_{it,1} + X_{it,3}
+
+ m_0(X_{it}) &= a X_{it,1} + X_{it,3}
+
+ Design 2. (dgp2): Non-linear and smooth in the nuisance parameters
+
+ .. math::
+
+ l_0(X_{it}) &= \\frac{\\exp(X_{it,1})}{1 + \\exp(X_{it,1})} + a \\cos(X_{it,3})
+
+ m_0(X_{it}) &= \\cos(X_{it,1}) + a \\frac{\\exp(X_{it,3})}{1 + \\exp(X_{it,3})}
+
+ Design 3. (dgp3): Non-linear and discontinuous in the nuisance parameters
+
+ .. math::
+
+ l_0(X_{it}) &= b (X_{it,1} \\cdot X_{it,3}) + a (X_{it,3} \\cdot 1\\{X_{it,3} > 0\\})
+
+ m_0(X_{it}) &= a (X_{it,1} \\cdot 1\\{X_{it,1} > 0\\}) + b (X_{it,1} \\cdot X_{it,3}),
+
+ where :math:`a = 0.25`, :math:`b = 0.5`.
+
+ Parameters
+ ----------
+ num_id :
+ The number of units in the panel.
+ num_t :
+ The number of time periods in the panel.
+ num_x :
+ The number of confounding variables.
+ theta :
+ The value of the causal parameter.
+ dgp_type :
+ The type of DGP design to be used. Default is ``'dgp1'``, other options are ``'dgp2'`` and ``'dgp3'``.
+
+ Returns
+ -------
+ pandas.DataFrame
+ DataFrame containing the simulated static panel data.
+
+ References
+ ----------
+ Clarke, P. S. and Polselli, A. (2025),
+ Double machine learning for static panel models with fixed effects. The Econometrics Journal, utaf011,
+ doi:`10.1093/ectj/utaf011 `_.
+ """
+
+ # parameters
+ a = 0.25
+ b = 0.5
+ sigma2_a = 0.95
+ sigma2_x = 5
+
+ # id and time vectors
+ id = np.repeat(np.arange(1, num_id + 1), num_t)
+ time = np.tile(np.arange(1, num_t + 1), num_id)
+
+ # individual fixed effects
+ a_i = np.repeat(np.random.normal(0, np.sqrt(sigma2_a), num_id), num_t)
+ c_i = np.repeat(np.random.standard_normal(num_id), num_t)
+
+ # covariates and errors
+ x_mean = 0
+ x_it = np.random.normal(loc=x_mean, scale=np.sqrt(sigma2_x), size=(num_id * num_t, dim_x))
+ u_it = np.random.standard_normal(num_id * num_t)
+ v_it = np.random.standard_normal(num_id * num_t)
+
+ # functional forms in nuisance functions
+ if dgp_type == "dgp1":
+ l_0 = a * x_it[:, 0] + x_it[:, 2]
+ m_0 = a * x_it[:, 0] + x_it[:, 2]
+ elif dgp_type == "dgp2":
+ l_0 = np.divide(np.exp(x_it[:, 0]), 1 + np.exp(x_it[:, 0])) + a * np.cos(x_it[:, 2])
+ m_0 = np.cos(x_it[:, 0]) + a * np.divide(np.exp(x_it[:, 2]), 1 + np.exp(x_it[:, 2]))
+ elif dgp_type == "dgp3":
+ l_0 = b * (x_it[:, 0] * x_it[:, 2]) + a * (x_it[:, 2] * np.where(x_it[:, 2] > 0, 1, 0))
+ m_0 = a * (x_it[:, 0] * np.where(x_it[:, 0] > 0, 1, 0)) + b * (x_it[:, 0] * x_it[:, 2])
+ else:
+ raise ValueError("Invalid dgp")
+
+ # treatment
+ d_it = m_0 + c_i + v_it
+
+ def alpha_i(x_it, d_it, a_i, num_n, num_t):
+ d_i = np.array_split(d_it, num_n)
+ d_i_term = np.repeat(np.mean(d_i, axis=1), num_t) - np.mean(d_it)
+
+ x_i = np.array_split(np.sum(x_it[:, [0, 2]], axis=1), num_n)
+ x_i_mean = np.mean(x_i, axis=1)
+ x_i_term = np.repeat(x_i_mean, num_t)
+
+ alpha_term = 0.25 * d_i_term + 0.25 * x_i_term + a_i
+ return alpha_term
+
+ # outcome
+ y_it = d_it * theta + l_0 + alpha_i(x_it, d_it, a_i, num_id, num_t) + u_it
+
+ x_cols = [f"x{i + 1}" for i in np.arange(dim_x)]
+
+ data = pd.DataFrame(np.column_stack((id, time, y_it, d_it, x_it)), columns=["id", "time", "y", "d"] + x_cols).astype(
+ {"id": "int64", "time": "int64"}
+ )
+
+ return data
diff --git a/doubleml/plm/datasets/dgp_static_panel_CP2025.py b/doubleml/plm/datasets/dgp_static_panel_CP2025.py
deleted file mode 100644
index ade148c41..000000000
--- a/doubleml/plm/datasets/dgp_static_panel_CP2025.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import numpy as np
-import pandas as pd
-
-
-def make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1', x_var=5, a_var=0.95):
- """
- Generates static panel data from the simulation dgp in Clarke and Polselli (2025).
-
- Parameters
- ----------
- num_n :
- The number of unit in the panel.
- num_t :
- The number of time periods in the panel.
- num_x :
- The number of of covariates.
- theta :
- The value of the causal parameter.
- dgp_type :
- The type of DGP design to be used. Default is ``'dgp1'``, other options are ``'dgp2'`` and ``'dgp3'``.
- x_var :
- The variance of the covariates.
- a_var :
- The variance of the individual fixed effect on outcome
-
- Returns
- -------
- pandas.DataFrame
- DataFrame containing the simulated static panel data.
- """
-
- # parameters
- a = 0.25
- b = 0.5
-
- # id and time vectors
- id = np.repeat(np.arange(1, num_n+1), num_t)
- time = np.tile(np.arange(1, num_t+1), num_n)
-
- # individual fixed effects
- a_i = np.repeat(np.random.normal(0, np.sqrt(a_var), num_n), num_t)
- c_i = np.repeat(np.random.standard_normal(num_n), num_t)
-
- # covariates and errors
- x_mean = 0
- x_it = np.random.normal(loc=x_mean, scale=np.sqrt(x_var), size=(num_n*num_t, dim_x))
- u_it = np.random.standard_normal(num_n*num_t)
- v_it = np.random.standard_normal(num_n*num_t)
-
- # functional forms in nuisance functions
- if dgp_type == 'dgp1':
- l_0 = a * x_it[:,0] + x_it[:,2]
- m_0 = a * x_it[:,0] + x_it[:,2]
- elif dgp_type == 'dgp2':
- l_0 = np.divide(np.exp(x_it[:,0]), 1 + np.exp(x_it[:,0])) + a * np.cos(x_it[:,2])
- m_0 = np.cos(x_it[:,0]) + a * np.divide(np.exp(x_it[:,2]), 1 + np.exp(x_it[:,2]))
- elif dgp_type == 'dgp3':
- l_0 = b * (x_it[:,0] * x_it[:,2]) + a * (x_it[:,2] * np.where(x_it[:,2] > 0, 1, 0))
- m_0 = a * (x_it[:,0] * np.where(x_it[:,0] > 0, 1, 0)) + b * (x_it[:,0] * x_it[:,2])
- else:
- raise ValueError('Invalid dgp')
-
- # treatment
- d_it = m_0 + c_i + v_it
-
- def alpha_i(x_it, d_it, a_i, num_n, num_t):
- d_i = np.array_split(d_it, num_n)
- d_i_term = np.repeat(np.mean(d_i, axis=1), num_t) - np.mean(d_it)
-
- x_i = np.array_split(np.sum(x_it[:, [0, 2]], axis=1), num_n)
- x_i_mean = np.mean(x_i, axis=1)
- x_i_term = np.repeat(x_i_mean, num_t)
-
- alpha_term = 0.25 * d_i_term + 0.25 * x_i_term + a_i
- return alpha_term
-
- # outcome
- y_it = d_it * theta + l_0 + alpha_i(x_it, d_it, a_i, num_n, num_t) + u_it
-
- x_cols = [f'x{i + 1}' for i in np.arange(dim_x)]
-
- data = pd.DataFrame(np.column_stack((id, time, y_it, d_it, x_it)),
- columns=['id', 'time', 'y', 'd'] + x_cols).astype({'id': 'int64', 'time': 'int64'})
-
- return data
\ No newline at end of file
From ecce2b5310646ad630f7845d0c293570cc58301a Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Thu, 27 Nov 2025 17:00:45 +0100
Subject: [PATCH 27/33] update make_plpr use
---
doubleml/plm/datasets/__init__.py | 4 +-
doubleml/plm/datasets/dgp_plpr_CP2025.py | 2 +-
doubleml/plm/plpr.py | 17 +-
doubleml/plm/sim/example_sim.ipynb | 42 +--
doubleml/plm/sim/learners_sim.ipynb | 448 +++++++++++------------
5 files changed, 261 insertions(+), 252 deletions(-)
diff --git a/doubleml/plm/datasets/__init__.py b/doubleml/plm/datasets/__init__.py
index 17e97e3bd..2d02510c0 100644
--- a/doubleml/plm/datasets/__init__.py
+++ b/doubleml/plm/datasets/__init__.py
@@ -6,9 +6,9 @@
from .dgp_confounded_plr_data import make_confounded_plr_data
from .dgp_pliv_CHS2015 import make_pliv_CHS2015
from .dgp_pliv_multiway_cluster_CKMS2021 import make_pliv_multiway_cluster_CKMS2021
+from .dgp_plpr_CP2025 import make_plpr_CP2025
from .dgp_plr_CCDDHNR2018 import make_plr_CCDDHNR2018
from .dgp_plr_turrell2018 import make_plr_turrell2018
-from .dgp_static_panel_CP2025 import make_static_panel_CP2025
__all__ = [
"make_plr_CCDDHNR2018",
@@ -17,5 +17,5 @@
"make_pliv_CHS2015",
"make_pliv_multiway_cluster_CKMS2021",
"_make_pliv_data",
- "make_static_panel_CP2025",
+ "make_plpr_CP2025",
]
diff --git a/doubleml/plm/datasets/dgp_plpr_CP2025.py b/doubleml/plm/datasets/dgp_plpr_CP2025.py
index f6c16f30e..438745c0a 100644
--- a/doubleml/plm/datasets/dgp_plpr_CP2025.py
+++ b/doubleml/plm/datasets/dgp_plpr_CP2025.py
@@ -9,7 +9,7 @@ def make_plpr_CP2025(num_id=250, num_t=10, dim_x=30, theta=0.5, dgp_type="dgp1")
.. math::
- Y_{it} &= D_{it} \\theta + l_0(X_{it}) + \\alpha_i + U_{it} m_0(x_i), & &U_{it} \\sim \\mathcal{N}(0,1),
+ Y_{it} &= D_{it} \\theta + l_0(X_{it}) + \\alpha_i + U_{it}, & &U_{it} \\sim \\mathcal{N}(0,1),
D_{it} &= m_0(X_{it}) + c_i + V_{it}, & &V_{it} \\sim \\mathcal{N}(0,1),
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index b98e14aed..ab35436f5 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -60,29 +60,38 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
Indicates whether the sample splitting should be drawn during initialization of the object.
Default is ``True``.
- TODO: include notes
Examples
--------
>>> import numpy as np
>>> import doubleml as dml
- >>> from doubleml.plm.datasets import make_static_panel_CP2025
+ >>> from doubleml.plm.datasets import make_plpr_CP2025
>>> from sklearn.linear_model import LassoCV
>>> from sklearn.base import clone
>>> np.random.seed(3142)
>>> learner = LassoCV()
>>> ml_l = clone(learner)
>>> ml_m = clone(learner)
- >>> data = make_static_panel_CP2025(num_n=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1')
+ >>> data = make_plpr_CP2025(num_id=250, num_t=10, dim_x=30, theta=0.5, dgp_type='dgp1')
>>> obj_dml_data = DoubleMLPanelData(data, 'y', 'd', 'time', 'id', static_panel=True)
>>> dml_plpr_obj = DoubleMLPLPR(obj_dml_data, ml_l, ml_m)
>>> dml_plpr_obj.fit().summary
- coef std err t P>|t| 2.5 % 97.5 %
+ coef std err t P>|t| 2.5 % 97.5 %
d_diff 0.511626 0.024615 20.784933 5.924636e-96 0.463381 0.559871
Notes
-----
**Partially linear panel regression (PLPR)** models take the form
+ .. math::
+
+ Y_{it} &= D_{it} \\theta_0 + l_0(X_{it}) + \\alpha_i + U_{it}, & &\\mathbb{E}(U_{it} | D_{it},X_{it},\\alpha_i) = 0,
+
+ D_{it} &= m_0(X_{it}) + \\gamma_i + V_{it}, & &\\mathbb{E}(V_{it} | X_{it},\\gamma_i) = 0,
+
+ where :math:`Y_{it}` is the outcome variable and :math:`D_{it}` is the policy variable of interest.
+ The high-dimensional vector :math:`X_{it} = (X_{it,1}, \\ldots, X_{it,p})` consists of other confounding covariates,
+ :math:`\\alpha_i` and :math:`\\gamma_i` are the unobserved individual heterogeneity correlated with the included covariates,
+ and :math:`\\U_{it}` and :math:` V_{it}` are stochastic errors.
"""
def __init__(
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
index aabccc688..a5f8db936 100644
--- a/doubleml/plm/sim/example_sim.ipynb
+++ b/doubleml/plm/sim/example_sim.ipynb
@@ -87,7 +87,7 @@
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import make_pipeline\n",
"from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct, extend_data\n",
- "from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
+ "from doubleml.plm.datasets.dgp_plpr_CP2025 import make_plpr_CP2025\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
@@ -135,7 +135,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "7c5cd6bc",
"metadata": {},
"outputs": [
@@ -171,7 +171,7 @@
")\n",
"\n",
"np.random.seed(1)\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"\n",
"y = np.array(data['y'])\n",
"X = np.array(data[['x1', 'x2', 'x3', 'x4', 'x5']])\n",
@@ -185,7 +185,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "3c061cf1",
"metadata": {},
"outputs": [
@@ -204,7 +204,7 @@
],
"source": [
"np.random.seed(1)\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"\n",
"x_cols = [col for col in data.columns if \"x\" in col]\n",
"\n",
@@ -218,7 +218,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "44387af3",
"metadata": {},
"outputs": [
@@ -233,7 +233,7 @@
],
"source": [
"# cre general\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"cre_data = cre_fct(data)\n",
"\n",
"learner = LassoCV()\n",
@@ -353,7 +353,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"id": "8731057f",
"metadata": {},
"outputs": [
@@ -368,7 +368,7 @@
],
"source": [
"# cre general, extend features\n",
- "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp3')\n",
+ "data = make_plpr_CP2025(num_id=100, dgp_type='dgp3')\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
@@ -467,7 +467,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"id": "9edf129b",
"metadata": {},
"outputs": [
@@ -496,7 +496,7 @@
"ml_l = clone(ml_boost)\n",
"ml_m = clone(ml_boost)\n",
"\n",
- "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp3')\n",
+ "data = make_plpr_CP2025(num_id=100, dgp_type='dgp3')\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
@@ -514,7 +514,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": null,
"id": "135a38ab",
"metadata": {},
"outputs": [
@@ -529,7 +529,7 @@
],
"source": [
"# cre normality assumption\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"cre_data = cre_fct(data)\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
@@ -553,7 +553,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"id": "d869e493",
"metadata": {},
"outputs": [
@@ -568,7 +568,7 @@
],
"source": [
"# First difference approach\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"fd_data = fd_fct(data)\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
@@ -591,7 +591,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": null,
"id": "579b43fa",
"metadata": {},
"outputs": [
@@ -606,7 +606,7 @@
],
"source": [
"# Within group approach\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"wd_data = wd_fct(data)\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
@@ -643,7 +643,7 @@
],
"source": [
"# Within group approach, polynomials\n",
- "data = make_static_panel_CP2025(dgp_type='dgp1')\n",
+ "data = make_plpr_CP2025(dgp_type='dgp1')\n",
"\n",
"panel_data_obj = DoubleMLPanelData(data,\n",
" y_col='y',\n",
@@ -847,7 +847,7 @@
"\n",
"for i in range(n_reps):\n",
" print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
+ " data = make_plpr_CP2025(num_id=100, theta=theta, dgp_type='dgp1')\n",
"\n",
" dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
" x_cols=[col for col in data.columns if \"x\" in col],\n",
@@ -899,7 +899,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"id": "24ef531c",
"metadata": {},
"outputs": [
@@ -993,7 +993,7 @@
"\n",
"for i in range(n_reps):\n",
" print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- " data = make_static_panel_CP2025(num_n=100, theta=theta, dgp_type='dgp1')\n",
+ " data = make_plpr_CP2025(num_id=100, theta=theta, dgp_type='dgp1')\n",
"\n",
" dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
" x_cols=[col for col in data.columns if \"x\" in col],\n",
diff --git a/doubleml/plm/sim/learners_sim.ipynb b/doubleml/plm/sim/learners_sim.ipynb
index d6662488b..e463dd817 100644
--- a/doubleml/plm/sim/learners_sim.ipynb
+++ b/doubleml/plm/sim/learners_sim.ipynb
@@ -16,7 +16,7 @@
"from sklearn.tree import DecisionTreeRegressor\n",
"from lightgbm import LGBMRegressor\n",
"# from doubleml.plm.utils._plpr_util import extend_data, cre_fct, fd_fct, wd_fct\n",
- "from doubleml.plm.datasets.dgp_static_panel_CP2025 import make_static_panel_CP2025\n",
+ "from doubleml.plm.datasets.dgp_plpr_CP2025 import make_plpr_CP2025\n",
"from sklearn.preprocessing import StandardScaler, PolynomialFeatures\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
@@ -27,7 +27,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "2650a1dd",
"metadata": {},
"outputs": [],
@@ -66,7 +66,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "68cab57c",
"metadata": {},
"outputs": [],
@@ -112,7 +112,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "dca81b0b",
"metadata": {},
"outputs": [
@@ -165,121 +165,121 @@
" 0 | \n",
" 1 | \n",
" 1 | \n",
- " -1.767568 | \n",
- " -1.348237 | \n",
- " 0.492677 | \n",
- " 0.281933 | \n",
- " -2.142618 | \n",
- " 0.791443 | \n",
- " 2.976620 | \n",
- " -1.001715 | \n",
+ " -5.891762 | \n",
+ " -4.574437 | \n",
+ " -1.203377 | \n",
+ " -0.087056 | \n",
+ " -3.875989 | \n",
+ " -2.060015 | \n",
+ " 1.924513 | \n",
+ " 1.564501 | \n",
" ... | \n",
- " -1.198500 | \n",
- " -0.049211 | \n",
- " 0.600889 | \n",
- " 2.435667 | \n",
- " -1.387149 | \n",
- " 3.034459 | \n",
- " -0.062419 | \n",
- " 0.258166 | \n",
- " -1.168477 | \n",
- " -1.061057 | \n",
+ " -1.697513 | \n",
+ " -2.509320 | \n",
+ " -0.727138 | \n",
+ " -2.393134 | \n",
+ " 0.334781 | \n",
+ " 2.097534 | \n",
+ " 1.942009 | \n",
+ " 1.649557 | \n",
+ " -0.612257 | \n",
+ " -4.331109 | \n",
" \n",
" \n",
" | 1 | \n",
" 1 | \n",
" 2 | \n",
- " -5.095199 | \n",
- " -3.566642 | \n",
- " -1.608388 | \n",
- " -0.819905 | \n",
- " -3.570497 | \n",
- " 1.583374 | \n",
- " 1.644214 | \n",
- " -4.221177 | \n",
+ " 0.601641 | \n",
+ " 1.217127 | \n",
+ " -1.076318 | \n",
+ " 2.226439 | \n",
+ " 0.379887 | \n",
+ " -2.491481 | \n",
+ " -1.446766 | \n",
+ " -0.000182 | \n",
" ... | \n",
- " 4.517936 | \n",
- " 0.413499 | \n",
- " 2.150563 | \n",
- " -2.971910 | \n",
- " 0.922270 | \n",
- " -2.628696 | \n",
- " -1.772420 | \n",
- " -3.851087 | \n",
- " 3.270008 | \n",
- " 0.820763 | \n",
+ " -0.185932 | \n",
+ " -0.491894 | \n",
+ " 1.320808 | \n",
+ " -2.888978 | \n",
+ " 0.296153 | \n",
+ " -0.209147 | \n",
+ " -1.066396 | \n",
+ " -2.232003 | \n",
+ " 3.217619 | \n",
+ " -0.660709 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 3 | \n",
- " 6.437985 | \n",
- " 4.202518 | \n",
- " -1.050850 | \n",
- " -1.400580 | \n",
- " 3.183289 | \n",
- " 3.513685 | \n",
- " 1.861339 | \n",
- " 0.888485 | \n",
+ " -5.336432 | \n",
+ " -4.084917 | \n",
+ " 1.684777 | \n",
+ " -2.117207 | \n",
+ " -4.038299 | \n",
+ " 1.196702 | \n",
+ " 4.320428 | \n",
+ " 1.543303 | \n",
" ... | \n",
- " 1.163815 | \n",
- " -0.069711 | \n",
- " -0.202117 | \n",
- " -1.262765 | \n",
- " 1.133570 | \n",
- " 0.884130 | \n",
- " 0.484024 | \n",
- " 3.124910 | \n",
- " 0.004369 | \n",
- " 0.349072 | \n",
+ " -1.896694 | \n",
+ " 2.950372 | \n",
+ " 2.266257 | \n",
+ " -1.962670 | \n",
+ " 1.913956 | \n",
+ " -3.847482 | \n",
+ " 0.914604 | \n",
+ " -1.721561 | \n",
+ " -0.954810 | \n",
+ " 0.407410 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 4 | \n",
- " 1.692969 | \n",
- " 0.776318 | \n",
- " -0.275264 | \n",
- " -0.787588 | \n",
- " -1.492324 | \n",
- " -3.920095 | \n",
- " -2.246768 | \n",
- " -1.655923 | \n",
+ " -0.478058 | \n",
+ " 0.043192 | \n",
+ " 0.978425 | \n",
+ " -2.568042 | \n",
+ " -1.001187 | \n",
+ " -2.151350 | \n",
+ " 0.973693 | \n",
+ " -1.286461 | \n",
" ... | \n",
- " -1.120645 | \n",
- " -1.726098 | \n",
- " -2.561617 | \n",
- " -2.247641 | \n",
- " 0.685799 | \n",
- " 3.943749 | \n",
- " 2.891479 | \n",
- " 5.381948 | \n",
- " 1.455669 | \n",
- " -2.480590 | \n",
+ " -1.148923 | \n",
+ " -3.388272 | \n",
+ " 1.121507 | \n",
+ " 4.753065 | \n",
+ " 1.424797 | \n",
+ " -2.345737 | \n",
+ " -0.693004 | \n",
+ " -1.618859 | \n",
+ " 1.668621 | \n",
+ " 4.571664 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1 | \n",
" 5 | \n",
- " -0.760349 | \n",
- " -0.423762 | \n",
- " -5.334104 | \n",
- " -3.650439 | \n",
- " 0.447345 | \n",
- " -4.053885 | \n",
- " 1.367370 | \n",
- " 0.752763 | \n",
+ " -1.223138 | \n",
+ " -1.544670 | \n",
+ " 1.398526 | \n",
+ " 3.567711 | \n",
+ " -1.353898 | \n",
+ " -2.226735 | \n",
+ " 3.713345 | \n",
+ " 0.675101 | \n",
" ... | \n",
- " -1.666836 | \n",
- " -0.607192 | \n",
- " 3.658921 | \n",
- " -1.153617 | \n",
- " 1.338251 | \n",
- " 2.990290 | \n",
- " -0.717240 | \n",
- " 2.494413 | \n",
- " -0.576748 | \n",
- " -0.049214 | \n",
+ " -1.837662 | \n",
+ " 2.941762 | \n",
+ " 2.061895 | \n",
+ " 0.853080 | \n",
+ " -0.244278 | \n",
+ " 1.263040 | \n",
+ " -2.011630 | \n",
+ " -0.826488 | \n",
+ " -0.887181 | \n",
+ " -1.935609 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -309,121 +309,121 @@
" 995 | \n",
" 100 | \n",
" 6 | \n",
- " 6.736806 | \n",
- " 3.652808 | \n",
- " 0.935941 | \n",
- " -1.046372 | \n",
- " 2.972970 | \n",
- " -3.188139 | \n",
- " -1.368655 | \n",
- " -1.973138 | \n",
+ " -2.734367 | \n",
+ " -0.265121 | \n",
+ " -3.045957 | \n",
+ " 1.798310 | \n",
+ " -1.485007 | \n",
+ " -2.107084 | \n",
+ " 2.983506 | \n",
+ " 1.469852 | \n",
" ... | \n",
- " -0.693739 | \n",
- " 1.142071 | \n",
- " -1.017755 | \n",
- " 4.427896 | \n",
- " 0.988486 | \n",
- " -0.082238 | \n",
- " 1.047705 | \n",
- " 1.222772 | \n",
- " 3.264437 | \n",
- " 1.173541 | \n",
+ " 4.836053 | \n",
+ " 1.932549 | \n",
+ " 2.307781 | \n",
+ " -2.536377 | \n",
+ " 1.150598 | \n",
+ " 1.052909 | \n",
+ " -0.969657 | \n",
+ " 1.266473 | \n",
+ " -3.177021 | \n",
+ " -3.070155 | \n",
"
\n",
" \n",
" | 996 | \n",
" 100 | \n",
" 7 | \n",
- " 6.023375 | \n",
- " 3.868859 | \n",
- " -0.667464 | \n",
- " 5.690558 | \n",
- " 4.230361 | \n",
- " 0.512239 | \n",
- " -0.016779 | \n",
- " -1.998309 | \n",
+ " 5.307034 | \n",
+ " 3.585023 | \n",
+ " -2.274700 | \n",
+ " -2.499113 | \n",
+ " 2.116755 | \n",
+ " 0.478902 | \n",
+ " -0.248561 | \n",
+ " -0.957826 | \n",
" ... | \n",
- " 0.429095 | \n",
- " -0.879970 | \n",
- " -0.333486 | \n",
- " 0.856337 | \n",
- " 3.159868 | \n",
- " -2.772002 | \n",
- " 1.782697 | \n",
- " 1.158639 | \n",
- " 1.373919 | \n",
- " -3.298531 | \n",
+ " -2.799372 | \n",
+ " 1.598447 | \n",
+ " 1.972620 | \n",
+ " -1.888645 | \n",
+ " 1.237270 | \n",
+ " 3.644984 | \n",
+ " 0.054862 | \n",
+ " 0.615274 | \n",
+ " -0.432120 | \n",
+ " -0.046949 | \n",
"
\n",
" \n",
" | 997 | \n",
" 100 | \n",
" 8 | \n",
- " 3.110866 | \n",
- " 0.955741 | \n",
- " 1.789798 | \n",
- " -1.256050 | \n",
- " -0.369970 | \n",
- " -1.965363 | \n",
- " 0.805850 | \n",
- " -0.837638 | \n",
+ " -4.476127 | \n",
+ " -2.751544 | \n",
+ " -0.716859 | \n",
+ " -1.263491 | \n",
+ " -2.826469 | \n",
+ " -0.954049 | \n",
+ " -1.237438 | \n",
+ " -1.200074 | \n",
" ... | \n",
- " 0.314910 | \n",
- " 2.215582 | \n",
- " 1.695434 | \n",
- " -0.092416 | \n",
- " -0.691905 | \n",
- " 2.864254 | \n",
- " 1.673074 | \n",
- " -0.839339 | \n",
- " 1.389325 | \n",
- " 1.714632 | \n",
+ " 1.527836 | \n",
+ " -1.918927 | \n",
+ " -0.381272 | \n",
+ " 2.065848 | \n",
+ " 0.723859 | \n",
+ " -0.711546 | \n",
+ " 0.930980 | \n",
+ " 0.883152 | \n",
+ " -0.324217 | \n",
+ " 0.053768 | \n",
"
\n",
" \n",
" | 998 | \n",
" 100 | \n",
" 9 | \n",
- " 4.572219 | \n",
- " 3.239447 | \n",
- " -3.129920 | \n",
- " -1.654972 | \n",
- " 3.222430 | \n",
- " -1.193018 | \n",
- " 0.287887 | \n",
- " -0.382851 | \n",
+ " 5.191475 | \n",
+ " 4.985730 | \n",
+ " -2.840246 | \n",
+ " 0.931855 | \n",
+ " 3.070040 | \n",
+ " 2.700103 | \n",
+ " 1.214848 | \n",
+ " 2.846577 | \n",
" ... | \n",
- " -0.481192 | \n",
- " 1.541299 | \n",
- " 1.153674 | \n",
- " -4.382081 | \n",
- " 4.017794 | \n",
- " 1.117018 | \n",
- " -1.648193 | \n",
- " -1.139779 | \n",
- " 2.748231 | \n",
- " 3.032575 | \n",
+ " -1.662106 | \n",
+ " 0.583185 | \n",
+ " 2.117253 | \n",
+ " -0.429837 | \n",
+ " -1.983224 | \n",
+ " -1.249148 | \n",
+ " 5.170035 | \n",
+ " 3.022710 | \n",
+ " 3.091618 | \n",
+ " -2.210554 | \n",
"
\n",
" \n",
" | 999 | \n",
" 100 | \n",
" 10 | \n",
- " -1.170401 | \n",
- " -0.463559 | \n",
- " -5.695356 | \n",
- " -1.765611 | \n",
- " 0.205248 | \n",
- " -0.671551 | \n",
- " 3.462028 | \n",
- " -2.128120 | \n",
+ " -4.046210 | \n",
+ " 0.194942 | \n",
+ " -0.998063 | \n",
+ " 1.876639 | \n",
+ " -0.401315 | \n",
+ " -2.042387 | \n",
+ " -0.389824 | \n",
+ " -0.388875 | \n",
" ... | \n",
- " 3.930542 | \n",
- " -1.196083 | \n",
- " 1.066743 | \n",
- " 0.690258 | \n",
- " 0.791818 | \n",
- " 2.491745 | \n",
- " 0.168359 | \n",
- " 2.278172 | \n",
- " -1.443654 | \n",
- " -5.695066 | \n",
+ " 0.951100 | \n",
+ " 0.987539 | \n",
+ " -2.201030 | \n",
+ " 0.144916 | \n",
+ " -1.977077 | \n",
+ " -2.538484 | \n",
+ " -1.978323 | \n",
+ " 2.068496 | \n",
+ " -2.546201 | \n",
+ " 2.218969 | \n",
"
\n",
" \n",
"\n",
@@ -432,54 +432,54 @@
],
"text/plain": [
" id time y d x1 x2 x3 x4 \\\n",
- "0 1 1 -1.767568 -1.348237 0.492677 0.281933 -2.142618 0.791443 \n",
- "1 1 2 -5.095199 -3.566642 -1.608388 -0.819905 -3.570497 1.583374 \n",
- "2 1 3 6.437985 4.202518 -1.050850 -1.400580 3.183289 3.513685 \n",
- "3 1 4 1.692969 0.776318 -0.275264 -0.787588 -1.492324 -3.920095 \n",
- "4 1 5 -0.760349 -0.423762 -5.334104 -3.650439 0.447345 -4.053885 \n",
+ "0 1 1 -5.891762 -4.574437 -1.203377 -0.087056 -3.875989 -2.060015 \n",
+ "1 1 2 0.601641 1.217127 -1.076318 2.226439 0.379887 -2.491481 \n",
+ "2 1 3 -5.336432 -4.084917 1.684777 -2.117207 -4.038299 1.196702 \n",
+ "3 1 4 -0.478058 0.043192 0.978425 -2.568042 -1.001187 -2.151350 \n",
+ "4 1 5 -1.223138 -1.544670 1.398526 3.567711 -1.353898 -2.226735 \n",
".. ... ... ... ... ... ... ... ... \n",
- "995 100 6 6.736806 3.652808 0.935941 -1.046372 2.972970 -3.188139 \n",
- "996 100 7 6.023375 3.868859 -0.667464 5.690558 4.230361 0.512239 \n",
- "997 100 8 3.110866 0.955741 1.789798 -1.256050 -0.369970 -1.965363 \n",
- "998 100 9 4.572219 3.239447 -3.129920 -1.654972 3.222430 -1.193018 \n",
- "999 100 10 -1.170401 -0.463559 -5.695356 -1.765611 0.205248 -0.671551 \n",
+ "995 100 6 -2.734367 -0.265121 -3.045957 1.798310 -1.485007 -2.107084 \n",
+ "996 100 7 5.307034 3.585023 -2.274700 -2.499113 2.116755 0.478902 \n",
+ "997 100 8 -4.476127 -2.751544 -0.716859 -1.263491 -2.826469 -0.954049 \n",
+ "998 100 9 5.191475 4.985730 -2.840246 0.931855 3.070040 2.700103 \n",
+ "999 100 10 -4.046210 0.194942 -0.998063 1.876639 -0.401315 -2.042387 \n",
"\n",
" x5 x6 ... x21 x22 x23 x24 \\\n",
- "0 2.976620 -1.001715 ... -1.198500 -0.049211 0.600889 2.435667 \n",
- "1 1.644214 -4.221177 ... 4.517936 0.413499 2.150563 -2.971910 \n",
- "2 1.861339 0.888485 ... 1.163815 -0.069711 -0.202117 -1.262765 \n",
- "3 -2.246768 -1.655923 ... -1.120645 -1.726098 -2.561617 -2.247641 \n",
- "4 1.367370 0.752763 ... -1.666836 -0.607192 3.658921 -1.153617 \n",
+ "0 1.924513 1.564501 ... -1.697513 -2.509320 -0.727138 -2.393134 \n",
+ "1 -1.446766 -0.000182 ... -0.185932 -0.491894 1.320808 -2.888978 \n",
+ "2 4.320428 1.543303 ... -1.896694 2.950372 2.266257 -1.962670 \n",
+ "3 0.973693 -1.286461 ... -1.148923 -3.388272 1.121507 4.753065 \n",
+ "4 3.713345 0.675101 ... -1.837662 2.941762 2.061895 0.853080 \n",
".. ... ... ... ... ... ... ... \n",
- "995 -1.368655 -1.973138 ... -0.693739 1.142071 -1.017755 4.427896 \n",
- "996 -0.016779 -1.998309 ... 0.429095 -0.879970 -0.333486 0.856337 \n",
- "997 0.805850 -0.837638 ... 0.314910 2.215582 1.695434 -0.092416 \n",
- "998 0.287887 -0.382851 ... -0.481192 1.541299 1.153674 -4.382081 \n",
- "999 3.462028 -2.128120 ... 3.930542 -1.196083 1.066743 0.690258 \n",
+ "995 2.983506 1.469852 ... 4.836053 1.932549 2.307781 -2.536377 \n",
+ "996 -0.248561 -0.957826 ... -2.799372 1.598447 1.972620 -1.888645 \n",
+ "997 -1.237438 -1.200074 ... 1.527836 -1.918927 -0.381272 2.065848 \n",
+ "998 1.214848 2.846577 ... -1.662106 0.583185 2.117253 -0.429837 \n",
+ "999 -0.389824 -0.388875 ... 0.951100 0.987539 -2.201030 0.144916 \n",
"\n",
" x25 x26 x27 x28 x29 x30 \n",
- "0 -1.387149 3.034459 -0.062419 0.258166 -1.168477 -1.061057 \n",
- "1 0.922270 -2.628696 -1.772420 -3.851087 3.270008 0.820763 \n",
- "2 1.133570 0.884130 0.484024 3.124910 0.004369 0.349072 \n",
- "3 0.685799 3.943749 2.891479 5.381948 1.455669 -2.480590 \n",
- "4 1.338251 2.990290 -0.717240 2.494413 -0.576748 -0.049214 \n",
+ "0 0.334781 2.097534 1.942009 1.649557 -0.612257 -4.331109 \n",
+ "1 0.296153 -0.209147 -1.066396 -2.232003 3.217619 -0.660709 \n",
+ "2 1.913956 -3.847482 0.914604 -1.721561 -0.954810 0.407410 \n",
+ "3 1.424797 -2.345737 -0.693004 -1.618859 1.668621 4.571664 \n",
+ "4 -0.244278 1.263040 -2.011630 -0.826488 -0.887181 -1.935609 \n",
".. ... ... ... ... ... ... \n",
- "995 0.988486 -0.082238 1.047705 1.222772 3.264437 1.173541 \n",
- "996 3.159868 -2.772002 1.782697 1.158639 1.373919 -3.298531 \n",
- "997 -0.691905 2.864254 1.673074 -0.839339 1.389325 1.714632 \n",
- "998 4.017794 1.117018 -1.648193 -1.139779 2.748231 3.032575 \n",
- "999 0.791818 2.491745 0.168359 2.278172 -1.443654 -5.695066 \n",
+ "995 1.150598 1.052909 -0.969657 1.266473 -3.177021 -3.070155 \n",
+ "996 1.237270 3.644984 0.054862 0.615274 -0.432120 -0.046949 \n",
+ "997 0.723859 -0.711546 0.930980 0.883152 -0.324217 0.053768 \n",
+ "998 -1.983224 -1.249148 5.170035 3.022710 3.091618 -2.210554 \n",
+ "999 -1.977077 -2.538484 -1.978323 2.068496 -2.546201 2.218969 \n",
"\n",
"[1000 rows x 34 columns]"
]
},
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "data = make_static_panel_CP2025(num_n=100, dgp_type='dgp1', x_var=5, a_var=0.95)\n",
+ "data = make_plpr_CP2025(num_id=100, dgp_type='dgp1')\n",
"data"
]
},
@@ -490,7 +490,7 @@
"metadata": {},
"outputs": [],
"source": [
- "def run_sim(n_reps, num_n, dim_x=30, theta=0.5, dgp_type='dgp3'):\n",
+ "def run_sim(n_reps, num_id, dim_x=30, theta=0.5, dgp_type='dgp3'):\n",
"\n",
" approaches = [\"cre_general\", \"cre_normal\", \"fd_exact\", \"wg_approx\"]\n",
" models = [\"lasso\", \"cart\", \"boost\"]\n",
@@ -535,7 +535,7 @@
" 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
" 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
"\n",
- " data = make_static_panel_CP2025(num_n=num_n, dim_x=dim_x, theta=theta, dgp_type=dgp_type, x_var=5, a_var=0.95)\n",
+ " data = make_plpr_CP2025(num_id=num_id, dim_x=dim_x, theta=theta, dgp_type=dgp_type)\n",
" dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', x_cols=x_cols, static_panel=True)\n",
"\n",
" # CRE general\n",
@@ -586,7 +586,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"id": "6503f741",
"metadata": {},
"outputs": [
@@ -751,13 +751,13 @@
"source": [
"np.random.seed(123)\n",
"\n",
- "res_dgp1 = run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp1')\n",
+ "res_dgp1 = run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp1')\n",
"res_dgp1"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"id": "7a7b3fa2",
"metadata": {},
"outputs": [
@@ -922,13 +922,13 @@
"source": [
"np.random.seed(123)\n",
"\n",
- "res_dgp2 = run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp2')\n",
+ "res_dgp2 = run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp2')\n",
"res_dgp2"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"id": "f8f0abee",
"metadata": {},
"outputs": [
@@ -1093,7 +1093,7 @@
"source": [
"np.random.seed(123)\n",
"\n",
- "run_sim(n_reps=100, num_n=100, theta=0.5, dgp_type='dgp3')"
+ "run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp3')"
]
}
],
From 453cbf6595082ee030403b401b5137773b36a445 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Sun, 30 Nov 2025 15:12:58 +0100
Subject: [PATCH 28/33] add basic plpr tests
---
doubleml/__init__.py | 2 +
doubleml/plm/__init__.py | 4 +-
doubleml/plm/plpr.py | 24 +-
doubleml/plm/tests/test_datasets.py | 8 +
doubleml/plm/tests/test_model_defaults.py | 22 +-
doubleml/plm/tests/test_plpr.py | 76 ++++++
doubleml/plm/tests/test_plpr_exceptions.py | 276 +++++++++++++++++++++
doubleml/plm/tests/test_plpr_tune.py | 104 ++++++++
doubleml/plm/tests/test_return_types.py | 1 +
9 files changed, 501 insertions(+), 16 deletions(-)
create mode 100644 doubleml/plm/tests/test_plpr.py
create mode 100644 doubleml/plm/tests/test_plpr_exceptions.py
create mode 100644 doubleml/plm/tests/test_plpr_tune.py
diff --git a/doubleml/__init__.py b/doubleml/__init__.py
index cb3891bac..06d4cd964 100644
--- a/doubleml/__init__.py
+++ b/doubleml/__init__.py
@@ -15,6 +15,7 @@
from .irm.ssm import DoubleMLSSM
from .plm.lplr import DoubleMLLPLR
from .plm.pliv import DoubleMLPLIV
+from .plm.plpr import DoubleMLPLPR
from .plm.plr import DoubleMLPLR
from .utils.blp import DoubleMLBLP
from .utils.policytree import DoubleMLPolicyTree
@@ -44,6 +45,7 @@
"DoubleMLPolicyTree",
"DoubleMLSSM",
"DoubleMLLPLR",
+ "DoubleMLPLPR",
]
__version__ = importlib.metadata.version("doubleml")
diff --git a/doubleml/plm/__init__.py b/doubleml/plm/__init__.py
index 6100e78cd..f4ce2b83d 100644
--- a/doubleml/plm/__init__.py
+++ b/doubleml/plm/__init__.py
@@ -4,7 +4,7 @@
from .lplr import DoubleMLLPLR
from .pliv import DoubleMLPLIV
-from .plr import DoubleMLPLR
from .plpr import DoubleMLPLPR
+from .plr import DoubleMLPLR
-__all__ = ["DoubleMLPLR", "DoubleMLPLIV", "DoubleMLLPLR"]
+__all__ = ["DoubleMLPLR", "DoubleMLPLIV", "DoubleMLLPLR", "DoubleMLPLPR"]
diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py
index ab35436f5..398dd5fca 100644
--- a/doubleml/plm/plpr.py
+++ b/doubleml/plm/plpr.py
@@ -17,19 +17,19 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
Parameters
----------
- obj_dml_data : :class:`DoubleMLData` object
+ obj_dml_data : :class:`DoubleMLPanelData` object
The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model.
ml_l : estimator implementing ``fit()`` and ``predict()``
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
- :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`\\ell_0(X) = E[Y|X]`.
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`\\l_0(X) = E[Y|X]`.
ml_m : estimator implementing ``fit()`` and ``predict()``
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
:py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`m_0(X) = E[D|X]`.
- For binary treatment variables :math:`D` (with values 0 and 1), a classifier implementing ``fit()`` and
- ``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier` returns ``True``,
- ``predict_proba()`` is used otherwise ``predict()``.
+ For binary treatment variables :math:`D` (with values 0 and 1) and the CRE approaches, a classifier
+ implementing ``fit()`` and ``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier`
+ returns ``True``, ``predict_proba()`` is used otherwise ``predict()``.
ml_g : estimator implementing ``fit()`` and ``predict()``
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
@@ -90,8 +90,8 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
where :math:`Y_{it}` is the outcome variable and :math:`D_{it}` is the policy variable of interest.
The high-dimensional vector :math:`X_{it} = (X_{it,1}, \\ldots, X_{it,p})` consists of other confounding covariates,
- :math:`\\alpha_i` and :math:`\\gamma_i` are the unobserved individual heterogeneity correlated with the included covariates,
- and :math:`\\U_{it}` and :math:` V_{it}` are stochastic errors.
+ :math:`\\alpha_i` and :math:`\\gamma_i` are the unobserved individual heterogeneity correlated with the included
+ covariates, and :math:`\\U_{it}` and :math:`V_{it}` are stochastic errors.
"""
def __init__(
@@ -302,8 +302,8 @@ def _set_d_mean(self):
self._d_mean = None
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
- x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
- x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
+ x, y = check_X_y(self._dml_data.x, self._dml_data.y, ensure_all_finite=False)
+ x, d = check_X_y(x, self._dml_data.d, ensure_all_finite=False)
m_external = external_predictions["ml_m"] is not None
l_external = external_predictions["ml_l"] is not None
if "ml_g" in self._learner:
@@ -430,8 +430,8 @@ def _nuisance_tuning(
search_mode,
n_iter_randomized_search,
):
- x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
- x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
+ x, y = check_X_y(self._dml_data.x, self._dml_data.y, ensure_all_finite=False)
+ x, d = check_X_y(x, self._dml_data.d, ensure_all_finite=False)
if scoring_methods is None:
scoring_methods = {"ml_l": None, "ml_m": None, "ml_g": None}
@@ -478,7 +478,7 @@ def _nuisance_tuning(
m_hat = np.full_like(d, np.nan)
for idx, (train_index, _) in enumerate(smpls):
l_hat[train_index] = l_tune_res[idx].predict(x[train_index, :])
- m_hat[train_index] = m_tune_res[idx].predict(x[train_index, :])
+ m_hat[train_index] = m_tune_res[idx].predict(x_m[train_index, :])
psi_a = -np.multiply(d - m_hat, d - m_hat)
psi_b = np.multiply(d - m_hat, y - l_hat)
theta_initial = -np.nanmean(psi_b) / np.nanmean(psi_a)
diff --git a/doubleml/plm/tests/test_datasets.py b/doubleml/plm/tests/test_datasets.py
index 5e16b9acf..2726ea269 100644
--- a/doubleml/plm/tests/test_datasets.py
+++ b/doubleml/plm/tests/test_datasets.py
@@ -9,6 +9,7 @@
make_lplr_LZZ2020,
make_pliv_CHS2015,
make_pliv_multiway_cluster_CKMS2021,
+ make_plpr_CP2025,
make_plr_CCDDHNR2018,
make_plr_turrell2018,
)
@@ -149,3 +150,10 @@ def test_make_lplr_LZZ2020_variants():
res = make_lplr_LZZ2020(n_obs=100, balanced_r0=False)
_, y_unique = np.unique(res.y, return_counts=True)
assert np.abs(y_unique[0] - y_unique[1]) > 10
+
+
+@pytest.mark.ci
+def test_make_plpr_CP2025_return_types():
+ np.random.seed(3141)
+ res = make_plpr_CP2025(num_id=100)
+ assert isinstance(res, pd.DataFrame)
diff --git a/doubleml/plm/tests/test_model_defaults.py b/doubleml/plm/tests/test_model_defaults.py
index b555f5ad5..6d8ef3226 100644
--- a/doubleml/plm/tests/test_model_defaults.py
+++ b/doubleml/plm/tests/test_model_defaults.py
@@ -1,14 +1,26 @@
import pytest
from sklearn.linear_model import LinearRegression, LogisticRegression
-from doubleml import DoubleMLLPLR
-from doubleml.plm.datasets import make_lplr_LZZ2020
+from doubleml import DoubleMLLPLR, DoubleMLPanelData, DoubleMLPLPR
+from doubleml.plm.datasets import make_lplr_LZZ2020, make_plpr_CP2025
from doubleml.utils._check_defaults import _check_basic_defaults_after_fit, _check_basic_defaults_before_fit, _fit_bootstrap
dml_data_lplr = make_lplr_LZZ2020(n_obs=100)
dml_lplr_obj = DoubleMLLPLR(dml_data_lplr, LogisticRegression(), LinearRegression(), LinearRegression())
+plpr_data = make_plpr_CP2025(num_id=100)
+dml_data_plpr = DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+)
+
+dml_plpr_obj = DoubleMLPLPR(dml_data_plpr, LinearRegression(), LinearRegression())
+
@pytest.mark.ci
def test_lplr_defaults():
@@ -17,3 +29,9 @@ def test_lplr_defaults():
_fit_bootstrap(dml_lplr_obj)
_check_basic_defaults_after_fit(dml_lplr_obj)
+
+
+@pytest.mark.ci
+def test_plpr_defaults():
+ _check_basic_defaults_before_fit(dml_plpr_obj)
+ # TODO: fit for cluster?
diff --git a/doubleml/plm/tests/test_plpr.py b/doubleml/plm/tests/test_plpr.py
new file mode 100644
index 000000000..3a5061588
--- /dev/null
+++ b/doubleml/plm/tests/test_plpr.py
@@ -0,0 +1,76 @@
+import numpy as np
+import pytest
+from sklearn.base import clone
+from sklearn.linear_model import Lasso, LinearRegression
+
+import doubleml as dml
+
+from ..datasets import make_plpr_CP2025
+
+
+@pytest.fixture(scope="module", params=[LinearRegression(), Lasso(alpha=0.1)])
+def learner(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["IV-type", "partialling out"])
+def score(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["cre_general", "cre_normal", "fd_exact", "wg_approx"])
+def approach(request):
+ return request.param
+
+
+@pytest.fixture(scope="module")
+def dml_plpr_fixture(
+ learner,
+ score,
+ approach,
+):
+ n_folds = 5
+ theta = 0.5
+
+ ml_l = clone(learner)
+ ml_m = clone(learner)
+ ml_g = clone(learner)
+
+ np.random.seed(3141)
+ plpr_data = make_plpr_CP2025(theta=theta)
+ obj_dml_data = dml.DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+ )
+ dml_plpr_obj = dml.DoubleMLPLPR(
+ obj_dml_data,
+ ml_l,
+ ml_m,
+ ml_g,
+ n_folds=n_folds,
+ score=score,
+ approach=approach,
+ )
+
+ dml_plpr_obj.fit()
+
+ res_dict = {
+ "coef": dml_plpr_obj.coef[0],
+ "se": dml_plpr_obj.se[0],
+ "true_coef": theta,
+ }
+
+ return res_dict
+
+
+@pytest.mark.ci
+def test_dml_selection_coef(dml_plpr_fixture):
+ # true_coef should lie within three standard deviations of the estimate
+ coef = dml_plpr_fixture["coef"]
+ se = dml_plpr_fixture["se"]
+ true_coef = dml_plpr_fixture["true_coef"]
+ assert abs(coef - true_coef) <= 3.0 * se
diff --git a/doubleml/plm/tests/test_plpr_exceptions.py b/doubleml/plm/tests/test_plpr_exceptions.py
new file mode 100644
index 000000000..f11811d23
--- /dev/null
+++ b/doubleml/plm/tests/test_plpr_exceptions.py
@@ -0,0 +1,276 @@
+import copy
+
+import numpy as np
+import pandas as pd
+import pytest
+from sklearn.base import BaseEstimator
+from sklearn.linear_model import Lasso, LogisticRegression
+
+from doubleml import DoubleMLPanelData, DoubleMLPLPR
+from doubleml.plm.datasets import make_plpr_CP2025
+
+np.random.seed(3141)
+num_id = 100
+# create test data and basic learners
+plpr_data = make_plpr_CP2025(num_id=num_id, theta=0.5, dim_x=30)
+plpr_data_binary = plpr_data.copy()
+plpr_data_binary["d"] = np.where(plpr_data_binary["d"] > 0, 1, 0)
+
+x_cols = [col for col in plpr_data.columns if "x" in col]
+dml_data = DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+)
+dml_data_iv = DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ x_cols=x_cols[:-1],
+ z_cols=x_cols[-1],
+ static_panel=True,
+)
+dml_data_binary = DoubleMLPanelData(
+ plpr_data_binary,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+)
+ml_l = Lasso(alpha=0.1)
+ml_m = Lasso(alpha=0.1)
+ml_g = Lasso(alpha=0.1)
+dml_plpr = DoubleMLPLPR(dml_data, ml_l, ml_m)
+dml_plpr_iv_type = DoubleMLPLPR(dml_data, ml_l, ml_m, ml_g, score="IV-type")
+
+
+@pytest.mark.ci
+def test_plpr_exception_data():
+ msg = "The data must be of DoubleMLPanelData type. was passed."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m)
+
+ # instrument
+ msg = (
+ r"Incompatible data. x30 have been set as instrumental variable\(s\). "
+ "DoubleMLPLPR currently does not support instrumental variables."
+ )
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data_iv, ml_l, ml_m)
+
+
+@pytest.mark.ci
+def test_plpr_exception_scores():
+ msg = "Invalid score IV. Valid score IV-type or partialling out."
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, score="IV")
+ msg = "score should be either a string or a callable. 0 was passed."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, score=0)
+
+
+@pytest.mark.ci
+def test_plpr_exception_approach():
+ # PLPR valid approaches are 'cre_general', 'cre_normal', 'fd_exact', and 'wg_approx'
+ msg = "Invalid approach cre. Valid approach cre_general or cre_normal or fd_exact or wg_approx."
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, approach="cre")
+ msg = "approach should be a string. 4 was passed."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, approach=4)
+
+
+@pytest.mark.ci
+def test_plpr_exception_resampling():
+ msg = "The number of folds must be of int type. 1.5 of type was passed."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, n_folds=1.5)
+ msg = "The number of repetitions for the sample splitting must be of int type. 1.5 of type was passed."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, n_rep=1.5)
+ msg = "The number of folds must be positive. 0 was passed."
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, n_folds=0)
+ msg = "The number of repetitions for the sample splitting must be positive. 0 was passed."
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, n_rep=0)
+ msg = "draw_sample_splitting must be True or False. Got true."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l, ml_m, draw_sample_splitting="true")
+
+
+@pytest.mark.ci
+def test_plpr_exception_get_params():
+ msg = "Invalid nuisance learner ml_r. Valid nuisance learner ml_l or ml_m."
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr.get_params("ml_r")
+ msg = "Invalid nuisance learner ml_g. Valid nuisance learner ml_l or ml_m."
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr.get_params("ml_g")
+ msg = "Invalid nuisance learner ml_r. Valid nuisance learner ml_l or ml_m or ml_g."
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr_iv_type.get_params("ml_r")
+
+
+# TODO: test_doubleml_exception_onefold(): for plpr?
+@pytest.mark.ci
+def test_plpr_exception_smpls():
+ msg = (
+ "Sample splitting not specified. "
+ r"Either draw samples via .draw_sample splitting\(\) or set external samples via .set_sample_splitting\(\)."
+ )
+ dml_plpr_no_smpls = DoubleMLPLPR(dml_data, ml_l, ml_m, draw_sample_splitting=False)
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_no_smpls.smpls
+
+ dml_plpr_cluster = dml_plpr
+ smpls = dml_plpr.smpls
+ msg = "For cluster data, all_smpls_cluster must be provided."
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_cluster.set_sample_splitting(smpls)
+
+ all_smpls_cluster = copy.deepcopy(dml_plpr_cluster.smpls_cluster)
+ all_smpls_cluster.append(all_smpls_cluster[0])
+ msg = "Invalid samples provided. Number of repetitions for all_smpls and all_smpls_cluster must be the same."
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_cluster.set_sample_splitting(all_smpls=dml_plpr_cluster.smpls, all_smpls_cluster=all_smpls_cluster)
+
+ all_smpls_cluster = copy.deepcopy(dml_plpr_cluster.smpls_cluster)
+ all_smpls_cluster.append(all_smpls_cluster[0])
+ msg = "Invalid samples provided. Number of repetitions for all_smpls and all_smpls_cluster must be the same."
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_cluster.set_sample_splitting(all_smpls=dml_plpr_cluster.smpls, all_smpls_cluster=all_smpls_cluster)
+
+ all_smpls_cluster = copy.deepcopy(dml_plpr_cluster.smpls_cluster)
+ all_smpls_cluster[0][0][1][0] = np.append(all_smpls_cluster[0][0][1][0], [11], axis=0)
+ msg = "Invalid cluster partition provided. At least one inner list does not form a partition."
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_cluster.set_sample_splitting(all_smpls=dml_plpr_cluster.smpls, all_smpls_cluster=all_smpls_cluster)
+
+ all_smpls_cluster = copy.deepcopy(dml_plpr_cluster.smpls_cluster)
+ all_smpls_cluster[0][0][1][0][1] = 11
+ msg = "Invalid cluster partition provided. At least one inner list does not form a partition."
+ with pytest.raises(ValueError, match=msg):
+ _ = dml_plpr_cluster.set_sample_splitting(all_smpls=dml_plpr_cluster.smpls, all_smpls_cluster=all_smpls_cluster)
+
+
+@pytest.mark.ci
+def test_plpr_exception_fit():
+ msg = "The number of CPUs used to fit the learners must be of int type. 5 of type was passed."
+ with pytest.raises(TypeError, match=msg):
+ dml_plpr.fit(n_jobs_cv="5")
+ msg = "store_predictions must be True or False. Got 1."
+ with pytest.raises(TypeError, match=msg):
+ dml_plpr.fit(store_predictions=1)
+ msg = "store_models must be True or False. Got 1."
+ with pytest.raises(TypeError, match=msg):
+ dml_plpr.fit(store_models=1)
+
+
+@pytest.mark.ci
+def test_plpr_exception_set_ml_nuisance_params():
+ msg = "Invalid nuisance learner g. Valid nuisance learner ml_l or ml_m."
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr.set_ml_nuisance_params("g", "d", {"alpha": 0.1})
+ msg = "Invalid treatment variable y. Valid treatment variable d_diff."
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr.set_ml_nuisance_params("ml_l", "y", {"alpha": 0.1})
+
+
+class _DummyNoSetParams:
+ def fit(self):
+ pass
+
+
+class _DummyNoGetParams(_DummyNoSetParams):
+ def set_params(self):
+ pass
+
+
+class _DummyNoClassifier(_DummyNoGetParams, BaseEstimator):
+ def get_params(self, deep=True):
+ return {}
+
+ def predict_proba(self):
+ pass
+
+
+class LogisticRegressionManipulatedPredict(LogisticRegression):
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.estimator_type = None
+ return tags
+
+ def predict(self, X):
+ if self.max_iter == 314:
+ preds = super().predict_proba(X)[:, 1]
+ else:
+ preds = super().predict(X)
+ return preds
+
+
+@pytest.mark.ci
+def test_plpr_exception_learner():
+ err_msg_prefix = "Invalid learner provided for ml_l: "
+
+ msg = err_msg_prefix + "provide an instance of a learner instead of a class."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, Lasso, ml_m)
+ msg = err_msg_prefix + r"BaseEstimator\(\) has no method .fit\(\)."
+ with pytest.raises(TypeError, match=msg):
+ _ = DoubleMLPLPR(dml_data, BaseEstimator(), ml_m)
+ # msg = err_msg_prefix + r'_DummyNoSetParams\(\) has no method .set_params\(\).'
+ with pytest.raises(TypeError):
+ _ = DoubleMLPLPR(dml_data, _DummyNoSetParams(), ml_m)
+ # msg = err_msg_prefix + r'_DummyNoSetParams\(\) has no method .get_params\(\).'
+ with pytest.raises(TypeError):
+ _ = DoubleMLPLPR(dml_data, _DummyNoGetParams(), ml_m)
+
+ # we allow classifiers for ml_m in PLPR, but only for binary treatment variables
+ msg = (
+ r"The ml_m learner LogisticRegression\(\) was identified as classifier "
+ "but at least one treatment variable is not binary with values 0 and 1."
+ )
+ with pytest.raises(ValueError, match=msg):
+ _ = DoubleMLPLPR(dml_data, Lasso(), LogisticRegression())
+
+ msg = r"For score = 'IV-type', learners ml_l and ml_g should be specified. Set ml_g = clone\(ml_l\)."
+ with pytest.warns(UserWarning, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l=Lasso(), ml_m=ml_m, score="IV-type")
+
+ msg = 'A learner ml_g has been provided for score = "partialling out" but will be ignored.'
+ with pytest.warns(UserWarning, match=msg):
+ _ = DoubleMLPLPR(dml_data, ml_l=Lasso(), ml_m=Lasso(), ml_g=Lasso(), score="partialling out")
+
+ # construct a classifier which is not identifiable as classifier via is_classifier by sklearn
+ # it then predicts labels and therefore an exception will be thrown
+ # TODO: cases with approaches cre_general, fd_exact, wg_approx
+ log_reg = LogisticRegressionManipulatedPredict()
+ msg_warn = (
+ r"Learner provided for ml_m is probably invalid: LogisticRegressionManipulatedPredict\(\) is \(probably\) "
+ "neither a regressor nor a classifier. Method predict is used for prediction."
+ )
+ with pytest.warns(UserWarning, match=msg_warn):
+ dml_plpr_hidden_classifier = DoubleMLPLPR(dml_data_binary, Lasso(), log_reg, approach="cre_normal")
+ msg = (
+ r"For the binary variable d, predictions obtained with the ml_m learner LogisticRegressionManipulatedPredict\(\) "
+ "are also observed to be binary with values 0 and 1. Make sure that for classifiers probabilities and not "
+ "labels are predicted."
+ )
+ with pytest.warns(UserWarning, match=msg_warn):
+ with pytest.raises(ValueError, match=msg):
+ dml_plpr_hidden_classifier.fit()
+
+
+@pytest.mark.ci
+@pytest.mark.filterwarnings("ignore:Learner provided for")
+def test_plpr_exception_and_warning_learner():
+ # msg = err_msg_prefix + r'_DummyNoClassifier\(\) has no method .predict\(\).'
+ with pytest.raises(TypeError):
+ _ = DoubleMLPLPR(dml_data, _DummyNoClassifier(), Lasso())
diff --git a/doubleml/plm/tests/test_plpr_tune.py b/doubleml/plm/tests/test_plpr_tune.py
new file mode 100644
index 000000000..411a9bd90
--- /dev/null
+++ b/doubleml/plm/tests/test_plpr_tune.py
@@ -0,0 +1,104 @@
+import numpy as np
+import pytest
+from sklearn.base import clone
+from sklearn.linear_model import Lasso
+
+import doubleml as dml
+
+from ..datasets import make_plpr_CP2025
+
+
+@pytest.fixture(scope="module", params=[Lasso()])
+def learner_l(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[Lasso()])
+def learner_m(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[Lasso()])
+def learner_g(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["partialling out", "IV-type"])
+def score(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["cre_general", "cre_normal", "fd_exact", "wg_approx"])
+def approach(request):
+ return request.param
+
+
+def get_par_grid():
+ par_grid = {"alpha": np.linspace(0.05, 0.95, 7)}
+ return par_grid
+
+
+@pytest.fixture(scope="module")
+def dml_plpr_fixture(
+ learner_l,
+ learner_m,
+ learner_g,
+ score,
+ approach,
+ tune_on_folds=False,
+):
+ par_grid = {
+ "ml_l": get_par_grid(),
+ "ml_m": get_par_grid(),
+ "ml_g": get_par_grid(),
+ }
+ n_folds_tune = 4
+ n_folds = 5
+ theta = 0.5
+
+ ml_l = clone(learner_l)
+ ml_m = clone(learner_m)
+ ml_g = clone(learner_g)
+
+ np.random.seed(3141)
+ plpr_data = make_plpr_CP2025(theta=theta)
+ obj_dml_data = dml.DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+ )
+ dml_sel_obj = dml.DoubleMLPLPR(
+ obj_dml_data,
+ ml_l,
+ ml_m,
+ ml_g,
+ n_folds=n_folds,
+ score=score,
+ approach=approach,
+ )
+
+ # tune hyperparameters
+ tune_res = dml_sel_obj.tune(par_grid, tune_on_folds=tune_on_folds, n_folds_tune=n_folds_tune, return_tune_res=False)
+ assert isinstance(tune_res, dml.DoubleMLPLPR)
+
+ dml_sel_obj.fit()
+
+ res_dict = {
+ "coef": dml_sel_obj.coef[0],
+ "se": dml_sel_obj.se[0],
+ "true_coef": theta,
+ }
+
+ return res_dict
+
+
+@pytest.mark.ci
+def test_dml_selection_coef(dml_plpr_fixture):
+ # true_coef should lie within three standard deviations of the estimate
+ coef = dml_plpr_fixture["coef"]
+ se = dml_plpr_fixture["se"]
+ true_coef = dml_plpr_fixture["true_coef"]
+ assert abs(coef - true_coef) <= 3.0 * se
diff --git a/doubleml/plm/tests/test_return_types.py b/doubleml/plm/tests/test_return_types.py
index cb32f5433..62075fc37 100644
--- a/doubleml/plm/tests/test_return_types.py
+++ b/doubleml/plm/tests/test_return_types.py
@@ -41,6 +41,7 @@
(dml_lplr_obj, DoubleMLLPLR),
(dml_lplr_obj_binary, DoubleMLLPLR),
]
+# TODO: plpr with cluster data return type tests? n_obs is changed for fd_exact approach
@pytest.mark.ci
From 86c947aeda049daf6247ffc32c85b1d25a7abad7 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 1 Dec 2025 15:43:28 +0100
Subject: [PATCH 29/33] remove notebooks
---
doubleml/plm/sim/example_sim.ipynb | 1068 -------------------------
doubleml/plm/sim/learners_sim.ipynb | 1121 ---------------------------
doubleml/plm/utils/_plpr_util.py | 73 --
3 files changed, 2262 deletions(-)
delete mode 100644 doubleml/plm/sim/example_sim.ipynb
delete mode 100644 doubleml/plm/sim/learners_sim.ipynb
delete mode 100644 doubleml/plm/utils/_plpr_util.py
diff --git a/doubleml/plm/sim/example_sim.ipynb b/doubleml/plm/sim/example_sim.ipynb
deleted file mode 100644
index a5f8db936..000000000
--- a/doubleml/plm/sim/example_sim.ipynb
+++ /dev/null
@@ -1,1068 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "01469cb2",
- "metadata": {},
- "source": [
- "### Double Machine Learning for Static Panel Models with Fixed Effects\n",
- "\n",
- "Extending the partially linear model to panel data by introducing fixed effects $\\alpha_i$ to give the partially linear panel regression (PLPR) model.\n",
- "\n",
- "Partialled-out PLPR (PO-PLPR) model:\n",
- "\n",
- "\\begin{align*}\n",
- " Y_{it} &= \\theta_0 D_{it} + g_0(X_{it}) + \\alpha_i + U_{it} \\\\\n",
- " D_{it} &= m_0(X_{it}) + \\gamma_i + V_{it}\n",
- "\\end{align*}\n",
- "\n",
- "- $Y_{it}$ outcome, $D_{it}$ treatment, $X_{it}$ covariates, $\\theta_0$ causal treatment effect\n",
- "- $g_0$ and $m_0$ nuisance functions\n",
- "- $\\alpha_i$, $\\gamma_i$ unobserved individual heterogeneity, correlated with covariates\n",
- "- $U_{it}$, $V_{it}$ error terms\n",
- "\n",
- "Further note $E[U_{it} \\mid D_{it}, X_{it}, \\alpha_i] = 0$, but $E[\\alpha_i \\mid D_{it}, X_{it}] \\neq 0$, and $E[V_{it} \\mid X_{it}, \\gamma_i]=0$\n",
- "\n",
- "#### 1 Correlated Random Effect Approach\n",
- "\n",
- "##### 1.1. General case:\n",
- "\n",
- "- Learning $g_0$ from $\\{ Y_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$\n",
- "- First learning $\\tilde{m}_0({\\cdot})$ from $\\{ D_{it}, X_{it}, \\bar{X}_i : t=1,\\dots, T \\}_{i=1}^N$ with prediction $\\hat{m}_{0it} = \\tilde{m}_0 (X_{it}, \\bar{X}_i) $\n",
- " - Calculate $\\hat{\\bar{m}}_i = T^{-1} \\sum_{t=1}^T \\hat{m}_{0it} $\n",
- " - Calculate final nuisance part as $ \\hat{m}^*_0 (X_{it}, \\bar{X}_i, \\bar{D}_i) = \\hat{m}_{0it} + \\bar{D}_i - \\hat{\\bar{m}}_i $ \n",
- "\n",
- "##### 1.2. Normal assumption:\n",
- "\n",
- "(conditional distribution $ D_{i1}, \\dots, D_{iT} \\mid X_{i1}, \\dots X_{iT} $ is multivariate normal)\n",
- "- Learn $m^*_{0}$ from $\\{ D_{it}, X_{it}, \\bar{X}_i, \\bar{D}_i: t=1,\\dots, T \\}_{i=1}^N$\n",
- "\n",
- "#### 2. Transformation Approaches\n",
- "\n",
- "##### 2.1. First Difference (FD) Transformation - Exact\n",
- "\n",
- "Consider FD transformation $Q(Y_{it})= Y_{it} - Y_{it-1} $, under Assumptions 3.1-3.5, transformed nuisance function can be learnt as\n",
- "\n",
- "- $ \\Delta g_0 (X_{it-1}, X_{it}) $ from $ \\{ Y_{it}-Y_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
- "- $ \\Delta m_0 (X_{it-1}, X_{it}) $ from $ \\{ D_{it}-D_{it-1}, X_{it-1}, X_{it} : t=2, \\dots , T \\}_{i=1}^N $\n",
- "\n",
- "##### 2.2. Within Group (WG) Transformation - Approximate\n",
- "\n",
- "For WG transformation $Q(X_{it})= X_{it} - \\bar{X}_{i} $, where $ \\bar{X}_{i} = T^{-1} \\sum_{t=1}^T X_{it} $. Approximate model\n",
- "\\begin{align*}\n",
- " Q(Y_{it}) &\\approx \\theta_0 Q(D_{it}) + g_0 (Q(X_{it})) + Q(U_{it}) \\\\\n",
- " Q(D_{it}) &\\approx m_0 (Q(X_{it})) + Q(V_{it})\n",
- "\\end{align*}\n",
- "\n",
- "- $g_0$ can be learnt from transformed data $ \\{ Q(Y_{it}), Q(X_{it}) : t=1,\\dots,T \\}_{i=1}^N $\n",
- "- $m_0$ can be learnt from transformed data $ \\{ Q(D_{it}), Q(X_{it}) : t=1,\\dots,T \\}_{i=1}^N $\n",
- "\n",
- "#### Implementation\n",
- "\n",
- "- Using block-k-fold cross-fitting, where the entire time series of the sampled unit is allocated to one fold to allow for possible serial correlation\n",
- "within each unit as is common with panel data\n",
- "\n",
- "- Cluster robust standard error\n",
- "\n",
- "$\\Rightarrow$ using id variable as cluster for DML"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "6dfa56df",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "import statsmodels.api as sm\n",
- "from doubleml.data.base_data import DoubleMLData \n",
- "from doubleml.data.panel_data import DoubleMLPanelData\n",
- "from doubleml.plm.plpr import DoubleMLPLPR\n",
- "from sklearn.linear_model import LassoCV, LinearRegression, LogisticRegressionCV\n",
- "from sklearn.base import clone\n",
- "from sklearn.preprocessing import StandardScaler\n",
- "from sklearn.preprocessing import PolynomialFeatures\n",
- "from sklearn.compose import ColumnTransformer\n",
- "from sklearn.pipeline import make_pipeline\n",
- "from doubleml.plm.utils._plpr_util import cre_fct, fd_fct, wd_fct, extend_data\n",
- "from doubleml.plm.datasets.dgp_plpr_CP2025 import make_plpr_CP2025\n",
- "import warnings\n",
- "warnings.filterwarnings(\"ignore\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "8bf5c99c",
- "metadata": {},
- "outputs": [],
- "source": [
- "from sklearn.base import BaseEstimator, TransformerMixin\n",
- "\n",
- "class PolyPlus(BaseEstimator, TransformerMixin):\n",
- " \"\"\"PolynomialFeatures(degree=k) and additional terms x_i^(k+1).\"\"\"\n",
- "\n",
- " def __init__(self, degree=2, interaction_only=False, include_bias=False):\n",
- " self.degree = degree\n",
- " self.extra_degree = degree + 1\n",
- " self.interaction_only = interaction_only\n",
- " self.include_bias = include_bias\n",
- " self.poly = PolynomialFeatures(degree=degree, interaction_only=interaction_only, include_bias=include_bias)\n",
- "\n",
- " def fit(self, X, y=None):\n",
- " self.poly.fit(X)\n",
- " self.n_features_in_ = X.shape[1]\n",
- " return self\n",
- "\n",
- " def transform(self, X):\n",
- " X = np.asarray(X)\n",
- " X_poly = self.poly.transform(X)\n",
- " X_extra = X ** self.extra_degree\n",
- " return np.hstack([X_poly, X_extra])\n",
- "\n",
- " def get_feature_names_out(self, input_features=None):\n",
- " input_features = np.array(\n",
- " input_features\n",
- " if input_features is not None\n",
- " else [f\"x{i}\" for i in range(self.n_features_in_)]\n",
- " )\n",
- " poly_names = self.poly.get_feature_names_out(input_features)\n",
- " extra_names = [f\"{name}^{self.extra_degree}\" for name in input_features]\n",
- " return np.concatenate([poly_names, extra_names])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7c5cd6bc",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array(['poly_x__x0', 'poly_x__x1', 'poly_x__x0^2', 'poly_x__x0 x1',\n",
- " 'poly_x__x1^2', 'poly_x__x0^3', 'poly_x__x1^3', 'poly_x_mean__x2',\n",
- " 'poly_x_mean__x3', 'poly_x_mean__x2^2', 'poly_x_mean__x2 x3',\n",
- " 'poly_x_mean__x3^2', 'poly_x_mean__x2^3', 'poly_x_mean__x3^3',\n",
- " 'remainder__x4'], dtype=object)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "preprocessor = ColumnTransformer([\n",
- " ('poly_x', make_pipeline(\n",
- " PolyPlus(include_bias=False, interaction_only=False)\n",
- " ), [0, 1]), \n",
- " ('poly_x_mean', make_pipeline(\n",
- " PolyPlus(include_bias=False, interaction_only=False)\n",
- " ), [2, 3]) \n",
- "], remainder='passthrough')\n",
- "\n",
- "learner = make_pipeline(\n",
- " preprocessor,\n",
- " StandardScaler(),\n",
- " LinearRegression()\n",
- ")\n",
- "\n",
- "np.random.seed(1)\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "\n",
- "y = np.array(data['y'])\n",
- "X = np.array(data[['x1', 'x2', 'x3', 'x4', 'x5']])\n",
- "\n",
- "pred = learner.fit(X, y).predict(X)\n",
- "\n",
- "learner.named_steps['columntransformer'].get_feature_names_out()\n",
- "# learner.named_steps['columntransformer']['poly_x'].get_feature_names_out()\n",
- "# learner.named_steps['linearregression'].n_features_in_"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3c061cf1",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(np.float64(0.6719371174913912),\n",
- " np.float64(0.6090488219157397),\n",
- " np.float64(0.7348254130670426))"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.random.seed(1)\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "\n",
- "x_cols = [col for col in data.columns if \"x\" in col]\n",
- "\n",
- "X = sm.add_constant(data[['d'] + x_cols])\n",
- "y = data['y']\n",
- "clusters = data['id']\n",
- "\n",
- "ols_model = sm.OLS(y, X).fit(cov_type='cluster', cov_kwds={'groups': clusters})\n",
- "ols_model.params['d'], ols_model.conf_int().loc['d'][0], ols_model.conf_int().loc['d'][1]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "44387af3",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.506545 0.020722 24.444675 5.733308e-132 0.465931 0.54716\n"
- ]
- }
- ],
- "source": [
- "# cre general\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
- "\n",
- "learner = LassoCV()\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
- " \n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 49,
- "id": "7f3b2faa",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "True"
- ]
- },
- "execution_count": 49,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_plpr_obj._is_cluster_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "7deabe55",
- "metadata": {},
- "outputs": [],
- "source": [
- "# model rmse\n",
- "\n",
- "# u_hat = dml_panel_plpr._dml_data.y - dml_panel_plpr.predictions['ml_l'].flatten()\n",
- "# v_hat = dml_panel_plpr._dml_data.d - dml_panel_plpr.predictions['ml_m'].flatten()\n",
- "\n",
- "# np.sqrt(np.mean(np.square(u_hat - (dml_panel_plpr.coef[0] * v_hat))))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "48d4dbd8",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "================== DoubleMLPLPR Object ==================\n",
- "\n",
- "------------------ Data Summary ------------------\n",
- "Outcome variable: y_diff\n",
- "Treatment variable(s): ['d_diff']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x1_lag', 'x2_lag', 'x3_lag', 'x4_lag', 'x5_lag', 'x6_lag', 'x7_lag', 'x8_lag', 'x9_lag', 'x10_lag', 'x11_lag', 'x12_lag', 'x13_lag', 'x14_lag', 'x15_lag', 'x16_lag', 'x17_lag', 'x18_lag', 'x19_lag', 'x20_lag', 'x21_lag', 'x22_lag', 'x23_lag', 'x24_lag', 'x25_lag', 'x26_lag', 'x27_lag', 'x28_lag', 'x29_lag', 'x30_lag']\n",
- "Instrument variable(s): None\n",
- "Time variable: time\n",
- "Id variable: id\n",
- "Static panel data: True\n",
- "No. Unique Ids: 250\n",
- "No. Observations: 2250\n",
- "\n",
- "\n",
- "------------------ Score & Algorithm ------------------\n",
- "Score function: partialling out\n",
- "Static panel model approach: fd_exact\n",
- "\n",
- "------------------ Machine Learner ------------------\n",
- "Learner ml_l: LassoCV()\n",
- "Learner ml_m: LassoCV()\n",
- "Out-of-sample Performance:\n",
- "Regression:\n",
- "Learner ml_l RMSE: [[1.57005957]]\n",
- "Learner ml_m RMSE: [[0.46246153]]\n",
- "\n",
- "------------------ Resampling ------------------\n",
- "No. folds per cluster: 5\n",
- "No. folds: 5\n",
- "No. repeated sample splits: 1\n",
- "\n",
- "------------------ Fit Summary ------------------\n",
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.62711 0.078221 8.017152 1.082254e-15 0.4738 0.780421\n",
- "\n",
- "------------------ Additional Information ------------------\n",
- "Cluster variable(s): ['id']\n",
- "\n",
- "Pre-Transformation Data Summary: \n",
- "Outcome variable: y\n",
- "Treatment variable(s): ['d']\n",
- "Covariates: ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30']\n",
- "No. Observations: 2500\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(dml_plpr_obj)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8731057f",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.51012 0.034875 14.627007 1.889172e-48 0.441766 0.578475\n"
- ]
- }
- ],
- "source": [
- "# cre general, extend features\n",
- "data = make_plpr_CP2025(num_id=100, dgp_type='dgp3')\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "dim_x = len(panel_data_obj.x_cols)\n",
- "indices_x = [x for x in range(dim_x)]\n",
- "indices_x_mean = [x + dim_x for x in indices_x]\n",
- "\n",
- "# x_cols = panel_data_obj.x_cols\n",
- "# dim_x = len(x_cols)\n",
- "# x_cols_for_poly = ['x1', 'x11', 'x22']\n",
- "\n",
- "# indices_x = [i for i, c in enumerate(x_cols) if c in x_cols_for_poly]\n",
- "# indices_x_mean = [x + dim_x for x in indices_x]\n",
- "\n",
- "# preprocessor = ColumnTransformer([\n",
- "# ('poly_x', make_pipeline(\n",
- "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
- "# ), indices_x), \n",
- "# ('poly_x_mean', make_pipeline(\n",
- "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False)\n",
- "# ), indices_x_mean) \n",
- "# ], remainder='passthrough') # add passthrough for cre_normal needed\n",
- "\n",
- "preprocessor = ColumnTransformer([\n",
- " ('poly_x', make_pipeline(\n",
- " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
- " ), indices_x), \n",
- " ('poly_x_mean', make_pipeline(\n",
- " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
- " ), indices_x_mean) \n",
- "], remainder='passthrough')\n",
- "\n",
- "learner = make_pipeline(\n",
- " preprocessor,\n",
- " StandardScaler(),\n",
- " LassoCV()\n",
- ")\n",
- "\n",
- "# learner = LassoCV()\n",
- "\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='cre_general', n_folds=5)\n",
- "dml_plpr_obj.fit(store_models=True)\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "d7936e42",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1050"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_plpr_obj.models['ml_l']['d'][0][0].named_steps['lassocv'].n_features_in_"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "20f158d6",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1050"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_plpr_obj.models['ml_m']['d'][0][0].named_steps['lassocv'].n_features_in_"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9edf129b",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_demean 1.195553 0.048717 24.540889 5.410930e-133 1.10007 1.291036\n"
- ]
- }
- ],
- "source": [
- "from lightgbm import LGBMRegressor\n",
- "\n",
- "ml_boost = LGBMRegressor(verbose=-1, \n",
- " n_estimators=100, \n",
- " learning_rate=0.3,\n",
- " min_child_samples=1) \n",
- "\n",
- "ml_boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- "\n",
- "ml_l = clone(ml_boost)\n",
- "ml_m = clone(ml_boost)\n",
- "\n",
- "data = make_plpr_CP2025(num_id=100, dgp_type='dgp3')\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l=ml_l, ml_m=ml_m, approach='wg_approx', n_folds=5)\n",
- "dml_plpr_obj.tune(param_grids=ml_boost_grid, n_jobs_cv=5)\n",
- "dml_plpr_obj.fit(store_models=True)\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "135a38ab",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d 0.564418 0.020749 27.201758 6.191263e-163 0.52375 0.605086\n"
- ]
- }
- ],
- "source": [
- "# cre normality assumption\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "cre_data = cre_fct(data)\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "# learner = LassoCV()\n",
- "learner = make_pipeline(StandardScaler(), LassoCV())\n",
- "\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='cre_normal')\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d869e493",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_diff 0.495522 0.023769 20.847281 1.613320e-96 0.448936 0.542109\n"
- ]
- }
- ],
- "source": [
- "# First difference approach\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "fd_data = fd_fct(data)\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "learner = LassoCV()\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='fd_exact')\n",
- "\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "579b43fa",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_demean 0.507186 0.022073 22.977214 7.878187e-117 0.463922 0.550449\n"
- ]
- }
- ],
- "source": [
- "# Within group approach\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "wd_data = wd_fct(data)\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "learner = LassoCV()\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
- "dml_plpr_obj.fit()\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "32e5bc9f",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " coef std err t P>|t| 2.5 % 97.5 %\n",
- "d_demean 0.472861 0.022779 20.758218 1.033280e-95 0.428214 0.517507\n"
- ]
- }
- ],
- "source": [
- "# Within group approach, polynomials\n",
- "data = make_plpr_CP2025(dgp_type='dgp1')\n",
- "\n",
- "panel_data_obj = DoubleMLPanelData(data,\n",
- " y_col='y',\n",
- " d_cols='d',\n",
- " t_col='time',\n",
- " id_col='id',\n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- "\n",
- "# preprocessor = ColumnTransformer([\n",
- "# ('poly', make_pipeline(\n",
- "# PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),\n",
- "# StandardScaler()\n",
- "# ), ['x1', 'x2']), # Columns to expand\n",
- "# ('pass', 'passthrough', ['cat']) # Columns to keep unchanged\n",
- "# ])\n",
- "\n",
- "preprocessor = ColumnTransformer([\n",
- " ('poly', make_pipeline(\n",
- " PolynomialFeatures(degree=2, include_bias=False)\n",
- " ), [0, 1])\n",
- "], remainder='passthrough')\n",
- "\n",
- "learner = make_pipeline(\n",
- " preprocessor,\n",
- " StandardScaler(),\n",
- " LassoCV()\n",
- ")\n",
- "\n",
- "ml_l = clone(learner)\n",
- "ml_m = clone(learner)\n",
- "\n",
- "dml_plpr_obj = DoubleMLPLPR(panel_data_obj, ml_l, ml_m, approach='wg_approx')\n",
- "dml_plpr_obj.fit(store_models=True)\n",
- "print(dml_plpr_obj.summary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 41,
- "id": "c43c504b",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[0, 1, 11]"
- ]
- },
- "execution_count": 41,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "x_cols_tranform = dml_plpr_obj.transform_cols['x_cols']\n",
- "\n",
- "x_cols_for_poly = ['x1_demean', 'x2_demean', 'x12_demean']\n",
- "\n",
- "indices = [i for i, c in enumerate(x_cols_tranform) if c in x_cols_for_poly]\n",
- "indices"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1a0e629d",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "33"
- ]
- },
- "execution_count": 32,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['lassocv'].n_features_in_"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "68f20f79",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array(['x0', 'x1', 'x0^2', 'x0 x1', 'x1^2'], dtype=object)"
- ]
- },
- "execution_count": 38,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dml_plpr_obj.models['ml_l']['d_demean'][0][0].named_steps['columntransformer']['poly'].get_feature_names_out()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c5a9c137",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Coef | \n",
- " Bias | \n",
- " Coverage | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | CRE general | \n",
- " 0.516684 | \n",
- " 0.016684 | \n",
- " 0.92 | \n",
- "
\n",
- " \n",
- " | CRE normal | \n",
- " 0.541518 | \n",
- " 0.041518 | \n",
- " 0.78 | \n",
- "
\n",
- " \n",
- " | FD exact | \n",
- " 0.504094 | \n",
- " 0.004094 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
- " | WG approx | \n",
- " 0.502006 | \n",
- " 0.002006 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Coef Bias Coverage\n",
- "CRE general 0.516684 0.016684 0.92\n",
- "CRE normal 0.541518 0.041518 0.78\n",
- "FD exact 0.504094 0.004094 0.94\n",
- "WG approx 0.502006 0.002006 0.94"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# simulation with built-in transformations\n",
- "\n",
- "n_reps = 100\n",
- "theta = 0.5\n",
- "\n",
- "learner = make_pipeline(StandardScaler(), LassoCV())\n",
- "\n",
- "res_cre_general = np.full((n_reps, 3), np.nan)\n",
- "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
- "res_fd = np.full((n_reps, 3), np.nan)\n",
- "res_wd = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "np.random.seed(1)\n",
- "\n",
- "for i in range(n_reps):\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- " data = make_plpr_CP2025(num_id=100, theta=theta, dgp_type='dgp1')\n",
- "\n",
- " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- " \n",
- " # CRE general Lasso\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_general')\n",
- " dml_plpr.fit()\n",
- " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # CRE normality\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_normal')\n",
- " dml_plpr.fit()\n",
- " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # FD approach\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='fd_exact')\n",
- " dml_plpr.fit()\n",
- " res_fd[i, 0] = dml_plpr.coef[0]\n",
- " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- " \n",
- " # WD approach, for now need new data object as FD approach overwrites _cluster_vars\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='wg_approx')\n",
- " dml_plpr.fit()\n",
- " res_wd[i, 0] = dml_plpr.coef[0]\n",
- " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- "\n",
- "pd.DataFrame(np.vstack([res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
- " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
- " columns=['Coef', 'Bias', 'Coverage'], \n",
- " index=['CRE general', 'CRE normal', \n",
- " 'FD exact', 'WG approx'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "24ef531c",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Coef | \n",
- " Bias | \n",
- " Coverage | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | CRE general | \n",
- " 0.498318 | \n",
- " -0.001682 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
- " | CRE normal | \n",
- " 0.497383 | \n",
- " -0.002617 | \n",
- " 0.96 | \n",
- "
\n",
- " \n",
- " | FD exact | \n",
- " 0.494321 | \n",
- " -0.005679 | \n",
- " 0.93 | \n",
- "
\n",
- " \n",
- " | WG approx | \n",
- " 0.497754 | \n",
- " -0.002246 | \n",
- " 0.95 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Coef Bias Coverage\n",
- "CRE general 0.498318 -0.001682 0.94\n",
- "CRE normal 0.497383 -0.002617 0.96\n",
- "FD exact 0.494321 -0.005679 0.93\n",
- "WG approx 0.497754 -0.002246 0.95"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "n_reps = 100\n",
- "theta = 0.5\n",
- "\n",
- "learner = LinearRegression()\n",
- "\n",
- "res_cre_general = np.full((n_reps, 3), np.nan)\n",
- "res_cre_normal = np.full((n_reps, 3), np.nan)\n",
- "res_fd = np.full((n_reps, 3), np.nan)\n",
- "res_wd = np.full((n_reps, 3), np.nan)\n",
- "\n",
- "np.random.seed(12)\n",
- "\n",
- "for i in range(n_reps):\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- " data = make_plpr_CP2025(num_id=100, theta=theta, dgp_type='dgp1')\n",
- "\n",
- " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', \n",
- " x_cols=[col for col in data.columns if \"x\" in col],\n",
- " static_panel=True)\n",
- " \n",
- " # CRE general Lasso\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_general')\n",
- " dml_plpr.fit()\n",
- " res_cre_general[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_general[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_general[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # CRE normality\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='cre_normal')\n",
- " dml_plpr.fit()\n",
- " res_cre_normal[i, 0] = dml_plpr.coef[0]\n",
- " res_cre_normal[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_cre_normal[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " # FD approach\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='fd_exact')\n",
- " dml_plpr.fit()\n",
- " res_fd[i, 0] = dml_plpr.coef[0]\n",
- " res_fd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_fd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- " \n",
- " # WD approach, for now need new data object as FD approach overwrites _cluster_vars\n",
- " dml_plpr = DoubleMLPLPR(dml_data, clone(learner), clone(learner), n_folds=5, \n",
- " approach='wg_approx')\n",
- " dml_plpr.fit()\n",
- " res_wd[i, 0] = dml_plpr.coef[0]\n",
- " res_wd[i, 1] = dml_plpr.coef[0] - theta\n",
- " confint = dml_plpr.confint()\n",
- " res_wd[i, 2] = (confint['2.5 %'].iloc[0] <= theta) & (confint['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- "\n",
- "pd.DataFrame(np.vstack([res_cre_general.mean(axis=0), res_cre_normal.mean(axis=0), \n",
- " res_fd.mean(axis=0), res_wd.mean(axis=0)]), \n",
- " columns=['Coef', 'Bias', 'Coverage'], \n",
- " index=['CRE general', 'CRE normal', \n",
- " 'FD exact', 'WG approx'])"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/doubleml/plm/sim/learners_sim.ipynb b/doubleml/plm/sim/learners_sim.ipynb
deleted file mode 100644
index e463dd817..000000000
--- a/doubleml/plm/sim/learners_sim.ipynb
+++ /dev/null
@@ -1,1121 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "fe0a50cb",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "from doubleml.data.panel_data import DoubleMLPanelData\n",
- "from doubleml.plm.plpr import DoubleMLPLPR\n",
- "from sklearn.linear_model import LassoCV\n",
- "from sklearn.base import clone\n",
- "from sklearn.tree import DecisionTreeRegressor\n",
- "from lightgbm import LGBMRegressor\n",
- "# from doubleml.plm.utils._plpr_util import extend_data, cre_fct, fd_fct, wd_fct\n",
- "from doubleml.plm.datasets.dgp_plpr_CP2025 import make_plpr_CP2025\n",
- "from sklearn.preprocessing import StandardScaler, PolynomialFeatures\n",
- "from sklearn.pipeline import make_pipeline\n",
- "from sklearn.base import BaseEstimator, TransformerMixin\n",
- "from sklearn.compose import ColumnTransformer\n",
- "import warnings\n",
- "warnings.filterwarnings(\"ignore\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "2650a1dd",
- "metadata": {},
- "outputs": [],
- "source": [
- "class PolyPlus(BaseEstimator, TransformerMixin):\n",
- " \"\"\"PolynomialFeatures(degree=k) and additional terms x_i^(k+1).\"\"\"\n",
- "\n",
- " def __init__(self, degree=2, interaction_only=False, include_bias=False):\n",
- " self.degree = degree\n",
- " self.extra_degree = degree + 1\n",
- " self.interaction_only = interaction_only\n",
- " self.include_bias = include_bias\n",
- " self.poly = PolynomialFeatures(degree=degree, interaction_only=interaction_only, include_bias=include_bias)\n",
- "\n",
- " def fit(self, X, y=None):\n",
- " self.poly.fit(X)\n",
- " self.n_features_in_ = X.shape[1]\n",
- " return self\n",
- "\n",
- " def transform(self, X):\n",
- " X = np.asarray(X)\n",
- " X_poly = self.poly.transform(X)\n",
- " X_extra = X ** self.extra_degree\n",
- " return np.hstack([X_poly, X_extra])\n",
- "\n",
- " def get_feature_names_out(self, input_features=None):\n",
- " input_features = np.array(\n",
- " input_features\n",
- " if input_features is not None\n",
- " else [f\"x{i}\" for i in range(self.n_features_in_)]\n",
- " )\n",
- " poly_names = self.poly.get_feature_names_out(input_features)\n",
- " extra_names = [f\"{name}^{self.extra_degree}\" for name in input_features]\n",
- " return np.concatenate([poly_names, extra_names])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "68cab57c",
- "metadata": {},
- "outputs": [],
- "source": [
- "dim_x = 30\n",
- "indices_x = [x for x in range(dim_x)]\n",
- "indices_x_tr = [x + dim_x for x in indices_x]\n",
- "\n",
- "preprocessor = ColumnTransformer([\n",
- " ('poly_x', make_pipeline(\n",
- " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
- " ), indices_x), \n",
- " ('poly_x_mean', make_pipeline(\n",
- " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
- " ), indices_x_tr) \n",
- "], remainder='passthrough')\n",
- "\n",
- "preprocessor_wg = ColumnTransformer([\n",
- " ('poly_x', make_pipeline(\n",
- " PolyPlus(degree=2, include_bias=False, interaction_only=False)\n",
- " ), indices_x), \n",
- "], remainder='passthrough')\n",
- "\n",
- "ml_lasso = make_pipeline(\n",
- " preprocessor,\n",
- " StandardScaler(),\n",
- " LassoCV(n_alphas=20, cv=2, n_jobs=5)\n",
- ")\n",
- "\n",
- "ml_lasso_wg = make_pipeline(\n",
- " preprocessor_wg,\n",
- " StandardScaler(),\n",
- " LassoCV(n_alphas=20, cv=2, n_jobs=5)\n",
- ")\n",
- "\n",
- "ml_cart = DecisionTreeRegressor()\n",
- "\n",
- "ml_boost = LGBMRegressor(verbose=-1, \n",
- " n_estimators=100, \n",
- " learning_rate=0.3,\n",
- " min_child_samples=1) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "dca81b0b",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " id | \n",
- " time | \n",
- " y | \n",
- " d | \n",
- " x1 | \n",
- " x2 | \n",
- " x3 | \n",
- " x4 | \n",
- " x5 | \n",
- " x6 | \n",
- " ... | \n",
- " x21 | \n",
- " x22 | \n",
- " x23 | \n",
- " x24 | \n",
- " x25 | \n",
- " x26 | \n",
- " x27 | \n",
- " x28 | \n",
- " x29 | \n",
- " x30 | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " 1 | \n",
- " 1 | \n",
- " -5.891762 | \n",
- " -4.574437 | \n",
- " -1.203377 | \n",
- " -0.087056 | \n",
- " -3.875989 | \n",
- " -2.060015 | \n",
- " 1.924513 | \n",
- " 1.564501 | \n",
- " ... | \n",
- " -1.697513 | \n",
- " -2.509320 | \n",
- " -0.727138 | \n",
- " -2.393134 | \n",
- " 0.334781 | \n",
- " 2.097534 | \n",
- " 1.942009 | \n",
- " 1.649557 | \n",
- " -0.612257 | \n",
- " -4.331109 | \n",
- "
\n",
- " \n",
- " | 1 | \n",
- " 1 | \n",
- " 2 | \n",
- " 0.601641 | \n",
- " 1.217127 | \n",
- " -1.076318 | \n",
- " 2.226439 | \n",
- " 0.379887 | \n",
- " -2.491481 | \n",
- " -1.446766 | \n",
- " -0.000182 | \n",
- " ... | \n",
- " -0.185932 | \n",
- " -0.491894 | \n",
- " 1.320808 | \n",
- " -2.888978 | \n",
- " 0.296153 | \n",
- " -0.209147 | \n",
- " -1.066396 | \n",
- " -2.232003 | \n",
- " 3.217619 | \n",
- " -0.660709 | \n",
- "
\n",
- " \n",
- " | 2 | \n",
- " 1 | \n",
- " 3 | \n",
- " -5.336432 | \n",
- " -4.084917 | \n",
- " 1.684777 | \n",
- " -2.117207 | \n",
- " -4.038299 | \n",
- " 1.196702 | \n",
- " 4.320428 | \n",
- " 1.543303 | \n",
- " ... | \n",
- " -1.896694 | \n",
- " 2.950372 | \n",
- " 2.266257 | \n",
- " -1.962670 | \n",
- " 1.913956 | \n",
- " -3.847482 | \n",
- " 0.914604 | \n",
- " -1.721561 | \n",
- " -0.954810 | \n",
- " 0.407410 | \n",
- "
\n",
- " \n",
- " | 3 | \n",
- " 1 | \n",
- " 4 | \n",
- " -0.478058 | \n",
- " 0.043192 | \n",
- " 0.978425 | \n",
- " -2.568042 | \n",
- " -1.001187 | \n",
- " -2.151350 | \n",
- " 0.973693 | \n",
- " -1.286461 | \n",
- " ... | \n",
- " -1.148923 | \n",
- " -3.388272 | \n",
- " 1.121507 | \n",
- " 4.753065 | \n",
- " 1.424797 | \n",
- " -2.345737 | \n",
- " -0.693004 | \n",
- " -1.618859 | \n",
- " 1.668621 | \n",
- " 4.571664 | \n",
- "
\n",
- " \n",
- " | 4 | \n",
- " 1 | \n",
- " 5 | \n",
- " -1.223138 | \n",
- " -1.544670 | \n",
- " 1.398526 | \n",
- " 3.567711 | \n",
- " -1.353898 | \n",
- " -2.226735 | \n",
- " 3.713345 | \n",
- " 0.675101 | \n",
- " ... | \n",
- " -1.837662 | \n",
- " 2.941762 | \n",
- " 2.061895 | \n",
- " 0.853080 | \n",
- " -0.244278 | \n",
- " 1.263040 | \n",
- " -2.011630 | \n",
- " -0.826488 | \n",
- " -0.887181 | \n",
- " -1.935609 | \n",
- "
\n",
- " \n",
- " | ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " | 995 | \n",
- " 100 | \n",
- " 6 | \n",
- " -2.734367 | \n",
- " -0.265121 | \n",
- " -3.045957 | \n",
- " 1.798310 | \n",
- " -1.485007 | \n",
- " -2.107084 | \n",
- " 2.983506 | \n",
- " 1.469852 | \n",
- " ... | \n",
- " 4.836053 | \n",
- " 1.932549 | \n",
- " 2.307781 | \n",
- " -2.536377 | \n",
- " 1.150598 | \n",
- " 1.052909 | \n",
- " -0.969657 | \n",
- " 1.266473 | \n",
- " -3.177021 | \n",
- " -3.070155 | \n",
- "
\n",
- " \n",
- " | 996 | \n",
- " 100 | \n",
- " 7 | \n",
- " 5.307034 | \n",
- " 3.585023 | \n",
- " -2.274700 | \n",
- " -2.499113 | \n",
- " 2.116755 | \n",
- " 0.478902 | \n",
- " -0.248561 | \n",
- " -0.957826 | \n",
- " ... | \n",
- " -2.799372 | \n",
- " 1.598447 | \n",
- " 1.972620 | \n",
- " -1.888645 | \n",
- " 1.237270 | \n",
- " 3.644984 | \n",
- " 0.054862 | \n",
- " 0.615274 | \n",
- " -0.432120 | \n",
- " -0.046949 | \n",
- "
\n",
- " \n",
- " | 997 | \n",
- " 100 | \n",
- " 8 | \n",
- " -4.476127 | \n",
- " -2.751544 | \n",
- " -0.716859 | \n",
- " -1.263491 | \n",
- " -2.826469 | \n",
- " -0.954049 | \n",
- " -1.237438 | \n",
- " -1.200074 | \n",
- " ... | \n",
- " 1.527836 | \n",
- " -1.918927 | \n",
- " -0.381272 | \n",
- " 2.065848 | \n",
- " 0.723859 | \n",
- " -0.711546 | \n",
- " 0.930980 | \n",
- " 0.883152 | \n",
- " -0.324217 | \n",
- " 0.053768 | \n",
- "
\n",
- " \n",
- " | 998 | \n",
- " 100 | \n",
- " 9 | \n",
- " 5.191475 | \n",
- " 4.985730 | \n",
- " -2.840246 | \n",
- " 0.931855 | \n",
- " 3.070040 | \n",
- " 2.700103 | \n",
- " 1.214848 | \n",
- " 2.846577 | \n",
- " ... | \n",
- " -1.662106 | \n",
- " 0.583185 | \n",
- " 2.117253 | \n",
- " -0.429837 | \n",
- " -1.983224 | \n",
- " -1.249148 | \n",
- " 5.170035 | \n",
- " 3.022710 | \n",
- " 3.091618 | \n",
- " -2.210554 | \n",
- "
\n",
- " \n",
- " | 999 | \n",
- " 100 | \n",
- " 10 | \n",
- " -4.046210 | \n",
- " 0.194942 | \n",
- " -0.998063 | \n",
- " 1.876639 | \n",
- " -0.401315 | \n",
- " -2.042387 | \n",
- " -0.389824 | \n",
- " -0.388875 | \n",
- " ... | \n",
- " 0.951100 | \n",
- " 0.987539 | \n",
- " -2.201030 | \n",
- " 0.144916 | \n",
- " -1.977077 | \n",
- " -2.538484 | \n",
- " -1.978323 | \n",
- " 2.068496 | \n",
- " -2.546201 | \n",
- " 2.218969 | \n",
- "
\n",
- " \n",
- "
\n",
- "
1000 rows × 34 columns
\n",
- "
"
- ],
- "text/plain": [
- " id time y d x1 x2 x3 x4 \\\n",
- "0 1 1 -5.891762 -4.574437 -1.203377 -0.087056 -3.875989 -2.060015 \n",
- "1 1 2 0.601641 1.217127 -1.076318 2.226439 0.379887 -2.491481 \n",
- "2 1 3 -5.336432 -4.084917 1.684777 -2.117207 -4.038299 1.196702 \n",
- "3 1 4 -0.478058 0.043192 0.978425 -2.568042 -1.001187 -2.151350 \n",
- "4 1 5 -1.223138 -1.544670 1.398526 3.567711 -1.353898 -2.226735 \n",
- ".. ... ... ... ... ... ... ... ... \n",
- "995 100 6 -2.734367 -0.265121 -3.045957 1.798310 -1.485007 -2.107084 \n",
- "996 100 7 5.307034 3.585023 -2.274700 -2.499113 2.116755 0.478902 \n",
- "997 100 8 -4.476127 -2.751544 -0.716859 -1.263491 -2.826469 -0.954049 \n",
- "998 100 9 5.191475 4.985730 -2.840246 0.931855 3.070040 2.700103 \n",
- "999 100 10 -4.046210 0.194942 -0.998063 1.876639 -0.401315 -2.042387 \n",
- "\n",
- " x5 x6 ... x21 x22 x23 x24 \\\n",
- "0 1.924513 1.564501 ... -1.697513 -2.509320 -0.727138 -2.393134 \n",
- "1 -1.446766 -0.000182 ... -0.185932 -0.491894 1.320808 -2.888978 \n",
- "2 4.320428 1.543303 ... -1.896694 2.950372 2.266257 -1.962670 \n",
- "3 0.973693 -1.286461 ... -1.148923 -3.388272 1.121507 4.753065 \n",
- "4 3.713345 0.675101 ... -1.837662 2.941762 2.061895 0.853080 \n",
- ".. ... ... ... ... ... ... ... \n",
- "995 2.983506 1.469852 ... 4.836053 1.932549 2.307781 -2.536377 \n",
- "996 -0.248561 -0.957826 ... -2.799372 1.598447 1.972620 -1.888645 \n",
- "997 -1.237438 -1.200074 ... 1.527836 -1.918927 -0.381272 2.065848 \n",
- "998 1.214848 2.846577 ... -1.662106 0.583185 2.117253 -0.429837 \n",
- "999 -0.389824 -0.388875 ... 0.951100 0.987539 -2.201030 0.144916 \n",
- "\n",
- " x25 x26 x27 x28 x29 x30 \n",
- "0 0.334781 2.097534 1.942009 1.649557 -0.612257 -4.331109 \n",
- "1 0.296153 -0.209147 -1.066396 -2.232003 3.217619 -0.660709 \n",
- "2 1.913956 -3.847482 0.914604 -1.721561 -0.954810 0.407410 \n",
- "3 1.424797 -2.345737 -0.693004 -1.618859 1.668621 4.571664 \n",
- "4 -0.244278 1.263040 -2.011630 -0.826488 -0.887181 -1.935609 \n",
- ".. ... ... ... ... ... ... \n",
- "995 1.150598 1.052909 -0.969657 1.266473 -3.177021 -3.070155 \n",
- "996 1.237270 3.644984 0.054862 0.615274 -0.432120 -0.046949 \n",
- "997 0.723859 -0.711546 0.930980 0.883152 -0.324217 0.053768 \n",
- "998 -1.983224 -1.249148 5.170035 3.022710 3.091618 -2.210554 \n",
- "999 -1.977077 -2.538484 -1.978323 2.068496 -2.546201 2.218969 \n",
- "\n",
- "[1000 rows x 34 columns]"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "data = make_plpr_CP2025(num_id=100, dgp_type='dgp1')\n",
- "data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b15dcb8d",
- "metadata": {},
- "outputs": [],
- "source": [
- "def run_sim(n_reps, num_id, dim_x=30, theta=0.5, dgp_type='dgp3'):\n",
- "\n",
- " approaches = [\"cre_general\", \"cre_normal\", \"fd_exact\", \"wg_approx\"]\n",
- " models = [\"lasso\", \"cart\", \"boost\"]\n",
- "\n",
- " res = {\n",
- " d: {\n",
- " m: np.full((n_reps, 3), np.nan)\n",
- " for m in models\n",
- " }\n",
- " for d in approaches\n",
- " }\n",
- "\n",
- " x_cols = [f\"x{i+1}\" for i in range(dim_x)]\n",
- "\n",
- " def run_single_dml(dml_data, ml, approach, grid=None):\n",
- " est = DoubleMLPLPR(dml_data, clone(ml), clone(ml), approach=approach, n_folds=5)\n",
- "\n",
- " if grid is not None:\n",
- " est.tune(param_grids=grid, search_mode='randomized_search',\n",
- " n_iter_randomized_search=5, n_jobs_cv=5)\n",
- "\n",
- " est.fit()\n",
- "\n",
- " coef_err = est.coef[0] - theta\n",
- " se = est.se[0]\n",
- " conf = est.confint()\n",
- " covered = (conf['2.5 %'].iloc[0] <= theta) & (conf['97.5 %'].iloc[0] >= theta)\n",
- "\n",
- " return coef_err, se, covered\n",
- "\n",
- " for i in range(n_reps):\n",
- "\n",
- " print(f\"\\rProcessing: {round((i+1)/n_reps*100, 3)} %\", end=\"\")\n",
- "\n",
- " cart_grid = {'ml_l': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'ccp_alpha': np.random.choice(np.arange(0.002, 0.052 , 0.002), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- "\n",
- " boost_grid= {'ml_l': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False), \n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)},\n",
- " 'ml_m': {'reg_lambda': np.random.choice(np.arange(0.2, 2 , 0.2), 5, replace=False),\n",
- " 'max_depth': np.random.choice(np.arange(2, 11 , 1), 5, replace=False)}}\n",
- "\n",
- " data = make_plpr_CP2025(num_id=num_id, dim_x=dim_x, theta=theta, dgp_type=dgp_type)\n",
- " dml_data = DoubleMLPanelData(data, y_col='y', d_cols='d', t_col='time', id_col='id', x_cols=x_cols, static_panel=True)\n",
- "\n",
- " # CRE general\n",
- " res['cre_general']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'cre_general', grid=None)\n",
- " res['cre_general']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'cre_general', grid=cart_grid)\n",
- " res['cre_general']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'cre_general', grid=boost_grid)\n",
- "\n",
- " # CRE normal\n",
- " res['cre_normal']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'cre_normal', grid=None)\n",
- " res['cre_normal']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'cre_normal', grid=cart_grid)\n",
- " res['cre_normal']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'cre_normal', grid=boost_grid)\n",
- "\n",
- " # FD\n",
- " res['fd_exact']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso, 'fd_exact', grid=None)\n",
- " res['fd_exact']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'fd_exact', grid=cart_grid)\n",
- " res['fd_exact']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'fd_exact', grid=boost_grid)\n",
- "\n",
- " # WD\n",
- " res['wg_approx']['lasso'][i, :] = run_single_dml(dml_data, ml_lasso_wg, 'wg_approx', grid=None)\n",
- " res['wg_approx']['cart'][i, :] = run_single_dml(dml_data, ml_cart, 'wg_approx', grid=cart_grid)\n",
- " res['wg_approx']['boost'][i, :] = run_single_dml(dml_data, ml_boost, 'wg_approx', grid=boost_grid)\n",
- "\n",
- " # summary\n",
- " rows = []\n",
- " index = []\n",
- "\n",
- " for approach, models_dict in res.items():\n",
- " for model, arr in models_dict.items():\n",
- "\n",
- " bias = np.mean(arr[:, 0])\n",
- " se_mean = np.mean(arr[:, 1])\n",
- " sd = np.std(arr[:, 1])\n",
- " coverage = np.mean(arr[:, 2])\n",
- " se_over_sd = sd / se_mean if se_mean > 0 else np.nan\n",
- " rmse = np.sqrt(np.mean(arr[:, 0]**2))\n",
- "\n",
- " rows.append([bias, rmse, se_over_sd, coverage])\n",
- " index.append((approach, model)) \n",
- "\n",
- " summary = pd.DataFrame(\n",
- " rows,\n",
- " index=pd.MultiIndex.from_tuples(index, names=[\"Approach\", \"ML Model\"]),\n",
- " columns=[\"Bias\", \"RMSE\", \"SE/SD\", \"Coverage\"]\n",
- " )\n",
- "\n",
- " return summary"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6503f741",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " | \n",
- " Bias | \n",
- " RMSE | \n",
- " SE/SD | \n",
- " Coverage | \n",
- "
\n",
- " \n",
- " | Approach | \n",
- " ML Model | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | cre_general | \n",
- " lasso | \n",
- " 0.024030 | \n",
- " 0.040756 | \n",
- " 0.076669 | \n",
- " 0.92 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.033091 | \n",
- " 0.069709 | \n",
- " 0.112711 | \n",
- " 0.75 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.041728 | \n",
- " 0.062508 | \n",
- " 0.114469 | \n",
- " 0.76 | \n",
- "
\n",
- " \n",
- " | cre_normal | \n",
- " lasso | \n",
- " 0.094727 | \n",
- " 0.103281 | \n",
- " 0.111451 | \n",
- " 0.41 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.015669 | \n",
- " 0.072878 | \n",
- " 0.125903 | \n",
- " 0.89 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.051158 | \n",
- " 0.081405 | \n",
- " 0.128427 | \n",
- " 0.90 | \n",
- "
\n",
- " \n",
- " | fd_exact | \n",
- " lasso | \n",
- " 0.021691 | \n",
- " 0.044563 | \n",
- " 0.084716 | \n",
- " 0.90 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " 0.076125 | \n",
- " 0.105962 | \n",
- " 0.099474 | \n",
- " 0.62 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.010003 | \n",
- " 0.049025 | \n",
- " 0.116624 | \n",
- " 0.87 | \n",
- "
\n",
- " \n",
- " | wg_approx | \n",
- " lasso | \n",
- " 0.003884 | \n",
- " 0.032382 | \n",
- " 0.077256 | \n",
- " 0.97 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.002517 | \n",
- " 0.048548 | \n",
- " 0.090029 | \n",
- " 0.87 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.042819 | \n",
- " 0.058042 | \n",
- " 0.086614 | \n",
- " 0.71 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Bias RMSE SE/SD Coverage\n",
- "Approach ML Model \n",
- "cre_general lasso 0.024030 0.040756 0.076669 0.92\n",
- " cart -0.033091 0.069709 0.112711 0.75\n",
- " boost -0.041728 0.062508 0.114469 0.76\n",
- "cre_normal lasso 0.094727 0.103281 0.111451 0.41\n",
- " cart -0.015669 0.072878 0.125903 0.89\n",
- " boost 0.051158 0.081405 0.128427 0.90\n",
- "fd_exact lasso 0.021691 0.044563 0.084716 0.90\n",
- " cart 0.076125 0.105962 0.099474 0.62\n",
- " boost 0.010003 0.049025 0.116624 0.87\n",
- "wg_approx lasso 0.003884 0.032382 0.077256 0.97\n",
- " cart -0.002517 0.048548 0.090029 0.87\n",
- " boost -0.042819 0.058042 0.086614 0.71"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.random.seed(123)\n",
- "\n",
- "res_dgp1 = run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp1')\n",
- "res_dgp1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7a7b3fa2",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " | \n",
- " Bias | \n",
- " RMSE | \n",
- " SE/SD | \n",
- " Coverage | \n",
- "
\n",
- " \n",
- " | Approach | \n",
- " ML Model | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | cre_general | \n",
- " lasso | \n",
- " -0.005823 | \n",
- " 0.027244 | \n",
- " 0.073960 | \n",
- " 0.95 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.109613 | \n",
- " 0.122162 | \n",
- " 0.106904 | \n",
- " 0.24 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.063039 | \n",
- " 0.075676 | \n",
- " 0.081278 | \n",
- " 0.50 | \n",
- "
\n",
- " \n",
- " | cre_normal | \n",
- " lasso | \n",
- " 0.070362 | \n",
- " 0.076794 | \n",
- " 0.087415 | \n",
- " 0.40 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.028230 | \n",
- " 0.073113 | \n",
- " 0.182335 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.016764 | \n",
- " 0.060976 | \n",
- " 0.183879 | \n",
- " 0.89 | \n",
- "
\n",
- " \n",
- " | fd_exact | \n",
- " lasso | \n",
- " -0.002715 | \n",
- " 0.032797 | \n",
- " 0.091710 | \n",
- " 0.95 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.074209 | \n",
- " 0.083379 | \n",
- " 0.099775 | \n",
- " 0.47 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.062527 | \n",
- " 0.076963 | \n",
- " 0.085480 | \n",
- " 0.58 | \n",
- "
\n",
- " \n",
- " | wg_approx | \n",
- " lasso | \n",
- " -0.003015 | \n",
- " 0.027330 | \n",
- " 0.074295 | \n",
- " 0.95 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " -0.024405 | \n",
- " 0.038327 | \n",
- " 0.079555 | \n",
- " 0.90 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " -0.056258 | \n",
- " 0.067245 | \n",
- " 0.083798 | \n",
- " 0.52 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Bias RMSE SE/SD Coverage\n",
- "Approach ML Model \n",
- "cre_general lasso -0.005823 0.027244 0.073960 0.95\n",
- " cart -0.109613 0.122162 0.106904 0.24\n",
- " boost -0.063039 0.075676 0.081278 0.50\n",
- "cre_normal lasso 0.070362 0.076794 0.087415 0.40\n",
- " cart -0.028230 0.073113 0.182335 0.94\n",
- " boost -0.016764 0.060976 0.183879 0.89\n",
- "fd_exact lasso -0.002715 0.032797 0.091710 0.95\n",
- " cart -0.074209 0.083379 0.099775 0.47\n",
- " boost -0.062527 0.076963 0.085480 0.58\n",
- "wg_approx lasso -0.003015 0.027330 0.074295 0.95\n",
- " cart -0.024405 0.038327 0.079555 0.90\n",
- " boost -0.056258 0.067245 0.083798 0.52"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.random.seed(123)\n",
- "\n",
- "res_dgp2 = run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp2')\n",
- "res_dgp2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f8f0abee",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Processing: 100.0 %"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " | \n",
- " Bias | \n",
- " RMSE | \n",
- " SE/SD | \n",
- " Coverage | \n",
- "
\n",
- " \n",
- " | Approach | \n",
- " ML Model | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | cre_general | \n",
- " lasso | \n",
- " 0.016582 | \n",
- " 0.037272 | \n",
- " 0.080832 | \n",
- " 0.94 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " 0.493280 | \n",
- " 0.534554 | \n",
- " 0.275800 | \n",
- " 0.05 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.443519 | \n",
- " 0.451916 | \n",
- " 0.180075 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- " | cre_normal | \n",
- " lasso | \n",
- " 0.083721 | \n",
- " 0.100538 | \n",
- " 0.129024 | \n",
- " 0.71 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " 0.413418 | \n",
- " 0.468509 | \n",
- " 0.218055 | \n",
- " 0.13 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.385419 | \n",
- " 0.393097 | \n",
- " 0.185301 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- " | fd_exact | \n",
- " lasso | \n",
- " 0.014759 | \n",
- " 0.041217 | \n",
- " 0.086594 | \n",
- " 0.92 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " 0.658754 | \n",
- " 0.673590 | \n",
- " 0.259508 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.588552 | \n",
- " 0.593496 | \n",
- " 0.134829 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- " | wg_approx | \n",
- " lasso | \n",
- " 0.630003 | \n",
- " 0.630942 | \n",
- " 0.099154 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- " | cart | \n",
- " 0.515894 | \n",
- " 0.532497 | \n",
- " 0.227144 | \n",
- " 0.01 | \n",
- "
\n",
- " \n",
- " | boost | \n",
- " 0.593587 | \n",
- " 0.597220 | \n",
- " 0.137237 | \n",
- " 0.00 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Bias RMSE SE/SD Coverage\n",
- "Approach ML Model \n",
- "cre_general lasso 0.016582 0.037272 0.080832 0.94\n",
- " cart 0.493280 0.534554 0.275800 0.05\n",
- " boost 0.443519 0.451916 0.180075 0.00\n",
- "cre_normal lasso 0.083721 0.100538 0.129024 0.71\n",
- " cart 0.413418 0.468509 0.218055 0.13\n",
- " boost 0.385419 0.393097 0.185301 0.00\n",
- "fd_exact lasso 0.014759 0.041217 0.086594 0.92\n",
- " cart 0.658754 0.673590 0.259508 0.00\n",
- " boost 0.588552 0.593496 0.134829 0.00\n",
- "wg_approx lasso 0.630003 0.630942 0.099154 0.00\n",
- " cart 0.515894 0.532497 0.227144 0.01\n",
- " boost 0.593587 0.597220 0.137237 0.00"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.random.seed(123)\n",
- "\n",
- "run_sim(n_reps=100, num_id=100, theta=0.5, dgp_type='dgp3')"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/doubleml/plm/utils/_plpr_util.py b/doubleml/plm/utils/_plpr_util.py
deleted file mode 100644
index 9045ffb08..000000000
--- a/doubleml/plm/utils/_plpr_util.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-import pandas as pd
-from sklearn.preprocessing import PolynomialFeatures
-
-
-def extend_data(data):
- data = data.copy()
- poly = PolynomialFeatures(2, include_bias=False)
-
- xdat = data.loc[:,data.columns.str.startswith('x') & ~data.columns.str.contains('lag')]
- xpoly = poly.fit_transform(xdat)
- x_p3 = xdat**3
-
- x_pol_nam = poly.get_feature_names_out()
- x_cols_p3 = [f'x{i + 1}^3' for i in np.arange(xdat.shape[1])]
-
- if data.columns.str.startswith('m_x').any():
- xmdat = data.loc[:,data.columns.str.startswith('m_x')]
- xmpoly = poly.fit_transform(xmdat)
- xm_p3 = xmdat**3
-
- xm_pol_nam = poly.get_feature_names_out()
- xm_cols_p3 = [f'm_x{i + 1}^3' for i in np.arange(xmdat.shape[1])]
-
- X_all = np.column_stack((xpoly, x_p3, xmpoly, xm_p3))
- x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3 + list(xm_pol_nam) + xm_cols_p3)
- df_ext = data[['id', 'time', 'd', 'y', 'm_d']].join(x_df)
-
- elif data.columns.str.contains('_lag').any():
- xldat = data.loc[:,data.columns.str.contains('_lag')]
- xlpoly = poly.fit_transform(xldat)
- xl_p3 = xldat**3
-
- xl_pol_nam = poly.get_feature_names_out()
- xl_cols_p3 = [f'x{i + 1}_lag^3' for i in np.arange(xldat.shape[1])]
-
- X_all = np.column_stack((xpoly, x_p3, xlpoly, xl_p3))
- x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3 + list(xl_pol_nam) + xl_cols_p3)
- df_ext = data[['id', 'time', 'd_diff', 'y_diff']].join(x_df)
-
- else:
- X_all = np.column_stack((xpoly, x_p3))
- x_df = pd.DataFrame(X_all, columns = list(x_pol_nam) + x_cols_p3)
- df_ext = data[['id', 'time', 'd', 'y']].join(x_df)
-
- return df_ext
-
-
-def cre_fct(data):
- df = data.copy()
- id_means = df.loc[:,~df.columns.isin(['time', 'y'])].groupby(["id"]).transform('mean')
- df = df.join(id_means.rename(columns=lambda x: "m_" + x))
- return df
-
-
-def fd_fct(data):
- df = data.copy()
- shifted = df.loc[:,~df.columns.isin(['d', 'y', 'time'])].groupby(["id"]).shift(1)
- first_diff = df.loc[:,df.columns.isin(['id', 'd', 'y'])].groupby(["id"]).diff()
- df_fd = df.join(shifted.rename(columns=lambda x: x +"_lag"))
- df_fd = df_fd.join(first_diff.rename(columns=lambda x: x +"_diff"))
- df = df_fd.dropna(subset=['x1_lag']).reset_index(drop=True)
- return df
-
-
-def wd_fct(data):
- df = data.copy()
- df_demean = df.loc[:,~df.columns.isin(['time'])].groupby(["id"]).transform(lambda x: x - x.mean())
- # add xbar (the grand mean allows a consistent estimate of the constant term)
- within_means = df_demean + df.loc[:,~df.columns.isin(['id','time'])].mean()
- df_wd = df.loc[:,df.columns.isin(['id','time'])]
- df = df_wd.join(within_means)
- return df
\ No newline at end of file
From 7d01cd1b36464676aaef4b79013a7ce07c480a38 Mon Sep 17 00:00:00 2001
From: Julian Diefenbacher
<93731825+JulianDiefenbacher@users.noreply.github.com>
Date: Mon, 1 Dec 2025 19:37:55 +0100
Subject: [PATCH 30/33] fix id var issue
---
doubleml/plm/datasets/dgp_plpr_CP2025.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/doubleml/plm/datasets/dgp_plpr_CP2025.py b/doubleml/plm/datasets/dgp_plpr_CP2025.py
index 438745c0a..4905819d9 100644
--- a/doubleml/plm/datasets/dgp_plpr_CP2025.py
+++ b/doubleml/plm/datasets/dgp_plpr_CP2025.py
@@ -82,7 +82,7 @@ def make_plpr_CP2025(num_id=250, num_t=10, dim_x=30, theta=0.5, dgp_type="dgp1")
sigma2_x = 5
# id and time vectors
- id = np.repeat(np.arange(1, num_id + 1), num_t)
+ id_var = np.repeat(np.arange(1, num_id + 1), num_t)
time = np.tile(np.arange(1, num_t + 1), num_id)
# individual fixed effects
@@ -127,7 +127,7 @@ def alpha_i(x_it, d_it, a_i, num_n, num_t):
x_cols = [f"x{i + 1}" for i in np.arange(dim_x)]
- data = pd.DataFrame(np.column_stack((id, time, y_it, d_it, x_it)), columns=["id", "time", "y", "d"] + x_cols).astype(
+ data = pd.DataFrame(np.column_stack((id_var, time, y_it, d_it, x_it)), columns=["id", "time", "y", "d"] + x_cols).astype(
{"id": "int64", "time": "int64"}
)
From 7c24a99928a25d75ea99240bb2ca3c4aaa92dfea Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Tue, 2 Dec 2025 22:16:08 +0100
Subject: [PATCH 31/33] complete plpr dataset tests
---
doubleml/plm/datasets/dgp_plpr_CP2025.py | 2 +-
doubleml/plm/tests/test_datasets.py | 17 +++++++++++++++--
doubleml/plm/tests/test_plpr.py | 2 +-
doubleml/plm/tests/test_plpr_tune.py | 2 +-
4 files changed, 18 insertions(+), 5 deletions(-)
diff --git a/doubleml/plm/datasets/dgp_plpr_CP2025.py b/doubleml/plm/datasets/dgp_plpr_CP2025.py
index 4905819d9..8e07ff487 100644
--- a/doubleml/plm/datasets/dgp_plpr_CP2025.py
+++ b/doubleml/plm/datasets/dgp_plpr_CP2025.py
@@ -106,7 +106,7 @@ def make_plpr_CP2025(num_id=250, num_t=10, dim_x=30, theta=0.5, dgp_type="dgp1")
l_0 = b * (x_it[:, 0] * x_it[:, 2]) + a * (x_it[:, 2] * np.where(x_it[:, 2] > 0, 1, 0))
m_0 = a * (x_it[:, 0] * np.where(x_it[:, 0] > 0, 1, 0)) + b * (x_it[:, 0] * x_it[:, 2])
else:
- raise ValueError("Invalid dgp")
+ raise ValueError("Invalid dgp type.")
# treatment
d_it = m_0 + c_i + v_it
diff --git a/doubleml/plm/tests/test_datasets.py b/doubleml/plm/tests/test_datasets.py
index 2726ea269..b255228cc 100644
--- a/doubleml/plm/tests/test_datasets.py
+++ b/doubleml/plm/tests/test_datasets.py
@@ -16,6 +16,13 @@
msg_inv_return_type = "Invalid return_type."
+msg_inv_dgp_type = "Invalid dgp type."
+
+
+@pytest.fixture(scope="module", params=["dgp1", "dgp2", "dgp3"])
+def dgp_type(request):
+ return request.param
+
@pytest.mark.ci
def test_make_plr_CCDDHNR2018_return_types():
@@ -153,7 +160,13 @@ def test_make_lplr_LZZ2020_variants():
@pytest.mark.ci
-def test_make_plpr_CP2025_return_types():
+def test_make_plpr_CP2025_return_types(dgp_type):
np.random.seed(3141)
- res = make_plpr_CP2025(num_id=100)
+ res = make_plpr_CP2025(num_id=100, dgp_type=dgp_type)
assert isinstance(res, pd.DataFrame)
+
+
+@pytest.mark.ci
+def test_make_plpr_CP2025_invalid_dgp_type():
+ with pytest.raises(ValueError, match=msg_inv_dgp_type):
+ _ = make_plpr_CP2025(num_id=100, dgp_type="dgp4")
\ No newline at end of file
diff --git a/doubleml/plm/tests/test_plpr.py b/doubleml/plm/tests/test_plpr.py
index 3a5061588..d072ebdf9 100644
--- a/doubleml/plm/tests/test_plpr.py
+++ b/doubleml/plm/tests/test_plpr.py
@@ -68,7 +68,7 @@ def dml_plpr_fixture(
@pytest.mark.ci
-def test_dml_selection_coef(dml_plpr_fixture):
+def test_dml_plpr_coef(dml_plpr_fixture):
# true_coef should lie within three standard deviations of the estimate
coef = dml_plpr_fixture["coef"]
se = dml_plpr_fixture["se"]
diff --git a/doubleml/plm/tests/test_plpr_tune.py b/doubleml/plm/tests/test_plpr_tune.py
index 411a9bd90..e1dcef2da 100644
--- a/doubleml/plm/tests/test_plpr_tune.py
+++ b/doubleml/plm/tests/test_plpr_tune.py
@@ -96,7 +96,7 @@ def dml_plpr_fixture(
@pytest.mark.ci
-def test_dml_selection_coef(dml_plpr_fixture):
+def test_dml_plpr_coef(dml_plpr_fixture):
# true_coef should lie within three standard deviations of the estimate
coef = dml_plpr_fixture["coef"]
se = dml_plpr_fixture["se"]
From fb56a276bfc63024c1479da7e835542859304809 Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Tue, 2 Dec 2025 22:26:27 +0100
Subject: [PATCH 32/33] fix formatting
---
doubleml/plm/tests/test_datasets.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/doubleml/plm/tests/test_datasets.py b/doubleml/plm/tests/test_datasets.py
index b255228cc..0762ab47c 100644
--- a/doubleml/plm/tests/test_datasets.py
+++ b/doubleml/plm/tests/test_datasets.py
@@ -169,4 +169,4 @@ def test_make_plpr_CP2025_return_types(dgp_type):
@pytest.mark.ci
def test_make_plpr_CP2025_invalid_dgp_type():
with pytest.raises(ValueError, match=msg_inv_dgp_type):
- _ = make_plpr_CP2025(num_id=100, dgp_type="dgp4")
\ No newline at end of file
+ _ = make_plpr_CP2025(num_id=100, dgp_type="dgp4")
From b15571ce7af14caf9486181c687da1817e6d26fa Mon Sep 17 00:00:00 2001
From: JulianDiefenbacher
Date: Wed, 3 Dec 2025 21:14:07 +0100
Subject: [PATCH 33/33] add external pred test, complete model default test
---
doubleml/plm/tests/test_model_defaults.py | 34 +++++-
.../tests/test_plpr_external_predictions.py | 106 ++++++++++++++++++
2 files changed, 138 insertions(+), 2 deletions(-)
create mode 100644 doubleml/plm/tests/test_plpr_external_predictions.py
diff --git a/doubleml/plm/tests/test_model_defaults.py b/doubleml/plm/tests/test_model_defaults.py
index 6d8ef3226..bc9e8a502 100644
--- a/doubleml/plm/tests/test_model_defaults.py
+++ b/doubleml/plm/tests/test_model_defaults.py
@@ -1,7 +1,9 @@
+import numpy as np
import pytest
from sklearn.linear_model import LinearRegression, LogisticRegression
from doubleml import DoubleMLLPLR, DoubleMLPanelData, DoubleMLPLPR
+from doubleml.double_ml import DoubleML
from doubleml.plm.datasets import make_lplr_LZZ2020, make_plpr_CP2025
from doubleml.utils._check_defaults import _check_basic_defaults_after_fit, _check_basic_defaults_before_fit, _fit_bootstrap
@@ -9,7 +11,7 @@
dml_lplr_obj = DoubleMLLPLR(dml_data_lplr, LogisticRegression(), LinearRegression(), LinearRegression())
-plpr_data = make_plpr_CP2025(num_id=100)
+plpr_data = make_plpr_CP2025(num_id=100, dgp_type="dgp1")
dml_data_plpr = DoubleMLPanelData(
plpr_data,
y_col="y",
@@ -34,4 +36,32 @@ def test_lplr_defaults():
@pytest.mark.ci
def test_plpr_defaults():
_check_basic_defaults_before_fit(dml_plpr_obj)
- # TODO: fit for cluster?
+
+ # manual fit and default check after fit
+ dml_plpr_obj.fit()
+ assert dml_plpr_obj.n_folds == 5
+ assert dml_plpr_obj.n_rep == 1
+ assert dml_plpr_obj.framework is not None
+
+ # coefs and se
+ assert isinstance(dml_plpr_obj.coef, np.ndarray)
+ assert isinstance(dml_plpr_obj.se, np.ndarray)
+ assert isinstance(dml_plpr_obj.all_coef, np.ndarray)
+ assert isinstance(dml_plpr_obj.all_se, np.ndarray)
+ assert isinstance(dml_plpr_obj.t_stat, np.ndarray)
+ assert isinstance(dml_plpr_obj.pval, np.ndarray)
+
+ # bootstrap and p_adjust method skipped
+
+ # sensitivity
+ assert dml_plpr_obj.sensitivity_params is None
+ if dml_plpr_obj.sensitivity_params is not None:
+ assert isinstance(dml_plpr_obj.sensitivity_elements, dict)
+
+ # fit method
+ if isinstance(dml_plpr_obj, DoubleML):
+ assert dml_plpr_obj.predictions is not None
+ assert dml_plpr_obj.models is None
+
+ # confint method
+ assert dml_plpr_obj.confint().equals(dml_plpr_obj.confint(joint=False, level=0.95))
diff --git a/doubleml/plm/tests/test_plpr_external_predictions.py b/doubleml/plm/tests/test_plpr_external_predictions.py
new file mode 100644
index 000000000..cd43cec15
--- /dev/null
+++ b/doubleml/plm/tests/test_plpr_external_predictions.py
@@ -0,0 +1,106 @@
+import math
+
+import numpy as np
+import pytest
+from sklearn.linear_model import LinearRegression
+
+from doubleml import DoubleMLPanelData, DoubleMLPLPR
+from doubleml.plm.datasets import make_plpr_CP2025
+from doubleml.utils import DMLDummyRegressor
+
+treat_label = {"cre_general": "d", "cre_normal": "d", "fd_exact": "d_diff", "wg_approx": "d_demean"}
+
+
+@pytest.fixture(scope="module", params=["IV-type", "partialling out"])
+def plpr_score(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["cre_general", "cre_normal", "fd_exact", "wg_approx"])
+def plpr_approach(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[1, 3])
+def n_rep(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[True, False])
+def set_ml_m_ext(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[True, False])
+def set_ml_l_ext(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[True, False])
+def set_ml_g_ext(request):
+ return request.param
+
+
+@pytest.fixture(scope="module")
+def doubleml_plpr_fixture(plpr_score, plpr_approach, n_rep, set_ml_m_ext, set_ml_l_ext, set_ml_g_ext):
+ ext_predictions = {treat_label[plpr_approach]: {}}
+
+ plpr_data = make_plpr_CP2025(num_id=100, theta=0.5, dgp_type="dgp1")
+
+ np.random.seed(3141)
+ dml_data_plpr = DoubleMLPanelData(
+ plpr_data,
+ y_col="y",
+ d_cols="d",
+ t_col="time",
+ id_col="id",
+ static_panel=True,
+ )
+
+ kwargs = {"obj_dml_data": dml_data_plpr, "score": plpr_score, "approach": plpr_approach, "n_rep": n_rep}
+
+ if plpr_score == "IV-type":
+ kwargs["ml_g"] = LinearRegression()
+
+ dml_plpr = DoubleMLPLPR(ml_m=LinearRegression(), ml_l=LinearRegression(), **kwargs)
+
+ np.random.seed(3141)
+ dml_plpr.fit(store_predictions=True)
+
+ if set_ml_m_ext:
+ ext_predictions[treat_label[plpr_approach]]["ml_m"] = dml_plpr.predictions["ml_m"][:, :, 0]
+ ml_m = DMLDummyRegressor()
+ else:
+ ml_m = LinearRegression()
+
+ if set_ml_l_ext:
+ ext_predictions[treat_label[plpr_approach]]["ml_l"] = dml_plpr.predictions["ml_l"][:, :, 0]
+ ml_l = DMLDummyRegressor()
+ else:
+ ml_l = LinearRegression()
+
+ if plpr_score == "IV-type" and set_ml_g_ext:
+ ext_predictions[treat_label[plpr_approach]]["ml_g"] = dml_plpr.predictions["ml_g"][:, :, 0]
+ kwargs["ml_g"] = DMLDummyRegressor()
+ elif plpr_score == "IV-type" and not set_ml_g_ext:
+ kwargs["ml_g"] = LinearRegression()
+ else:
+ pass
+
+ if plpr_score == "IV-type" and set_ml_g_ext and not set_ml_l_ext:
+ ml_l = DMLDummyRegressor()
+
+ # special case if ml_l is not needed
+ dml_plpr_ext = DoubleMLPLPR(ml_m=ml_m, ml_l=ml_l, **kwargs)
+
+ np.random.seed(3141)
+ dml_plpr_ext.fit(external_predictions=ext_predictions)
+
+ res_dict = {"coef_normal": dml_plpr.coef[0], "coef_ext": dml_plpr_ext.coef[0]}
+
+ return res_dict
+
+
+@pytest.mark.ci
+def test_doubleml_plpr_coef(doubleml_plpr_fixture):
+ assert math.isclose(doubleml_plpr_fixture["coef_normal"], doubleml_plpr_fixture["coef_ext"], rel_tol=1e-9, abs_tol=1e-4)