diff --git a/doubleml/irm/ssm.py b/doubleml/irm/ssm.py index e7e5d83c..6e8b9563 100644 --- a/doubleml/irm/ssm.py +++ b/doubleml/irm/ssm.py @@ -428,86 +428,114 @@ def _nuisance_tuning( ): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, force_all_finite=False) - # time indicator is used for selection (selection not available in DoubleMLData yet) x, s = check_X_y(x, self._dml_data.s, force_all_finite=False) if self._score == "nonignorable": z, _ = check_X_y(self._dml_data.z, y, force_all_finite=False) - dx = np.column_stack((x, d, z)) - else: - dx = np.column_stack((x, d)) if scoring_methods is None: scoring_methods = {"ml_g": None, "ml_pi": None, "ml_m": None} - # nuisance training sets conditional on d - _, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s) - train_inds = [train_index for (train_index, _) in smpls] - train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1] - train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1] - - # hyperparameter tuning for ML - g_d0_tune_res = _dml_tune( - y, - x, - train_inds_d0_s1, - self._learner["ml_g"], - param_grids["ml_g"], - scoring_methods["ml_g"], - n_folds_tune, - n_jobs_cv, - search_mode, - n_iter_randomized_search, - ) - g_d1_tune_res = _dml_tune( - y, - x, - train_inds_d1_s1, - self._learner["ml_g"], - param_grids["ml_g"], - scoring_methods["ml_g"], - n_folds_tune, - n_jobs_cv, - search_mode, - n_iter_randomized_search, - ) - pi_tune_res = _dml_tune( - s, - dx, - train_inds, - self._learner["ml_pi"], - param_grids["ml_pi"], - scoring_methods["ml_pi"], - n_folds_tune, - n_jobs_cv, - search_mode, - n_iter_randomized_search, - ) - m_tune_res = _dml_tune( - d, - x, - train_inds, - self._learner["ml_m"], - param_grids["ml_m"], - scoring_methods["ml_m"], - n_folds_tune, - n_jobs_cv, - search_mode, - n_iter_randomized_search, - ) + # Nested helper functions + def tune_learner(target, features, train_indices, learner_key): + return _dml_tune( + target, + features, + train_indices, + self._learner[learner_key], + param_grids[learner_key], + scoring_methods[learner_key], + n_folds_tune, + n_jobs_cv, + search_mode, + n_iter_randomized_search, + ) - g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res] - g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res] - pi_best_params = [xx.best_params_ for xx in pi_tune_res] - m_best_params = [xx.best_params_ for xx in m_tune_res] + def split_inner_folds(train_inds, d, s, random_state=42): + inner_train0_inds, inner_train1_inds = [], [] + for train_index in train_inds: + stratify_vec = d[train_index] + 2 * s[train_index] + inner0, inner1 = train_test_split(train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state) + inner_train0_inds.append(inner0) + inner_train1_inds.append(inner1) + return inner_train0_inds, inner_train1_inds + + def filter_by_ds(inner_train1_inds, d, s): + inner1_d0_s1, inner1_d1_s1 = [], [] + for inner1 in inner_train1_inds: + d_fold, s_fold = d[inner1], s[inner1] + mask_d0_s1 = (d_fold == 0) & (s_fold == 1) + mask_d1_s1 = (d_fold == 1) & (s_fold == 1) + + inner1_d0_s1.append(inner1[mask_d0_s1]) + inner1_d1_s1.append(inner1[mask_d1_s1]) + return inner1_d0_s1, inner1_d1_s1 - params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params} + if self._score == "nonignorable": - tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res} + train_inds = [train_index for (train_index, _) in smpls] + + # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning) + inner_train0_inds, inner_train1_inds = split_inner_folds(train_inds, d, s) + # split inner1 by (d,s) to build g-models for treated/control + inner_train1_d0_s1, inner_train1_d1_s1 = filter_by_ds(inner_train1_inds, d, s) + + # Tune ml_pi + x_d_z = np.column_stack((x, d, z)) + pi_tune_res = [] + pi_hat_full = np.full(shape=s.shape, fill_value=np.nan) + for inner0, inner1 in zip(inner_train0_inds, inner_train1_inds): + res = tune_learner(s, x_d_z, [inner0], "ml_pi") + best_params = res[0].best_params_ + + # Fit tuned model and predict + ml_pi_temp = clone(self._learner["ml_pi"]) + ml_pi_temp.set_params(**best_params) + ml_pi_temp.fit(x_d_z[inner0], s[inner0]) + pi_hat_full[inner1] = _predict_zero_one_propensity(ml_pi_temp, x_d_z)[inner1] + pi_tune_res.append(res[0]) + + # Tune ml_m with x + pi-hats + x_pi = np.column_stack([x, pi_hat_full.reshape(-1, 1)]) + m_tune_res = tune_learner(d, x_pi, inner_train1_inds, "ml_m") + + # Tune ml_g for d=0 and d=1 + x_pi_d = np.column_stack([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)]) + g_d0_tune_res = tune_learner(y, x_pi_d, inner_train1_d0_s1, "ml_g") + g_d1_tune_res = tune_learner(y, x_pi_d, inner_train1_d1_s1, "ml_g") - res = {"params": params, "tune_res": tune_res} + else: + # nuisance training sets conditional on d + _, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s) + train_inds = [train_index for (train_index, _) in smpls] + train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1] + train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1] + + # Tune ml_g for d=0 and d=1 + g_d0_tune_res = tune_learner(y, x, train_inds_d0_s1, "ml_g") + g_d1_tune_res = tune_learner(y, x, train_inds_d1_s1, "ml_g") + + # Tune ml_pi and ml_m + x_d = np.column_stack((x, d)) + pi_tune_res = tune_learner(s, x_d, train_inds, "ml_pi") + m_tune_res = tune_learner(d, x, train_inds, "ml_m") + + # Collect results + params = { + "ml_g_d0": [res.best_params_ for res in g_d0_tune_res], + "ml_g_d1": [res.best_params_ for res in g_d1_tune_res], + "ml_pi": [res.best_params_ for res in pi_tune_res], + "ml_m": [res.best_params_ for res in m_tune_res], + } + + tune_res = { + "g_d0_tune": g_d0_tune_res, + "g_d1_tune": g_d1_tune_res, + "pi_tune": pi_tune_res, + "m_tune": m_tune_res, + } - return res + return {"params": params, "tune_res": tune_res} def _sensitivity_element_est(self, preds): pass diff --git a/doubleml/irm/tests/_utils_ssm_manual.py b/doubleml/irm/tests/_utils_ssm_manual.py index e27f7fe1..f14a1f66 100644 --- a/doubleml/irm/tests/_utils_ssm_manual.py +++ b/doubleml/irm/tests/_utils_ssm_manual.py @@ -273,17 +273,14 @@ def var_selection(theta, psi_a, psi_b, n_obs): return var -def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m): +def tune_nuisance_ssm_mar(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m): d0_s1 = np.intersect1d(np.where(d == 0)[0], np.where(s == 1)[0]) d1_s1 = np.intersect1d(np.where(d == 1)[0], np.where(s == 1)[0]) g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d0_s1) g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d1_s1) - if score == "nonignorable": - dx = np.column_stack((x, d, z)) - else: - dx = np.column_stack((x, d)) + dx = np.column_stack((x, d)) pi_tune_res = tune_grid_search(s, dx, ml_pi, smpls, param_grid_pi, n_folds_tune) @@ -295,3 +292,44 @@ def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tu m_best_params = [xx.best_params_ for xx in m_tune_res] return g0_best_params, g1_best_params, pi_best_params, m_best_params + + +def tune_nuisance_ssm_nonignorable( + y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m +): + + train_inds = [tr for (tr, _) in smpls] + + inner0_list, inner1_list = [], [] + for tr in train_inds: + i0, i1 = train_test_split(tr, test_size=0.5, stratify=d[tr] + 2 * s[tr], random_state=42) + inner0_list.append(i0) + inner1_list.append(i1) + + X_dz = np.c_[x, d.reshape(-1, 1), z.reshape(-1, 1)] + pi_tune_res = tune_grid_search(s, X_dz, ml_pi, [(i0, np.array([])) for i0 in inner0_list], param_grid_pi, n_folds_tune) + pi_best_params = [gs.best_params_ for gs in pi_tune_res] + + pi_hat_full = np.full_like(s, np.nan, dtype=float) + for i0, i1, gs in zip(inner0_list, inner1_list, pi_tune_res): + ml_pi_temp = clone(ml_pi) + ml_pi_temp.set_params(**gs.best_params_) + ml_pi_temp.fit(X_dz[i0], s[i0]) + ph = _predict_zero_one_propensity(ml_pi_temp, X_dz) + pi_hat_full[i1] = ph[i1] + + X_pi = np.c_[x, pi_hat_full] + m_tune_res = tune_grid_search(d, X_pi, ml_m, [(i1, np.array([])) for i1 in inner1_list], param_grid_m, n_folds_tune) + m_best_params = [gs.best_params_ for gs in m_tune_res] + + X_pi_d = np.c_[x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)] + inner1_d0_s1 = [i1[(d[i1] == 0) & (s[i1] == 1)] for i1 in inner1_list] + inner1_d1_s1 = [i1[(d[i1] == 1) & (s[i1] == 1)] for i1 in inner1_list] + + g0_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d0_s1], param_grid_g, n_folds_tune) + g1_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d1_s1], param_grid_g, n_folds_tune) + + g0_best_params = [gs.best_params_ for gs in g0_tune_res] + g1_best_params = [gs.best_params_ for gs in g1_tune_res] + + return g0_best_params, g1_best_params, pi_best_params, m_best_params diff --git a/doubleml/irm/tests/test_ssm_tune.py b/doubleml/irm/tests/test_ssm_tune.py index 0fafbc13..0d285da9 100644 --- a/doubleml/irm/tests/test_ssm_tune.py +++ b/doubleml/irm/tests/test_ssm_tune.py @@ -9,7 +9,7 @@ import doubleml as dml from ...tests._utils import draw_smpls -from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm +from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm_mar, tune_nuisance_ssm_nonignorable @pytest.fixture(scope="module", params=[RandomForestRegressor(random_state=42)]) @@ -115,41 +115,73 @@ def dml_ssm_fixture( np.random.seed(42) smpls = all_smpls[0] if tune_on_folds: - g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm( - y, - x, - d, - z, - s, - clone(learner_g), - clone(learner_m), - clone(learner_m), - smpls, - score, - n_folds_tune, - par_grid["ml_g"], - par_grid["ml_pi"], - par_grid["ml_m"], - ) + if score == "missing-at-random": + g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar( + y, + x, + d, + z, + s, + clone(learner_g), + clone(learner_m), + clone(learner_m), + smpls, + n_folds_tune, + par_grid["ml_g"], + par_grid["ml_pi"], + par_grid["ml_m"], + ) + elif score == "nonignorable": + g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable( + y, + x, + d, + z, + s, + clone(learner_g), + clone(learner_m), + clone(learner_m), + smpls, + n_folds_tune, + par_grid["ml_g"], + par_grid["ml_pi"], + par_grid["ml_m"], + ) else: xx = [(np.arange(len(y)), np.array([]))] - g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm( - y, - x, - d, - z, - s, - clone(learner_g), - clone(learner_m), - clone(learner_m), - xx, - score, - n_folds_tune, - par_grid["ml_g"], - par_grid["ml_pi"], - par_grid["ml_m"], - ) + if score == "missing-at-random": + g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar( + y, + x, + d, + z, + s, + clone(learner_g), + clone(learner_m), + clone(learner_m), + xx, + n_folds_tune, + par_grid["ml_g"], + par_grid["ml_pi"], + par_grid["ml_m"], + ) + elif score == "nonignorable": + g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable( + y, + x, + d, + z, + s, + clone(learner_g), + clone(learner_m), + clone(learner_m), + xx, + n_folds_tune, + par_grid["ml_g"], + par_grid["ml_pi"], + par_grid["ml_m"], + ) g0_best_params = g0_best_params * n_folds g1_best_params = g1_best_params * n_folds