-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract y and yhat for each test fold from results of evaluate! #575
Comments
Well, you can give
|
Thanks! That's a pretty good solution. However, I think it would be worth adding a method to the MLJ API that does this out of the box. I think something like the following could be pretty useful. struct CVPredictions{T}
ŷ::Vector{<:AbstractVector{T}}
y::Vector{<:AbstractVector{T}}
end
cv_predict(model, X, y) = # returns a CVPredictions object
cv_predict!(mach) = # returns a CVPredictions object
function evaluate(cvp; measure)
# evaluate measures on each fold
# return model evaluation
end
function evaluate(model, X, y; measure)
cvp = cv_predict(model, X, y)
evaluate(cvp; measure)
end
function evaluate!(mach; measure)
cvp = cv_predict!(mach)
evaluate(cvp; measure)
end
export cv_predict, cv_predict!, evaluate, evaluate! Motivating exampleHere's one example where it would be nice to have the separate To extend the example, suppose I define a measure like this: cost = let
cost_matrix = [0 10;
100 0]
function cost(ŷ, y)
confusion = confusion_matrix(ŷ, y)
sum(confusion .* cost_matrix)
end
end Then if I later decide that I want to change the cost matrix, it would be nice if I could just run |
So basically you just want to insert a new interface point. Sounds like a good idea. A few comments:
|
Sometimes one wants to look at the actual and predicted y values for each test fold in a cross-validation. For example, one might want to make a plot of the residuals versus the predicted values. As far as I can tell, there's not an easy way to do that right now.
This is mentioned in #89, but I thought it would be good to have a more specific issue.
The scikit-learn equivalent of this feature request is
cross_val_predict()
.The text was updated successfully, but these errors were encountered: