Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba #710

Merged
merged 59 commits into from
Jul 4, 2019
Merged

Numba #710

merged 59 commits into from
Jul 4, 2019

Conversation

Ban-zee
Copy link
Contributor

@Ban-zee Ban-zee commented Jun 18, 2019

Numba has been kept optional. Added benchmarks tests for the modified methods(mainly variance, histogram and stats.circstd). These modified methods are twice as fast as the original numpy and scipy methods. Along with this, I have added a class in utils to allow the user to enable or disable numba at will. The required tests have been added as well.

@@ -510,13 +511,22 @@ def geweke(ary, first=0.1, last=0.5, intervals=20):
last_slice = ary[int(end - last * (end - start)) :]

z_score = first_slice.mean() - last_slice.mean()
z_score /= np.sqrt(first_slice.var() + last_slice.var())
if numba_check():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not required. While this if else pattern works I challenge us both to come up with a "nicer" pattern. I have some ideas and we can talk together about this.

Again don't worry about it right now but did want to make the comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ban-zee In the last 20 minutes I thought about this and think we should change on this PR before we move.

We shouldn't do a numba check everytime a function is called. I think a preferred pattern would be to check if Numba is installed once when ArviZ is loaded and then use that constant throughout the entire program.

I also think this will fit in well because users can optionally turn numba off, even if they have it installed.

If you want I can set some time to talk about this synchronously if you want :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done. I have a few questions in mind and I am up for discussion anytime. :)

@canyon289 canyon289 added the GSOC label Jun 19, 2019


class Hist:
def time_histogram(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like tabs not spaces :)

Can you run black over this file for consistent formatting of all .py files?

@canyon289
Copy link
Member

@Ban-zee Im not getting the mock iterable error mentioned in Slack, but another one. Was the mock iterable error fixed?

image

test_utils.py:150 (test_conditional_vect_numba_decorator_keyword)
monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x19e0b6c18>

def test_conditional_vect_numba_decorator_keyword(monkeypatch):
    """Checks else statement and vect keyword argument"""
    from arviz import utils

    # Mock import lib to return numba with hit method which returns a function that returns kwargs
    numba_mock = Mock()
    monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock)

    def vectorize(**kwargs):
        """overwrite numba.vectorize function"""
        return lambda x: (x(), kwargs)

    numba_mock.vectorize = vectorize
  @utils.conditional_vect(keyword_argument="A keyword argument")
    def placeholder_func(a, b):

test_utils.py:165:


../utils.py:104: in wrapper
return numba.vectorize(**kwargs)(function)


x = <function test_conditional_vect_numba_decorator_keyword..placeholder_func at 0x1a2916950>

return lambda x: (x(), kwargs)
E TypeError: placeholder_func() missing 2 required positional arguments: 'a' and 'b'

test_utils.py:161: TypeError



class CircStd:
def time_circ_std(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like its still tabs in here. I ask that this be converted to spaces

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, look's like something went wrong with black.

return np.histogram(data, bins=100)


class Variance:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing independent logic is a good start and makes it easy to see how numba should be implemented. Is it possible to also benchmark the arviz functions themselves with and without numba?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue. This was my initial plan but importing arviz breaks the build and the benchmarks do not execute. In my notebook, the reduction in the timings of the arviz methods was somewhat proportionate to the reduction in the time of the modified function; so I thought that if importing arviz does not work then I can benchmark the modified methods to ensure that the things are going in the right direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strategy makes sense. Let us know how we can help you figure out how to keep the build from breaking. Doesn't need to be done on this pr

Copy link
Member

@canyon289 canyon289 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes on if numba_check() pattern

class Hist:
def time_histogram(self):
try:
data = np.random.rand(10000, 1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small thing. In these checks the data line is repeated but it seems it doesn't have to be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be changed. It's just that asv takes into account the data initialization time as well so I had to include it in both blocks so as to nullify this effect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Today I learned something

@OriolAbril
Copy link
Member

Could it be that asv.conf.json should be ignored and not added to ArviZ? It is full of local paths and .asv/ was added to .gitignore

@OriolAbril
Copy link
Member

Also, out of curiosity, have you done any experiments or benchmarks with numba on make_ufunc?

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 19, 2019

I have removed the local path; I had added it for my ease and forgot to replace it :|
Regarding its inclusion, the conf file is needed for running the benchmarks locally. .asv contains environments and the results of the tests so I decided to gitignore it.

Unfortunately, I haven't tested make_ufunc yet :(.

asv.conf.json Outdated
//"install_timeout": 600,

// the base URL to show a commit for the project.
"show_commit_url": "https://github.com/Ban-zee/arviz/commit/",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be modified to ArviZ instead of fork?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know how I missed that :( , I'll change it by my next commit

@canyon289
Copy link
Member

Just curious who is @ABM8? @Ban-zee is this another account that you have?
image

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 19, 2019

That's me. Looks like I did not rebase :(. I'll resolve this after I wake up.

@canyon289
Copy link
Member

That's me. Looks like I did not rebase :(. I'll resolve this after I wake up.

Out of curiosity why not use one account?

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 20, 2019

It has stuck with me ever since I made the account and I did not bother changing it. :|

@Ban-zee Ban-zee force-pushed the numba branch 2 times, most recently from c09fbf5 to a3fd796 Compare June 20, 2019 22:29
@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 20, 2019

Just need to add one or two more tests for increasing the code coverage. Apart from that, the pr is up for review.

from ..utils import _var_names

from ..utils import _var_names, conditional_jit, conditional_vect
from .. import Numba
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the cyclic import problem comes from this line. This line is in both diagnostics.py and stats.py, thus, it would explain why the cyclic import warnings finish always in one of these two files, and it looks like it is importing from all ArviZ, hence the cycle. I don't really know how to import things from the init file from within the same library though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init too imports stats, plots and data. Hence the cyclic import I guess.

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 21, 2019

Up for review!!

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. It looks really good!

@@ -12,7 +12,8 @@
format_coords_as_labels,
set_xticklabels,
)
from ..stats import waic, loo, ELPDData
from ..stats import waic, loo
from ..stats import ELPDData
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from ..stats import waic, loo, ELPDData

@@ -4,6 +4,7 @@
import numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pandas as pd
import scipy.stats as st
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from scipy.stats import circstd? I may have missed some call to st, but it looks like only circstd is used

arviz/utils.py Outdated
if not cls.numba_flag and numba_check():
cls.numba_flag = True
else:
raise ValueError("Numba is already enabled or not installed")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should an error be raised when calling Numba.enable_numba() if numba is already enables? It feels more natural to just leave numba activated and do nothing (like calling %matplotlib inline when inline backend is already set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to play safe here but after looking at it again, I think this won't lead to any breakdown.

Copy link
Contributor

@ahartikainen ahartikainen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments

@@ -445,6 +441,12 @@ def mcse(data, *, var_names=None, method="mean", prob=None):
)


@conditional_vect
def _sqr(a_a, b_b):
return math.sqrt(a_a + b_b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is math really needed? also name should be _sqrt?

(a_a + b_b)**0.5

sd = np.std(ary, ddof=1)
if _numba_flag:
ary = np.ravel(ary)
sd = np.sqrt(svar(ary, ddof=1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sqrt?

ary.ravel() to same line?

sd = np.std(ary, ddof=1)
if _numba_flag:
ary = np.ravel(ary)
sd = np.sqrt(svar(ary, ddof=1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be ssd/sstd function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use _sqr, as you suggested.

b_b = b_b + i * i
var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
var = var * (len(data) / (len(data) - ddof))
return var
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test this if there are huge and tiny values in the same array

@@ -150,7 +152,7 @@ def test_deterministic(self):
assert (abs(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 6e-5).all(None)
assert abs(np.median(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 1e-14).all(None)
not_rhat = [col for col in reference.columns if col != "rhat_rank"]
assert (abs(reference[not_rhat] - arviz_data[not_rhat]) < 1e-11).all(None)
assert (abs((reference[not_rhat] - arviz_data[not_rhat])).values < 1e-8).all(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just random noise or did the function change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a small change :( so I had to adjust the tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. (The accuracy for rhat is something like 1e-6, so this change is fine)

assert_almost_equal(ess_mean_hat, ess_mean_hat_)
assert_almost_equal(ess_sd_hat, ess_sd_hat_)
assert_almost_equal(ess_bulk_hat, ess_bulk_hat_)
assert_almost_equal(ess_tail_hat, ess_tail_hat_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some difference between the results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of the results was the rounded version of the other one. This rounding off took place at around 10th or 11th place of the decimal iirc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check the summary code and see if it needs update? (It's is basically copy from the main functions, just trying to minimize duplicate calculations)



def test_numba_bfmi():
state = Numba.numba_flag
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these test methods need docstrings

@@ -530,7 +543,11 @@ def ks_summary(pareto_tail_indices):
df_k : dataframe
Dataframe containing k diagnostic values.
"""
kcounts, _ = np.histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
_numba_flag = Numba.numba_flag
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be set at the top once instead of being called in every func?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only being called twice; will setting this it at the top make any difference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its only being called twice then I don't think its worth it

@@ -510,13 +514,22 @@ def geweke(ary, first=0.1, last=0.5, intervals=20):
last_slice = ary[int(end - last * (end - start)) :]

z_score = first_slice.mean() - last_slice.mean()
z_score /= np.sqrt(first_slice.var() + last_slice.var())
if _numba_flag:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My next challenge for you is to get rid of all these if else statements.

There might be a pattern where we can do something like

z_score = _numba(f1, f2, args)

This way all the numba specific logic is contained in the Numba class and all the math logic is left here. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 22, 2019

Screenshot of the benchmark tests
bench-T

@ColCarroll
Copy link
Member

This is cool -- just so I'm clear about reading this, numba is making circ_std ~30% faster, histogram ~60% faster, and variance ~50% faster?

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 22, 2019

Approximately yes, on my machine.

Copy link
Contributor

@ahartikainen ahartikainen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comment on float formatting, not sure if its real issue or not

@@ -993,7 +990,7 @@ def _mc_error(ary, batches=5, circular=False):
std = stats.circstd(ary, high=np.pi, low=-np.pi)
else:
if _numba_flag:
std = np.sqrt(svar(ary))
std = float(_sqrt(svar(ary), np.zeros(1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float over array does not work

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some final tweaks comments. They don't need to be included in this PR, in the next one is fine. It is mainly to not forget them.

Also, I just remembered that I included a call to np.histogram in ELPDData.__str__ (very similar to ks_summary`) should it be numbified too?

CONTRIBUTING.md Outdated

- `cd arviz/`

- `asv run`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about?

$ cd arviz/
$ asv run

minor comment, it is only to follow the convention from the rest of the contributing file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, and I'll change ELPD data too

return (a_a + b_b) ** 0.5


@conditional_jit
def geweke(ary, first=0.1, last=0.5, intervals=20):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why nor how to solve it, I point it out just in case. If you build the docs locally, you will see that the geweke link on the api page is not working anymore. It shows the text, but it does not link to the geweke page.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think we might need to do a functools wrap. Decorators "eat" the docstrings of their decorated functions

https://www.thecodeship.com/patterns/guide-to-python-function-decorators/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this on this PR because we don't want to regress the docs for every function that we use conditional_jit for

@@ -66,6 +66,7 @@ Here is the citation in BibTeX format
Quickstart<notebooks/Introduction>
Example Gallery<examples/index>
Cookbook<notebooks/InferenceDataCookbook>
Numba<notebooks/Numba>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely unrelated. Should this toctree have the same elements as the sticky nav-bar on top of the website?

If so this seems like a good ocasion to add InferenceData<notebooks/XarrayforArviZ> here and ("Numba", "notebooks/Numba"), between lines 137-138 of doc/conf.py. Adding numba to the nav-bar is probably more important, otherwise it can be hard to find.

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jun 29, 2019

Strange, the tests work fine locally. Any idea why this might be happening?

@ahartikainen
Copy link
Contributor

ahartikainen commented Jun 29, 2019

with pytest.raises(ValueError) as err:
    get_coords(data, coords)
    
>       assert "Coords {'NOT_A_COORD_NAME'} are invalid coordinate keys" in str(err)

arviz/tests/test_plot_utils.py:112: AssertionError

It says AssertationError is not ValueError. Did pytest update?

@Ban-zee
Copy link
Contributor Author

Ban-zee commented Jul 1, 2019

Sorry for the delay :)

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of careful work, @Ban-zee! It looks good.

@ColCarroll ColCarroll merged commit acdba3f into arviz-devs:master Jul 4, 2019
OriolAbril pushed a commit to OriolAbril/arviz that referenced this pull request Jul 4, 2019
* Modified gitignore

* Added numba

* Added numba

* Implemented black

* Implemented black

* Added tests

* Added tests

* Added benchmark and util tests

* Added benchmark and util tests

* Fixed benchmark tests

* Modified tests

* Fixed Lint

* Fixed Lint

* Another attempt at lint

* Another attempt at lint

* Black

* Changed tests

* Changed tests

* Forgot black

* Black benchmarks

* Lint

* Lint

* Fixed the tests

* Removed relative path

* Modified asv conf

* Made changes

* made the modifications

* Added tests and numba global function

* Black

* lint

* Increased code coverage

* Added tests and lint fix

* lint

* Sphinx

* Attemp 1 at circular import

* Cyclic import fix arviz-devs#2

* Circ Import 3

* Attempt 4

* Undoing changes in init

* Added tests

* Fixed init

* Fixed imports

* LINT

* Added one more test

* Grammar

* Made review changes

* Replaced most of the if-else blocks

* float

* update

* updated docstring

* Fixed docs

* Added instructions

* notebook

* Test-1

* Test-2

* TEST-3

* Possible fix

* Changes

* Added files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants