Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reproducibility and estimated measurements #24

Closed
parmentelat opened this issue Jun 19, 2024 · 11 comments
Closed

reproducibility and estimated measurements #24

parmentelat opened this issue Jun 19, 2024 · 11 comments

Comments

@parmentelat
Copy link
Contributor

a final word about the paper's reproducibility

some values in the figure appear with a circle; about that the legend states that:

A square means that the experiment was run until the time per iteration had stabilized and used to forecast the time usage if the experiment was run to completion

can you please comment further about the practical means, if any, to achieve that, and about possible means to automate that process nevertheless ?

@Sm00thix
Copy link
Owner

Hi @parmentelat,

Sure! Some of the experiments with leave-one-out cross-validation (LOOCV) would take extremely long (up to several decades in one instance) to run to completion - especially the ones using the scikit-learn implementation.

For this reason, I opted to let the experiments run for a while and then, after some time, use the ratio of (time passed / cross validation splits completed) to forecast how much time would pass if I had let the entire experiment run to completion. This is a sensible thing to do as the splits are balanced. Thus, each split is expected to take the same amount of time to complete.

I waited until some multiple of n_jobs cross-validation splits had completed, noted the time, and manually forecasted the expected time to finish. The timings used in this process were from sklearn.model_selection.cross_validate internal timers that are printed when setting verbose=1.

Automation of this approach may be feasible for the regular CPU implementations by emptying some of cv=KFold() (currently defined on line 288 in timings.py.

Automation using the fast cross-validation implementation is likely not necessary at it is extremely fast - even for LOOCV with 1e6 samples. However, it can likely be supported by simply removing some multiple of n_jobs indices from cv_splits after line line 334 and before line 335 in timings.py.

Automation using the JAX implementations can probably be achieved in the same manner as for the fast cross-validation implementation by removing some of the indices in cv_splits after line 439 and before line 440 in timings.py.

These are my three ideas for automation of the estimation process. However, I have not actually tried it and it may be the case that I have missed some detail that makes automation more challenging.

Please let me know if you will accept the manual approach. Otherwise, I will try and implement my ideas as stated above.

@parmentelat
Copy link
Contributor Author

hi @Sm00thix
I'm also pinging @basileMarchand as he expressed interest in the matter

I must admit it would be very nice to be able to script the process of re-obtaining all the numbers depicted in the paper's main figure
in this respect I believe your ideas above truly deserve a shot at least, since you seem to deem it feasible
and even if full automation turns out to be out of reach, any progress that would make it more accessible for others to retrieve your figures for estimated runs would be welcome

regardless, once you are done with the improvements if any, I would recommend that you enrich paper/README.md to either give instructions on how to reproduce (if you improve the current situation), or at the very least how you have proceeded - i.e. the details above - if not

Sm00thix added a commit that referenced this issue Jun 26, 2024
…so modified paper/README.md to clarify the estimation process. Also updated the notebook to reflect the newly added --estimate flag in time_pls.py. This is related to #24
@Sm00thix
Copy link
Owner

Hi @basileMarchand and @parmentelat,

To accommodate your request to automate the benchmark estimation process, I have attempted to implement the ideas I described in my previous comment in this thread.

I was successful for the regular CPU implementations. These are scikit-learn's NIPALS and my own NumPy-based IKPLS implementations - i.e., the ones that can be benchmarked with time_pls.py with the flags -model sk, -model np1, and -model np2, respectively. They can now be estimated by adding the --estimate flag to the call to time_pls.py in the benchmark branch of the IKPLS repo.

For the JAX implementations and fast cross-validation implementation, I was unable to implement automation of the estimation process. Instead I updated the paper/README.md to clarify the manual approach as suggested by @parmentelat.

These changes are visible in 841dc4a.

If you agree with these changes, I will merge them into main.
Alternatively, if you think it is confusing to have automation only available for a subset of the implementations, I will remove the possibility to automate the regular CPU implementations and instead only clarify the manual approach in paper/README.md.

Please let me know which option you prefer. I will then proceed as you wish and close the issue.

Below are my explanations for why I was unable to automate the benchmark estimation of the JAX implementations and the fast cross-validation implementation:

I was not successful for the JAX implementations. The reason being the way I implemented the cross-validation where simply removing elements from cv_splits removes them from both training and validation splits, effectively just making n smaller. I do not think I can make this work. At the very least I would have to rewrite cross_validate (and possibly related methods) in jax_ikpls_base.py. I sincerely do not think I should modify the implementations with the sole goal of making automation of benchmark estimation possible.

For the fast cross-validation algorithm, timing the execution of a small number of cross-validation splits, computing the time/split ratio, and then linearly forecasting a time estimate for computing the whole cross-validation fails.
The reason for this is to be found in the nature of the fast cross-validation algorithm: It initially performs one relatively expensive computation and then performs relatively cheap operations during iteration over the folds. This initial one-time expensive computation implies that the suggested approach to estimation will fail to account for the fact that the initial expensive computation needs only to be performed once. Details about the algorithm can be found in this paper.
In practice, estimation of the fast cross-validation algorithm seems somewhat unnecessary as it is, as the name implies, fast. Even for the somewhat extreme case of leave-one-out cross-validation with a million samples, 500 features, 10 targets and 30 PLS components, it only takes 20 minutes to complete. I also did not estimate its runtime in my experiments - as evidenced by the plots in paper/timings/timings.png.

@parmentelat
Copy link
Contributor Author

hi @Sm00thix

thanks for your work in this area, I really believe this kind of apparently unrewarding task contributes to a much better material for others to toy with :)

