Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating non-IEEE floats - such as bfloat16, bfloat32, or various float8_* types #3959

Open
qthequartermasterman opened this issue Apr 27, 2024 · 3 comments
Labels
enhancement it's not broken, but we want it to be better

Comments

@qthequartermasterman
Copy link
Contributor

Thank you for hypothesis! It's a wonderful library that I use frequently.

While investigating a bug report from one of the users of my extension library: qthequartermasterman/hypothesis-torch#20, I realized that there is no obvious way (at least to me) in hypothesis to build a strategy that can generate only floats of width other than 16, 32, or 64, such as bfloat16 or bfloat32.

Is there a straightforward way to construct such a strategy?

There is a mention of a desire to support these types after the float overhaul, but I cannot find any other references to such work ever being done.

I looked into the internals of how floats handles width currently, and it looks like floats_of could be extended using a similar technique of stuct packing/unpacking. Before I implement that however, I was hoping to see if there is a more straightforward way.

Is there a straightforward way to generate only bfloat16-valid floats using hypothesis?

@Zac-HD Zac-HD changed the title Questions about generating floats with width not equal to 16, 32, or 64, such as bfloat16 or bfloat32. Generating non-IEEE floats - such as bfloat16, bfloat32, or various float8_* types Apr 27, 2024
@Zac-HD
Copy link
Member

Zac-HD commented Apr 27, 2024

As an immediately-available workaround, I'll observe that bfloat16 and bfloat32 are simply truncated versions of the standard float32 and float64 respectively - so you could generate those and throw away the lower half of the bytes (mostly mantissa), with some minimal work to fix up NaNs (which would otherwise turn into +/- inf).


In the longer term, I'd like to support these increasingly popular types directly. We'll need to think of this in relation to #3921, which gives us two main options:

  • Have draw_float() always produce a float64, and derive narrower types from that value.
    • This makes life easier for backends, but introduces a lot of redundancy into the IR for every smaller float type.
  • Add a dtype: Literal["float64", "float32", "float16", "bfloat32", "bfloat16"] = "float64" argument to draw_float()
    • Annoying for backends, but less redundant. Overall I think this is the way to go; we can mitigate the challenge by providing a downcast function that backends can use or copy (if e.g. they only support IEEE floats).

In either case, float8_* types are complicated and diverse enough (see e.g. jax-ml/ml_dtypes), but with only 256 distinct values, that I think we should handle them 'above' the IR layer with something more like a sampled-from strategy.

The user-facing API would be to generalize the current st.floats(width=...) argument to also accept any of the strings indicating a known float dtype. We could also consider renaming the argument; the compatibility issues are manageable.

@qthequartermasterman
Copy link
Contributor Author

Thanks for the speedy response.

I experimented with manually truncating the lower half of the bytes (using struct.pack/struct.unpack) and then passing those into torch, but I discovered that it was (significantly) slower than simply filtering out values above/below the min/max values for bfloat16, and then letting torch implicitly coerce the values. Although manually generating correct values looks to be asymptotically faster on average (because it avoids the expensive filtering), it seems to take thousands of examples (at least on my machine) before it managed to notice an advantage over letting torch do its own conversion (though I will admit I didn't do an systematic study). I suspect this is related to the fact that torch can do that down-casting very close to the hardware level, instead of having to deal with high level python packages and multiple conversions that the struct solution requires. I realize that delegating that functionality to torch is not an option to many use cases.

I definitely look forward to the continuing hypothesis development around this topic in the future, especially if adding native functionality can do that down-casting very quickly.

@Zac-HD
Copy link
Member

Zac-HD commented Apr 28, 2024

"Compile to native code" is definitely on the wishlist: #3074 (comment) - glad it's working in the meantime though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement it's not broken, but we want it to be better
Projects
None yet
Development

No branches or pull requests

2 participants