Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early_stopping_rounds? #158

Open
Moelf opened this issue Jan 11, 2023 · 18 comments
Open

early_stopping_rounds? #158

Moelf opened this issue Jan 11, 2023 · 18 comments

Comments

@Moelf
Copy link
Contributor

Moelf commented Jan 11, 2023

I don't see it being explicitly tested and supported but if it's part of C++ library it should just work?

@Moelf
Copy link
Contributor Author

Moelf commented Jan 11, 2023

image

:(

@ExpandingMan
Copy link
Collaborator

I don't see that parameter listed here... what makes you think it exists?

Also, I'm not sure what the difference is supposed to be between this and num_rounds.

@Moelf
Copy link
Contributor Author

Moelf commented Jan 11, 2023

yeah I think unfortunately it's not part of the model parameter

num_rounds is the maximum rounds in some sense, early_stopping_rounds is the number of rounds without improvement for the fit to halt

@Moelf
Copy link
Contributor Author

Moelf commented Jan 11, 2023

btw related question, is it possible to error when unsupported keywords gets passed to XGBoost? for example the early_stopping_rounds, and because it didn't error, I keep worrying about if I spelled colsample_bytree correctly

@ExpandingMan
Copy link
Collaborator

btw related question, is it possible to error when unsupported keywords gets passed to XGBoost?

I would really like to, this has been an annoyance of mine as well. I've also been burned by this many times.
Unfortunately right now there are no good options: the C API does not return an error code for invalid parameters, and the only way to retrieve a list is from the documentation or source code (i.e. there is no way to incorporate it to XGBoost.jl without a maintenance nightmare that will cause more harm than good).

I just opened an issue for this.

@bobaronoff
Copy link

From past work in R, early stopping is a convenience feature provided by the package as opposed to a libxgboost function. The feature relies on capturing the evaluation log and applying a criteria of when to stop. The optimum num_round value can change with other parameters and it is difficult to tune num_round parameter independent of others. There are implications in how one does a grid search.

If it's important enough to a workflow, it is possible for one to write their own early stopping function. In order to plot learning curves from cross validation, without duplicating the metrics already available in libxgboost, it was necessary to 'intercept' and parse the evaluation log. I posted my modest attempt(i.e. hack) in #148. It has been working quite well although am sure it could be improved. All are welcome to use and adapt to their needs. For early stopping one would need to modify and parse each line of the log as it is recorded; the posted function parses all the lines together after the run.

Using this approach I've been able to create a function to duplicate the cv learning curves one is accustomed to in R (if that is one's workflow). An example ....

fram2_0a

@Moelf
Copy link
Contributor Author

Moelf commented Jan 19, 2023

yeah, well, we have to do this externally in Julia because we can't pass a callable to libxgboost for evaluation set evals

@ExpandingMan
Copy link
Collaborator

This isn't that hard to do with current XGBoost.jl, but yes, it is annoying that one can't access the evaluation metrics it's already computing via watchlist and instead has to compute it separately.

I'd be happy to help get through a PR that makes this more ergonomic. I'm open to replacing the watchlist mechanism completely, though that might be pretty tough without breaking changes.

@bobaronoff
Copy link

Not sure I understand. I am able to set the evaluation metrics with the eval_metric keyword parameter. The results show up in the log and are parse-'able. It does require a bit of an end-around by writing a personal version of one or two of the functions in booster.jl but the crux of it is only a handful of lines.

@bobaronoff
Copy link

Recopied from #148. This is my function to capture the logs.

function xgboost_log(traindm::DMatrix, a...;
                        testdm::Any=[] ,
                        num_round::Integer=10,
                        kw...
                    )
    
    Xy = XGBoost.DMatrix(traindm)
    b = XGBoost.Booster(Xy; kw...)
    update_feature_names::Bool=false
    if typeof(testdm)==DMatrix
        watchlist=Dict("train"=>traindm , "test"=>testdm)
    else
        watchlist=Dict("train"=>traindm)
    end
    names = collect(Iterators.map(string, keys(watchlist)))
    watch = collect(Iterators.map(x -> x.handle, values(watchlist)))
    thelog = Vector{String}(undef,0)
    for j in 1:num_round
        XGBoost.xgbcall(XGBoost.XGBoosterUpdateOneIter, b.handle, j, Xy.handle)
        o = Ref{Ptr{Int8}}()
        XGBoost.xgbcall(XGBoost.XGBoosterEvalOneIter, b.handle, j, watch, names, length(watch), o)
        push!(thelog,unsafe_string(o[]))
        XGBoost._maybe_update_feature_names!(b, Xy, update_feature_names)
    end
    return (booster=b , log=parsethelog(thelog))