My advise would be to

  • keep the automated-estimation code as it helps increasing the reproducibility coverage
  • merge the PR that I just filed reproducibility improvement #30 a minute ago, where I have superficially improved the readme, and refactored the reproducing notebook to make it more user-friendly
  • maybe at that point improve a little bit the very last paragraph of the readme, because as such it remains really vague on how to carry out estimations for the runs where it is not automated; maybe simply with a digest of the material above ...

thanks again for bearing with me on this one :)

Sm00thix added a commit that referenced this issue Jun 27, 2024
* Added support for estimation of benchmarking for sk, np1, and np2. Also modified paper/README.md to clarify the estimation process. Also updated the notebook to reflect the newly added --estimate flag in time_pls.py. This is related to #24

* reproducibility improvement (#30)

* clarify the estimation methods

* refactor the notebook for reproducing results

more options available from the command line
more consistent way to implement the various filtering stages

* reproducing notebook: replace -s (bool) with -s <n>

this way we can go for more and more complex scenarios
also the shortcut for --dry-run becomes more traditional -n

* Clarified instructions for manual benchmarking in paper/README.md. Related to #24

---------

Co-authored-by: parmentelat <thierry.parmentelat@inria.fr>
@Sm00thix
Copy link
Owner

Hi @parmentelat,

Thanks for your assistance on this one. I agree that we have made it easier for others to toy around!

I have merged your pull request to the benchmark branch and, in turn, merged the benchmark branch with main. I have also tried to improve the very last paragraph of paper/README.md as per your instructions.

All the changes are merged to main in ee3de3e

If you agree with these changes, I think we can close this issue :-)

@parmentelat
Copy link
Contributor Author

parmentelat commented Jun 27, 2024

hey @Sm00thix

actually I am currently trying to recompute all the data from the figure
this is something that I'll put on back burner, hopefully just waiting for the job to complete
for now I have secured a GPU-capable container to that end

in the process I have run into 2 separate issues

  • as of ee3de3e, my repro script is bugged wrt the -g option, that is broken (gpu-needing runs are always skipped)
  • also there's a need to define
    # Allow JAX to use 64-bit floating point precision.
    jax.config.update("jax_enable_x64", True)
    

I'll get this repro script to work as far as I can, and will file a PR once I'm done; please keep this open until then

@parmentelat
Copy link
Contributor Author

on a side track, I am seeing this during my reproducibility attempts, with jax-related runs
feel free to open a separate issue if need be

# 13/282: (jax2 x 30 x 1 x 100000 - False) - expect 4.27
python3 time_pls.py -o timings/user_timings.csv -model jax2 -n_components 30 -n_splits 1 -n 100000 -k500 -m 1 -n_jobs -1 
Fitting JAX Improved Kernel PLS Algorithm #2 with 30 components on 100000 samples with 500 features and 1 targets.
/home/ubuntu/miniconda3/lib/python3.12/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
/home/ubuntu/miniconda3/lib/python3.12/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
Time: 4.175951265729964```

@Sm00thix Sm00thix mentioned this issue Jun 27, 2024
Merged
@Sm00thix
Copy link
Owner

Sm00thix commented Jun 27, 2024

Hi @parmentelat,

In regards to your benchmarks: Sounds good. Depending on your machine, running all the benchmarks may take a few weeks if I recall correctly. That is, if you estimate the same ones I did and run to completion the same ones I did.

I'm sorry if I broke your notebook with some of my own changes. Be aware that a NaN in the 'inferred' column in `paper/timings/timings.csv' is to be regarded as False. That is, I (admittedly somewhat foolishly), did not write anything in that column if I did not estimate the value. Doing boolean logic with NaN's is error prone and I tried to account for that when I modified your script. Please let me know if I can be of any help in this regard. I will keep this issue open until that is resolved.

In relation to the FutureWarning that you mention in your comment, I found the culprit and fixed it. I commented on the details in 9cf6cc7. I also ran a few of the JAX benchmarks and noticed no difference when compared to my original benchmarks.

I initially forgot to bump the ikpls version number. I did this in fc694c0.

All of these changes are currently in the dev branch which I will merge to main once all tests pass.

Alright. I had made a couple of errors which caused the computation of gradients using reverse mode differentiation to fail. I fixed those in 97c9024.

Update

I merged the dev branch to main.

@parmentelat
Copy link
Contributor Author

the paper, at least timings.csv, holds 606 measurement points; right now I have collected in the 360's
I might not go far enough for it to spend weeks, but I'll try to maximize the number of data points that I can gather
I will file a PR with the latest changes that I made in the notebook now that this script works fine for me

no worries about the notebook, the first version was very rough and needed ironing anyways; plus your code was actually working, so...

also for clarity, none of the warnings that I reported are to be deemed showstoppers, it's just FYI in case you'd have missed them

@Sm00thix
Copy link
Owner

Hi @parmentelat,

I merged your PR in b27201b. Thanks for your work on this one! Do you want to wait until your benchmarks have completed before we close this issue? Otherwise, I suggest we close it now :-)

@parmentelat
Copy link
Contributor Author

ok for closing now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants