Skip to content

Commit

Permalink
r.learn.ml2 fix issue cross-validation with newer scikit-learn versio…
Browse files Browse the repository at this point in the history
…ns (#704)

* Fix r.learn.ml2 test_append error caused by attempting remove temporary map from another mapset

* Update r.learn.train.py

Fix issue with using > 3 positional arguments for cross_val_predict

* Black formatting changes
  • Loading branch information
stevenpawley committed Feb 26, 2022
1 parent d050707 commit 99f15f0
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/raster/r.learn.ml2/r.learn.train/r.learn.train.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,13 @@ def main():
from sklearn.model_selection import cross_val_predict

preds = cross_val_predict(
estimator, X, y, group_id, cv=outer, n_jobs=n_jobs, fit_params=fit_params
estimator=estimator,
X=X,
y=y,
groups=group_id,
cv=outer,
n_jobs=n_jobs,
fit_params=fit_params,
)

test_idx = [test for train, test in outer.split(X, y)]
Expand Down

0 comments on commit 99f15f0

Please sign in to comment.