end
'''

@ExpandingMan
Copy link
Collaborator

I'm not enthused about the prospect of opening the pandora's box of parsing log output, it's bad enough that we are already trying to do this to identify warning messages, in my opinion it would be much nicer to just have it run evaluation normally so the whole thing is accessible and controllable. That said, I wouldn't stand in the way of it, particularly if you left it flexible enough to leave it open for later improvements, it could be replaced with a non-parsing version at some point in the future.

@bobaronoff
Copy link

I am not advocating to add this to the package, just offering a solution to those who 'need' the functionality. I am not a programmer and am not sophisticated to know the downsides. I needed this functionality and did what I could to make it work. The only difference between my function and what is already in booster.jl is that instead of sending a string to INFO , it is saved in a vector. The rest is just combining 3 existing booster.jl functions in to one. The R package does essentially the same thing but has a callback system. Creating that sort of thing is over my head, so I opted for what I could manage. It is working quite well for me. Parsing turned out to be easier than I feared - even with multiple metrics. Of course if libxgboost changes their reporting string structure I am back to revise.

Before going this route I did write my own evaluation routines but it became unmanageable. Each objective has its own set of possible metrics and after 3 or 4 it became tedious - there are a lot of objective/metric pairs. Am still trying to get the hang of MLJ and I suppose that is out there for the user. Since this approach uses libxgboost for the metrics, everything is in line with their documentation.

I am quite fine with how you have XGBoost.jl structured. I've learned where most things are; if I need a wrinkle here or there am happy to personalize my functions - I thought that was kind of the Julia way??? You do an excellent job maintaining this very important package and providing prompt feedback to users. I thank you !!!!

@Moelf
Copy link
Contributor Author

Moelf commented Jan 20, 2023

actually

it's bad enough that we are already trying to do this to identify warning messages,

I'm not sure if it's for or against the idea to parsing logs to get loss during training, I mean given we already do this we might as well utilize existing infrastructure :)

@ExpandingMan
Copy link
Collaborator

I'm not sure if it's for or against the idea to parsing logs to get loss during training, I mean given we already do this we might as well utilize existing infrastructure :)

It's not infrastructure, it's one badly-written regex, and it's done because there is no alternative. For evaluations there is an alternative, there is no clear reason why this needs to be done by the watchlist arg, we already have plenty of library functions available for doing it.

Again, I'm not going to stand in the way of a PR that provides parsing of the logs to extract evaluation results as a feature, but I think we should keep it flexible enough that it could be made more reasonable in the future.

@Moelf
Copy link
Contributor Author

Moelf commented Jan 20, 2023

one reason to not do it in Julia would be performance, also keep in mind we'd have to some how evaluate it with the same metric as the training and handle GPU etc.

@ExpandingMan
Copy link
Collaborator

one reason to not do it in Julia would be performance

Why should it be more expensive to run via a separate call to the predict function? The overhead should be minuscule. If that's not the case, something is very wrong.

also keep in mind we'd have to some how evaluate it with the same metric as the training and handle GPU etc.

We already have calls to do all this stuff, it shouldn't be hard. I don't think it should be more complicated than providing a function argument to update! and providing reasonable defaults.

@bobaronoff
Copy link

Actually, it was the need for performance that lead me to explore parsing the log. One could argue that learning curves are old fashioned but that is how I was trained to build a gradient boost model, so for me learning curve it is. I did write a cross validation function based on predict and it works okay. However, consider the task. In a typical 10 fold learning curve one requires 20 calls to predict per round (10 for train and 10 for test). A typical model has 500 rounds. That's 10,000 calls to predict and 10,000 calls to a function that processes the return in to a metric. Although predict allows one to specify a particular round the lower rounds are not cache'd. Predicting round 100 drops through all 1-99 trees even though it went through the same trees when predicting for round 99. For even a simple model this is a lot of stuff. Pre-allocating space for the predictions and as efficient broadcasting as I could, the processing time was considerably improved but still a bit long. Reducing the curve to every 10 rounds brought the processing time to what I considered appropriate. I haven't figured out yet how to do this with MLJ. I am sure it will be better at the 1,000 metric calculations but the 1,000 calls to libxgboost for prediction will be the same. In comparison, the time to parse an evaluation log is trivial. After I re-wrote my learning curve function with data from the parsed log, the run times with evaluations every round are comparable to the 'predict' method every 10 rounds. Also, I don't need to keep track of or code the default metrics. For my needs this seemed the better solution. I did observe that booster.jl grows every tree one round at a time ( a call to libxgboost for each round). I was wondering/hoping if libxgboost relegated any cache space for prior levels predictions as that is needed to grow the next tree, but that is off topic. My best effort at creating a 'predict' approach may not be the best metric but it is what I have; parsing the evaluation log won out for my purpose and level of skill. This is all predicated on simple cross validation. If one is in to nested cross validation, comparing different algorithms, then MLJ would seem to be the best ( if not only) way to go.

Aside from my bashfulness, the reason I hesitate to submit a PR is that I am next to useless when it comes to git and GitHub. It's all I can do to keep my personal packages straight and not corrupt my Julia environment. I am afraid if I fork and download, my laptop might blow up due to my ignorance.

All in all, XGBoost.jl is just fine.

@david-sun-1
Copy link
Contributor

david-sun-1 commented Oct 8, 2023

Hey all, I too have had a need for early stopping rounds.
I've implemented something that closely follow @bobaronoff 's solution above.

Understand that there could be a few ways to approach it, but it's a start and want to work towards getting this functionality added.

#193

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants