Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
2ece803
add optuna tuning to package
JanTeichertKluge Oct 13, 2025
9b846e3
Merge branch 'main' into j-optuna
JanTeichertKluge Oct 14, 2025
cf5cdf4
update optuna impl.
JanTeichertKluge Oct 15, 2025
f56c0de
update optuna tuning
JanTeichertKluge Oct 21, 2025
4835629
upd gitignore
JanTeichertKluge Oct 21, 2025
fefc21d
Remove examples/ from tracking (now gitignored)
JanTeichertKluge Oct 21, 2025
81b9924
Merge branch 'main' into j-optuna
JanTeichertKluge Oct 21, 2025
0ead5d2
update optuna implementation
JanTeichertKluge Oct 30, 2025
8fec084
add name to studies created in _create_study
JanTeichertKluge Oct 30, 2025
d69c70f
update optuna implementation
JanTeichertKluge Oct 30, 2025
9081fd6
update optuna implementation
JanTeichertKluge Oct 30, 2025
e97c13d
update optuna implementation
JanTeichertKluge Oct 30, 2025
2eda2c8
update optuna implementation
JanTeichertKluge Oct 30, 2025
3489cf3
update optuna implementation
JanTeichertKluge Oct 30, 2025
9585ba4
update optuna implementation
JanTeichertKluge Oct 30, 2025
8318eee
update optuna implementation
JanTeichertKluge Oct 30, 2025
65a8ebd
update optuna implementation
JanTeichertKluge Oct 30, 2025
9e60a62
update optuna implementation
JanTeichertKluge Oct 30, 2025
906e01b
update optuna implementation, add unit tests
JanTeichertKluge Oct 30, 2025
13a7fd1
adjust PLIV optuna tuning
JanTeichertKluge Oct 30, 2025
0acec3e
update for single set of hyperparams instead of fold-specifics
JanTeichertKluge Nov 10, 2025
21825f8
update tests
JanTeichertKluge Nov 10, 2025
4610ef5
Merge branch 'main' into j-optuna
JanTeichertKluge Nov 10, 2025
826236a
del cache files
JanTeichertKluge Nov 10, 2025
fa58939
revert estimation.py since everything optuna related is moved to sep.…
JanTeichertKluge Nov 10, 2025
f6287ee
rename `params` to `ml_param_space` in tune_optuna method
JanTeichertKluge Nov 10, 2025
3a28380
update ´cv´ object handling, update for partial tuning
JanTeichertKluge Nov 10, 2025
d5c9586
renaming tune_optuna to tune_ml_models
JanTeichertKluge Nov 10, 2025
21164da
change return type of nuisance tuning methods to container objects only
JanTeichertKluge Nov 10, 2025
f2c759b
formatting
JanTeichertKluge Nov 10, 2025
4914b4e
change print statements to logging statements in _tune_optuna.py
JanTeichertKluge Nov 10, 2025
70501ec
add depreciation warning for tune method in doubleml.py
JanTeichertKluge Nov 10, 2025
7bf2b1e
add depreciation warning for tune method in doubleml.py
JanTeichertKluge Nov 10, 2025
cbe1a8c
add optuna to dev dependencies in pyproject.toml
JanTeichertKluge Nov 10, 2025
0752b18
revert changes in std. tune method
JanTeichertKluge Nov 10, 2025
129f9aa
simplify optuna tuning logic
JanTeichertKluge Nov 10, 2025
3b04016
update docstrings
JanTeichertKluge Nov 10, 2025
0a11bcf
revert changes from first implementation of optuna tuning
JanTeichertKluge Nov 10, 2025
c46afcd
formatting
JanTeichertKluge Nov 10, 2025
90de845
formatting, reverting changes from previous implementation.
JanTeichertKluge Nov 10, 2025
20aabd8
add pseudo method to util class in tests
JanTeichertKluge Nov 10, 2025
a8940bc
update optuna setting validation
JanTeichertKluge Nov 11, 2025
4e9de2b
fix docstring
JanTeichertKluge Nov 11, 2025
9915be3
add doctest execution to pytest job in CI workflow
SvenKlaassen Nov 11, 2025
9aa36d8
fix tune_ml_models docsting
SvenKlaassen Nov 11, 2025
ae19cf0
refactor scoring methods resolution into a separate method
SvenKlaassen Nov 11, 2025
fc73db5
refactor optuna settings validation logic
SvenKlaassen Nov 11, 2025
d556627
refactor _create_objective function by removing unused learner_name p…
SvenKlaassen Nov 11, 2025
ff81970
fix the input for objective; remove learner_label
SvenKlaassen Nov 11, 2025
a838bad
remove unnecessary test
SvenKlaassen Nov 11, 2025
7c45a89
simplify set param logic
SvenKlaassen Nov 11, 2025
9d070f0
remove sampler test
SvenKlaassen Nov 11, 2025
edecd4e
move optuna dependency to the main dependencies section
SvenKlaassen Nov 11, 2025
8ec446d
update optuna tuning methods to not use n_jobs_cv since optuna has it…
JanTeichertKluge Nov 12, 2025
0528e97
adjsut tests
JanTeichertKluge Nov 12, 2025
4a483e9
update test cases
JanTeichertKluge Nov 12, 2025
d0062de
adjust params / learner specific settings / param spaces merging
JanTeichertKluge Nov 13, 2025
4ed93d1
adjust params / learner specific settings / param spaces merging
JanTeichertKluge Nov 13, 2025
d490274
fix messages in _resolve_optuna_scoring
JanTeichertKluge Nov 13, 2025
0de85a9
update tests for optuna
JanTeichertKluge Nov 13, 2025
966d306
update DMLOptunaResult import DMLOptunaResult
JanTeichertKluge Nov 13, 2025
7137511
bug fix
JanTeichertKluge Nov 13, 2025
61bd137
adjust naming, differentiate between learner name and params name
JanTeichertKluge Nov 13, 2025
ff4c878
fix best_params_ change to best_params
JanTeichertKluge Nov 13, 2025
9445cd3
run pre-commit
JanTeichertKluge Nov 13, 2025
f2697c1
fix setting-merge bug in _get_optuna_settings
JanTeichertKluge Nov 13, 2025
477fdb5
fix issue for joining param_spaces
JanTeichertKluge Nov 13, 2025
0e7e7c3
add docstring for class DoubleMLAPOS
JanTeichertKluge Nov 13, 2025
e283e85
SKIP docstring test output for model classes, adjust optuna tuning
JanTeichertKluge Nov 13, 2025
f4bcef3
fix unit tests to reduce computation time
JanTeichertKluge Nov 13, 2025
84dfa64
SKIP docstring test output for doubleml model
JanTeichertKluge Nov 13, 2025
0913b1b
refactor _get_optuna_settings to simplify learner-specific settings r…
SvenKlaassen Nov 14, 2025
4274f85
validate optuna settings to ensure keys are dictionaries
SvenKlaassen Nov 14, 2025
d6aa3a2
skip doctest output for tuning results in DoubleML class
SvenKlaassen Nov 14, 2025
56dc711
add random seed initialization in test_solve_quadratic_inequation for…
SvenKlaassen Nov 14, 2025
40f2262
remove unused import of _dml_tune_optuna in DoubleMLPLIV class
SvenKlaassen Nov 14, 2025
37edf6b
ensure all parameter spaces are returned
SvenKlaassen Nov 14, 2025
303a1bb
add missing newline in DMLOptunaResult docstring
SvenKlaassen Nov 14, 2025
1ab7843
update deprecation notice for tune specify version 0.13.0
SvenKlaassen Nov 14, 2025
448c959
enhance DMLOptunaResult docstring with detailed attributes and examples
SvenKlaassen Nov 14, 2025
fe26cdf
simplify unit tests for optuna
JanTeichertKluge Nov 14, 2025
17f8710
Merge branch 'j-optuna' of https://github.com/DoubleML/doubleml-for-p…
JanTeichertKluge Nov 14, 2025
87911b8
update tests for optuna tuning
JanTeichertKluge Nov 14, 2025
7d0864c
first implementation of tune_ml_models for APOS
JanTeichertKluge Nov 14, 2025
5923723
formatting
JanTeichertKluge Nov 14, 2025
b114c87
test shared docstring for tune_ml_models method(s)
JanTeichertKluge Nov 14, 2025
daa5e25
remove fixed random_state from KFold in resolve_optuna_cv for more fl…
SvenKlaassen Nov 14, 2025
6124ca6
Merge branch 'j-optuna' of https://github.com/DoubleML/doubleml-for-p…
JanTeichertKluge Nov 17, 2025
2c3cc71
add optuna tuning to QTE and DiDMulti
JanTeichertKluge Nov 17, 2025
40d3c3a
add unit tests for wrapper models
JanTeichertKluge Nov 17, 2025
3a7140d
create re-usable docstring for tune_ml_models methods
JanTeichertKluge Nov 17, 2025
c03d86c
fix issue for using tuned models from output container
JanTeichertKluge Nov 17, 2025
fe9ab58
remove exampole
JanTeichertKluge Nov 17, 2025
97f1c8c
check docstring tests for tune_ml_models methods
JanTeichertKluge Nov 17, 2025
d5bb48b
remove double import
JanTeichertKluge Nov 17, 2025
ae2f263
adjust preliminary theta estimate for optuna tuning
JanTeichertKluge Nov 17, 2025
2685725
adjust pre-docstring for shared docstring in tune_ml_models
JanTeichertKluge Nov 17, 2025
98a6643
fix partial_z naming
SvenKlaassen Nov 17, 2025
7ae51a4
refactor: replace pseudo_target with cross_val_predict for improved p…
SvenKlaassen Nov 17, 2025
6ca20fa
fix code warning
SvenKlaassen Nov 17, 2025
85595d4
remove unnecessary test
SvenKlaassen Nov 17, 2025
342c83b
remove test for tune_res.predict since estimator is not fitted anymore
JanTeichertKluge Nov 17, 2025
349095b
Merge branch 'j-optuna' of https://github.com/DoubleML/doubleml-for-p…
JanTeichertKluge Nov 17, 2025
205b705
fix typo in tune_ml_model docstring
JanTeichertKluge Nov 17, 2025
f885492
reduce trials in docstring for tune_ml_models method
JanTeichertKluge Nov 17, 2025
9dcefcc
fix workflow to check only docstrings and tests marked as ci
JanTeichertKluge Nov 17, 2025
0034969
Merge branch 'main' into j-optuna
SvenKlaassen Nov 24, 2025
042e76c
add _nuisance_tuning_optuna placeholder for DoubleMLLPLR class
SvenKlaassen Nov 24, 2025
731e458
remove usage of `study.set_metric_names` since it causes warnings in …
JanTeichertKluge Nov 24, 2025
896d117
update docstring since we allowing for mixed specifications in parame…
JanTeichertKluge Nov 24, 2025
67e7481
change arg `force_all_finite` to `ensure_all_finite` in check_X_y cal…
JanTeichertKluge Nov 24, 2025
270ac04
change arg `force_all_finite` to `ensure_all_finite` for logistic mod…
JanTeichertKluge Nov 24, 2025
43359d0
add _tune_optuna for lplr
SvenKlaassen Nov 26, 2025
970a0b6
refactor: move Optuna utility functions to a new module for better or…
SvenKlaassen Nov 26, 2025
350c87d
move tuning test into submodules
SvenKlaassen Nov 26, 2025
4b5631d
fix M_hat tuning calculation
SvenKlaassen Nov 26, 2025
7cff553
add pytest.mark.ci to various tuning test cases for continuous integr…
SvenKlaassen Nov 26, 2025
2308750
refactor: update model parameters in tuning test
SvenKlaassen Nov 26, 2025
02c3535
test: add NotImplementedError test for nonignorable score in DoubleMLSSM
SvenKlaassen Nov 26, 2025
fd0bf1b
remove check for tune_optuna and nonignorable score in SSM.py
SvenKlaassen Nov 26, 2025
e764d95
refactor: update min_samples_leaf and max_leaf_nodes for DecisionTree…
SvenKlaassen Nov 26, 2025
81799fd
update tests for overfitting
SvenKlaassen Nov 26, 2025
4b9d96b
refactor: add min_samples_split and max_depth parameters to DecisionT…
SvenKlaassen Nov 26, 2025
2f8b1c6
fix: increase n_trials from 5 to 10 in basic Optuna settings
SvenKlaassen Nov 26, 2025
56c32bf
update cvar setting to allow more overfitting
SvenKlaassen Nov 26, 2025
29b7fd7
standardize tuning tests
SvenKlaassen Nov 27, 2025
ff5455d
update tuning tests: increase n_obs to 500, dgp_type to 4, and set n_…
SvenKlaassen Nov 27, 2025
6ba4c30
formatting
SvenKlaassen Nov 27, 2025
80c731f
update treatment assignment logic in DGP functions: update probabilit…
SvenKlaassen Dec 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
matrix.config.os != 'ubuntu-latest' ||
matrix.config.python-version != '3.12'
run: |
pytest --doctest-modules --ignore-glob="doubleml/**/tests/*" --ignore-glob="doubleml/tests/*"
pytest -m ci
pytest -m ci_rdd
Expand Down
1 change: 1 addition & 0 deletions doubleml/data/tests/test_dml_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def test_dml_data_w_missings(generate_data_irm_w_missings):
assert dml_data.force_all_x_finite == "allow-nan"


@pytest.mark.ci
def test_dml_data_w_missing_d(generate_data1):
data = generate_data1
np.random.seed(3141)
Expand Down
41 changes: 34 additions & 7 deletions doubleml/did/datasets/dgp_did_CS2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,35 @@ def make_did_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, time_typ
6. Treatment assignment:
For non-experimental settings (DGP 1-4), the probability of being in treatment group :math:`g` is:
For non-experimental settings (DGP 1-4), the probability of being in treatment group :math:`g` is computed as follows:
.. math::
- Compute group-specific logits for each observation:
.. math::
\\text{logit}_{i,g} = f_{ps,g}(W_{ps})
The logits are clipped to the range [-2.5, 2.5] for numerical stability.
- Convert logits to uncapped probabilities via softmax:
.. math::
p^{\\text{uncapped}}_{i,g} = \\frac{\\exp(\\text{logit}_{i,g})}{\\sum_{g'} \\exp(\\text{logit}_{i,g'})}
- Clip uncapped probabilities to the range [0.05, 0.95]:
.. math::
p^{\\text{clipped}}_{i,g} = \\min(\\max(p^{\\text{uncapped}}_{i,g}, 0.05), 0.95)
- Renormalize clipped probabilities so they sum to 1 for each observation:
.. math::
p_{i,g} = \\frac{p^{\text{clipped}}_{i,g}}{\\sum_{g'} p^{\\text{clipped}}_{i,g'}}
P(G_i = g) = \\frac{\\exp(f_{ps,g}(W_{ps}))}{\\sum_{g'} \\exp(f_{ps,g'}(W_{ps}))}
- Assign each observation to a treatment group by sampling from the categorical distribution defined by :math:`p_{i,g}`.
For experimental settings (DGP 5-6), each treatment group (including never-treated) has equal probability:
Expand Down Expand Up @@ -159,7 +183,7 @@ def make_did_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, time_typ
`dim_x` (int, default=4):
Dimension of feature vectors.
`xi` (float, default=0.9):
`xi` (float, default=0.5):
Scale parameter for the propensity score function.
`n_periods` (int, default=5):
Expand Down Expand Up @@ -188,7 +212,7 @@ def make_did_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, time_typ

c = kwargs.get("c", 0.0)
dim_x = kwargs.get("dim_x", 4)
xi = kwargs.get("xi", 0.9)
xi = kwargs.get("xi", 0.75)
n_periods = kwargs.get("n_periods", 5)
anticipation_periods = kwargs.get("anticipation_periods", 0)
n_pre_treat_periods = kwargs.get("n_pre_treat_periods", 2)
Expand Down Expand Up @@ -228,8 +252,11 @@ def make_did_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, time_typ
p = np.ones(n_treatment_groups) / n_treatment_groups
d_index = np.random.choice(n_treatment_groups, size=n_obs, p=p)
else:
unnormalized_p = np.exp(_f_ps_groups(features_ps, xi, n_groups=n_treatment_groups))
p = unnormalized_p / unnormalized_p.sum(1, keepdims=True)
logits = np.clip(_f_ps_groups(features_ps, xi, n_groups=n_treatment_groups), a_min=-2.5, a_max=2.5)
unnormalized_p = np.exp(logits)
p_uncapped = unnormalized_p / unnormalized_p.sum(1, keepdims=True)
p_clipped = np.clip(p_uncapped, a_min=0.05, a_max=0.95)
p = p_clipped / p_clipped.sum(1, keepdims=True)
d_index = np.array([np.random.choice(n_treatment_groups, p=p_row) for p_row in p])

# fixed effects (shape (n_obs, n_time_periods))
Expand Down
32 changes: 28 additions & 4 deletions doubleml/did/datasets/dgp_did_cs_CS2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,35 @@ def make_did_cs_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, lambd
6. Treatment assignment:
For non-experimental settings (DGP 1-4), the probability of being in treatment group :math:`g` is:
For non-experimental settings (DGP 1-4), the probability of being in treatment group :math:`g` is computed as follows:
.. math::
- Compute group-specific logits for each observation:
.. math::
\\text{logit}_{i,g} = f_{ps,g}(W_{ps})
The logits are clipped to the range [-2.5, 2.5] for numerical stability.
- Convert logits to uncapped probabilities via softmax:
.. math::
p^{\\text{uncapped}}_{i,g} = \\frac{\\exp(\\text{logit}_{i,g})}{\\sum_{g'} \\exp(\\text{logit}_{i,g'})}
- Clip uncapped probabilities to the range [0.05, 0.95]:
.. math::
p^{\\text{clipped}}_{i,g} = \\min(\\max(p^{\\text{uncapped}}_{i,g}, 0.05), 0.95)
- Renormalize clipped probabilities so they sum to 1 for each observation:
.. math::
p_{i,g} = \\frac{p^{\text{clipped}}_{i,g}}{\\sum_{g'} p^{\\text{clipped}}_{i,g'}}
P(G_i = g) = \\frac{\\exp(f_{ps,g}(W_{ps}))}{\\sum_{g'} \\exp(f_{ps,g'}(W_{ps}))}
- Assign each observation to a treatment group by sampling from the categorical distribution defined by :math:`p_{i,g}`.
For experimental settings (DGP 5-6), each treatment group (including never-treated) has equal probability:
Expand Down Expand Up @@ -148,7 +172,7 @@ def make_did_cs_CS2021(n_obs=1000, dgp_type=1, include_never_treated=True, lambd
`dim_x` (int, default=4):
Dimension of feature vectors.
`xi` (float, default=0.9):
`xi` (float, default=0.5):
Scale parameter for the propensity score function.
`n_periods` (int, default=5):
Expand Down
78 changes: 78 additions & 0 deletions doubleml/did/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from doubleml.double_ml_score_mixins import LinearScoreMixin
from doubleml.utils._checks import _check_finite_predictions, _check_is_propensity, _check_score
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls
from doubleml.utils._tune_optuna import _dml_tune_optuna


# TODO: Remove DoubleMLDIDData with version 0.12.0
Expand Down Expand Up @@ -427,6 +428,83 @@ def _nuisance_tuning(

return res

def _nuisance_tuning_optuna(
self,
optuna_params,
scoring_methods,
cv,
optuna_settings,
):
"""
Optuna-based hyperparameter tuning for DID nuisance models.

Performs tuning once on the whole dataset using cross-validation,
returning the same optimal parameters for all folds.
"""

x, y = check_X_y(self._dml_data.x, self._dml_data.y, ensure_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, ensure_all_finite=False)

if scoring_methods is None:
if self.score == "observational":
scoring_methods = {"ml_g0": None, "ml_g1": None, "ml_m": None}
else:
scoring_methods = {"ml_g0": None, "ml_g1": None}

# Separate data by treatment status for conditional mean tuning
mask_d0 = d == 0
mask_d1 = d == 1

x_d0 = x[mask_d0, :]
y_d0 = y[mask_d0]
g0_tune_res = _dml_tune_optuna(
y_d0,
x_d0,
self._learner["ml_g"],
optuna_params["ml_g0"],
scoring_methods["ml_g0"],
cv,
optuna_settings,
learner_name="ml_g",
params_name="ml_g0",
)

x_d1 = x[mask_d1, :]
y_d1 = y[mask_d1]
g1_tune_res = _dml_tune_optuna(
y_d1,
x_d1,
self._learner["ml_g"],
optuna_params["ml_g1"],
scoring_methods["ml_g1"],
cv,
optuna_settings,
learner_name="ml_g",
params_name="ml_g1",
)

# Tune propensity score on full dataset for observational score
m_tune_res = None
if self.score == "observational":
m_tune_res = _dml_tune_optuna(
d,
x,
self._learner["ml_m"],
optuna_params["ml_m"],
scoring_methods["ml_m"],
cv,
optuna_settings,
learner_name="ml_m",
params_name="ml_m",
)

if self.score == "observational":
results = {"ml_g0": g0_tune_res, "ml_g1": g1_tune_res, "ml_m": m_tune_res}
else:
results = {"ml_g0": g0_tune_res, "ml_g1": g1_tune_res}

return results

def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
"""
Computes a benchmark for a given set of features.
Expand Down
74 changes: 74 additions & 0 deletions doubleml/did/did_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_check_score,
)
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls
from doubleml.utils._tune_optuna import _dml_tune_optuna
from doubleml.utils.propensity_score_processing import PSProcessorConfig, init_ps_processor


Expand Down Expand Up @@ -666,6 +667,79 @@ def _nuisance_tuning(

return res

def _nuisance_tuning_optuna(
self,
optuna_params,
scoring_methods,
cv,
optuna_settings,
):

x, y = check_X_y(self._x_data_subset, self._y_data_subset, ensure_all_finite=False)
x, d = check_X_y(x, self._g_data_subset, ensure_all_finite=False)

if scoring_methods is None:
if self.score == "observational":
scoring_methods = {"ml_g0": None, "ml_g1": None, "ml_m": None}
else:
scoring_methods = {"ml_g0": None, "ml_g1": None}

mask_d0 = d == 0
mask_d1 = d == 1

x_d0 = x[mask_d0, :]
y_d0 = y[mask_d0]
g0_param_grid = optuna_params["ml_g0"]
g0_scoring = scoring_methods["ml_g0"]
g0_tune_res = _dml_tune_optuna(
y_d0,
x_d0,
self._learner["ml_g"],
g0_param_grid,
g0_scoring,
cv,
optuna_settings,
learner_name="ml_g",
params_name="ml_g0",
)

x_d1 = x[mask_d1, :]
y_d1 = y[mask_d1]
g1_param_grid = optuna_params["ml_g1"]
g1_scoring = scoring_methods["ml_g1"]
g1_tune_res = _dml_tune_optuna(
y_d1,
x_d1,
self._learner["ml_g"],
g1_param_grid,
g1_scoring,
cv,
optuna_settings,
learner_name="ml_g",
params_name="ml_g1",
)

m_tune_res = None
if self.score == "observational":
m_tune_res = _dml_tune_optuna(
d,
x,
self._learner["ml_m"],
optuna_params["ml_m"],
scoring_methods["ml_m"],
cv,
optuna_settings,
learner_name="ml_m",
params_name="ml_m",
)

if self.score == "observational":
results = {"ml_g0": g0_tune_res, "ml_g1": g1_tune_res, "ml_m": m_tune_res}
else:
results = {"ml_g0": g0_tune_res, "ml_g1": g1_tune_res}

return results

def _sensitivity_element_est(self, preds):
y = self._y_data_subset
d = self._g_data_subset
Expand Down
Loading
Loading