-
-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numba #710
Numba #710
Conversation
arviz/stats/diagnostics.py
Outdated
@@ -510,13 +511,22 @@ def geweke(ary, first=0.1, last=0.5, intervals=20): | |||
last_slice = ary[int(end - last * (end - start)) :] | |||
|
|||
z_score = first_slice.mean() - last_slice.mean() | |||
z_score /= np.sqrt(first_slice.var() + last_slice.var()) | |||
if numba_check(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is not required. While this if else pattern works I challenge us both to come up with a "nicer" pattern. I have some ideas and we can talk together about this.
Again don't worry about it right now but did want to make the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Ban-zee In the last 20 minutes I thought about this and think we should change on this PR before we move.
We shouldn't do a numba check everytime a function is called. I think a preferred pattern would be to check if Numba is installed once when ArviZ is loaded and then use that constant throughout the entire program.
I also think this will fit in well because users can optionally turn numba off, even if they have it installed.
If you want I can set some time to talk about this synchronously if you want :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done. I have a few questions in mind and I am up for discussion anytime. :)
benchmarks/benchmarks.py
Outdated
|
||
|
||
class Hist: | ||
def time_histogram(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These look like tabs not spaces :)
Can you run black over this file for consistent formatting of all .py files?
@Ban-zee Im not getting the mock iterable error mentioned in Slack, but another one. Was the mock iterable error fixed?
|
benchmarks/benchmarks.py
Outdated
|
||
|
||
class CircStd: | ||
def time_circ_std(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like its still tabs in here. I ask that this be converted to spaces
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, look's like something went wrong with black.
return np.histogram(data, bins=100) | ||
|
||
|
||
class Variance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing independent logic is a good start and makes it easy to see how numba should be implemented. Is it possible to also benchmark the arviz functions themselves with and without numba?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an issue. This was my initial plan but importing arviz breaks the build and the benchmarks do not execute. In my notebook, the reduction in the timings of the arviz methods was somewhat proportionate to the reduction in the time of the modified function; so I thought that if importing arviz does not work then I can benchmark the modified methods to ensure that the things are going in the right direction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strategy makes sense. Let us know how we can help you figure out how to keep the build from breaking. Doesn't need to be done on this pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting changes on if numba_check()
pattern
class Hist: | ||
def time_histogram(self): | ||
try: | ||
data = np.random.rand(10000, 1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small thing. In these checks the data line is repeated but it seems it doesn't have to be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be changed. It's just that asv takes into account the data initialization time as well so I had to include it in both blocks so as to nullify this effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Today I learned something
Could it be that |
Also, out of curiosity, have you done any experiments or benchmarks with numba on |
I have removed the local path; I had added it for my ease and forgot to replace it :| Unfortunately, I haven't tested |
asv.conf.json
Outdated
//"install_timeout": 600, | ||
|
||
// the base URL to show a commit for the project. | ||
"show_commit_url": "https://github.com/Ban-zee/arviz/commit/", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be modified to ArviZ instead of fork?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know how I missed that :( , I'll change it by my next commit
Just curious who is @ABM8? @Ban-zee is this another account that you have? |
That's me. Looks like I did not rebase :(. I'll resolve this after I wake up. |
Out of curiosity why not use one account? |
It has stuck with me ever since I made the account and I did not bother changing it. :| |
c09fbf5
to
a3fd796
Compare
Just need to add one or two more tests for increasing the code coverage. Apart from that, the pr is up for review. |
arviz/stats/diagnostics.py
Outdated
from ..utils import _var_names | ||
|
||
from ..utils import _var_names, conditional_jit, conditional_vect | ||
from .. import Numba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the cyclic import problem comes from this line. This line is in both diagnostics.py and stats.py, thus, it would explain why the cyclic import warnings finish always in one of these two files, and it looks like it is importing from all ArviZ, hence the cycle. I don't really know how to import things from the init file from within the same library though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init too imports stats, plots and data. Hence the cyclic import I guess.
Up for review!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments. It looks really good!
arviz/plots/elpdplot.py
Outdated
@@ -12,7 +12,8 @@ | |||
format_coords_as_labels, | |||
set_xticklabels, | |||
) | |||
from ..stats import waic, loo, ELPDData | |||
from ..stats import waic, loo | |||
from ..stats import ELPDData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from ..stats import waic, loo, ELPDData
arviz/tests/test_diagnostics.py
Outdated
@@ -4,6 +4,7 @@ | |||
import numpy as np | |||
from numpy.testing import assert_almost_equal, assert_array_almost_equal | |||
import pandas as pd | |||
import scipy.stats as st |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from scipy.stats import circstd
? I may have missed some call to st, but it looks like only circstd is used
arviz/utils.py
Outdated
if not cls.numba_flag and numba_check(): | ||
cls.numba_flag = True | ||
else: | ||
raise ValueError("Numba is already enabled or not installed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should an error be raised when calling Numba.enable_numba()
if numba is already enables? It feels more natural to just leave numba activated and do nothing (like calling %matplotlib inline
when inline backend is already set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to play safe here but after looking at it again, I think this won't lead to any breakdown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments
arviz/stats/diagnostics.py
Outdated
@@ -445,6 +441,12 @@ def mcse(data, *, var_names=None, method="mean", prob=None): | |||
) | |||
|
|||
|
|||
@conditional_vect | |||
def _sqr(a_a, b_b): | |||
return math.sqrt(a_a + b_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is math really needed? also name should be _sqrt
?
(a_a + b_b)**0.5
arviz/stats/diagnostics.py
Outdated
sd = np.std(ary, ddof=1) | ||
if _numba_flag: | ||
ary = np.ravel(ary) | ||
sd = np.sqrt(svar(ary, ddof=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_sqrt?
ary.ravel() to same line?
arviz/stats/diagnostics.py
Outdated
sd = np.std(ary, ddof=1) | ||
if _numba_flag: | ||
ary = np.ravel(ary) | ||
sd = np.sqrt(svar(ary, ddof=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be ssd/sstd function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can use _sqr, as you suggested.
b_b = b_b + i * i | ||
var = b_b / (len(data)) - ((a_a / (len(data))) ** 2) | ||
var = var * (len(data) / (len(data) - ddof)) | ||
return var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we test this if there are huge and tiny values in the same array
@@ -150,7 +152,7 @@ def test_deterministic(self): | |||
assert (abs(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 6e-5).all(None) | |||
assert abs(np.median(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 1e-14).all(None) | |||
not_rhat = [col for col in reference.columns if col != "rhat_rank"] | |||
assert (abs(reference[not_rhat] - arviz_data[not_rhat]) < 1e-11).all(None) | |||
assert (abs((reference[not_rhat] - arviz_data[not_rhat])).values < 1e-8).all(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just random noise or did the function change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a small change :( so I had to adjust the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. (The accuracy for rhat is something like 1e-6, so this change is fine)
assert_almost_equal(ess_mean_hat, ess_mean_hat_) | ||
assert_almost_equal(ess_sd_hat, ess_sd_hat_) | ||
assert_almost_equal(ess_bulk_hat, ess_bulk_hat_) | ||
assert_almost_equal(ess_tail_hat, ess_tail_hat_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there some difference between the results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one of the results was the rounded version of the other one. This rounding off took place at around 10th or 11th place of the decimal iirc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check the summary code and see if it needs update? (It's is basically copy from the main functions, just trying to minimize duplicate calculations)
|
||
|
||
def test_numba_bfmi(): | ||
state = Numba.numba_flag |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these test methods need docstrings
@@ -530,7 +543,11 @@ def ks_summary(pareto_tail_indices): | |||
df_k : dataframe | |||
Dataframe containing k diagnostic values. | |||
""" | |||
kcounts, _ = np.histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf]) | |||
_numba_flag = Numba.numba_flag |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be set at the top once instead of being called in every func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only being called twice; will setting this it at the top make any difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If its only being called twice then I don't think its worth it
@@ -510,13 +514,22 @@ def geweke(ary, first=0.1, last=0.5, intervals=20): | |||
last_slice = ary[int(end - last * (end - start)) :] | |||
|
|||
z_score = first_slice.mean() - last_slice.mean() | |||
z_score /= np.sqrt(first_slice.var() + last_slice.var()) | |||
if _numba_flag: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My next challenge for you is to get rid of all these if else statements.
There might be a pattern where we can do something like
z_score = _numba(f1, f2, args)
This way all the numba specific logic is contained in the Numba class and all the math logic is left here. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
This is cool -- just so I'm clear about reading this, |
Approximately yes, on my machine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small comment on float formatting, not sure if its real issue or not
arviz/stats/diagnostics.py
Outdated
@@ -993,7 +990,7 @@ def _mc_error(ary, batches=5, circular=False): | |||
std = stats.circstd(ary, high=np.pi, low=-np.pi) | |||
else: | |||
if _numba_flag: | |||
std = np.sqrt(svar(ary)) | |||
std = float(_sqrt(svar(ary), np.zeros(1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float over array does not work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some final tweaks comments. They don't need to be included in this PR, in the next one is fine. It is mainly to not forget them.
Also, I just remembered that I included a call to np.histogram
in ELPDData.__str__
(very similar to ks_summary`) should it be numbified too?
CONTRIBUTING.md
Outdated
|
||
- `cd arviz/` | ||
|
||
- `asv run` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about?
$ cd arviz/
$ asv run
minor comment, it is only to follow the convention from the rest of the contributing file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, and I'll change ELPD data too
return (a_a + b_b) ** 0.5 | ||
|
||
|
||
@conditional_jit | ||
def geweke(ary, first=0.1, last=0.5, intervals=20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea why nor how to solve it, I point it out just in case. If you build the docs locally, you will see that the geweke
link on the api page is not working anymore. It shows the text, but it does not link to the geweke
page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think we might need to do a functools wrap. Decorators "eat" the docstrings of their decorated functions
https://www.thecodeship.com/patterns/guide-to-python-function-decorators/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should do this on this PR because we don't want to regress the docs for every function that we use conditional_jit
for
@@ -66,6 +66,7 @@ Here is the citation in BibTeX format | |||
Quickstart<notebooks/Introduction> | |||
Example Gallery<examples/index> | |||
Cookbook<notebooks/InferenceDataCookbook> | |||
Numba<notebooks/Numba> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely unrelated. Should this toctree have the same elements as the sticky nav-bar on top of the website?
If so this seems like a good ocasion to add InferenceData<notebooks/XarrayforArviZ>
here and ("Numba", "notebooks/Numba"),
between lines 137-138 of doc/conf.py
. Adding numba to the nav-bar is probably more important, otherwise it can be hard to find.
Strange, the tests work fine locally. Any idea why this might be happening? |
It says AssertationError is not ValueError. Did pytest update? |
Sorry for the delay :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot of careful work, @Ban-zee! It looks good.
* Modified gitignore * Added numba * Added numba * Implemented black * Implemented black * Added tests * Added tests * Added benchmark and util tests * Added benchmark and util tests * Fixed benchmark tests * Modified tests * Fixed Lint * Fixed Lint * Another attempt at lint * Another attempt at lint * Black * Changed tests * Changed tests * Forgot black * Black benchmarks * Lint * Lint * Fixed the tests * Removed relative path * Modified asv conf * Made changes * made the modifications * Added tests and numba global function * Black * lint * Increased code coverage * Added tests and lint fix * lint * Sphinx * Attemp 1 at circular import * Cyclic import fix arviz-devs#2 * Circ Import 3 * Attempt 4 * Undoing changes in init * Added tests * Fixed init * Fixed imports * LINT * Added one more test * Grammar * Made review changes * Replaced most of the if-else blocks * float * update * updated docstring * Fixed docs * Added instructions * notebook * Test-1 * Test-2 * TEST-3 * Possible fix * Changes * Added files
Numba has been kept optional. Added benchmarks tests for the modified methods(mainly variance, histogram and stats.circstd). These modified methods are twice as fast as the original numpy and scipy methods. Along with this, I have added a class in utils to allow the user to enable or disable numba at will. The required tests have been added as well.