From f2f8fbc99e21a4b30593e3418e32525f03db4c08 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Wed, 21 Jun 2023 11:59:15 +0200 Subject: [PATCH 1/2] Add minimal support for BFloat16 dtype. --- CHANGELOG.md | 1 + src/dtype.rs | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6758c2a1..7d948ef34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Unreleased - Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378)) - Add support for ASCII (`PyFixedString`) and Unicode (`PyFixedUnicode`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378)) + - Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381)) - v0.19.0 - Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369)) diff --git a/src/dtype.rs b/src/dtype.rs index 7bbd257f6..c32113df3 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -5,7 +5,7 @@ use std::os::raw::{ use std::ptr; #[cfg(feature = "half")] -use half::f16; +use half::{bf16, f16}; use num_traits::{Bounded, Zero}; use pyo3::{ exceptions::{PyIndexError, PyValueError}, @@ -15,6 +15,8 @@ use pyo3::{ AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo, Python, ToPyObject, }; +#[cfg(feature = "half")] +use pyo3::{sync::GILOnceCell, IntoPy, Py}; use crate::npyffi::{ NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, @@ -477,6 +479,22 @@ impl_element_scalar!(f64 => NPY_DOUBLE); #[cfg(feature = "half")] impl_element_scalar!(f16 => NPY_HALF); +#[cfg(feature = "half")] +unsafe impl Element for bf16 { + const IS_COPY: bool = true; + + fn get_dtype(py: Python) -> &PyArrayDescr { + static DTYPE: GILOnceCell> = GILOnceCell::new(); + + DTYPE + .get_or_init(py, || { + PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").into_py(py) + }) + .clone() + .into_ref(py) + } +} + impl_element_scalar!(Complex32 => NPY_CFLOAT, #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]); impl_element_scalar!(Complex64 => NPY_CDOUBLE, From 1a03f5850a41a589cb024013fae504e76cbc6422 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Thu, 22 Jun 2023 08:31:47 +0200 Subject: [PATCH 2/2] Test support for bfloat16 using ml_dtypes. --- .github/workflows/ci.yml | 8 +++---- tests/array.rs | 47 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 376ded49c..14713e7d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,7 +67,7 @@ jobs: shell: python - name: Test run: | - pip install numpy + pip install numpy ml_dtypes cargo test --all-features # Not on PyPy, because no embedding API if: ${{ !startsWith(matrix.python-version, 'pypy') }} @@ -101,7 +101,7 @@ jobs: continue-on-error: true - uses: taiki-e/install-action@valgrind - run: | - pip install numpy + pip install numpy ml_dtypes cargo test --all-features --release env: CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1 @@ -115,7 +115,7 @@ jobs: - uses: Swatinem/rust-cache@v2 continue-on-error: true - run: | - pip install numpy + pip install numpy ml_dtypes cargo install --locked cargo-careful cargo careful test --all-features @@ -201,7 +201,7 @@ jobs: python-version: 3.9 architecture: x64 - name: Install numpy - run: pip install numpy + run: pip install numpy ml_dtypes - uses: Swatinem/rust-cache@v2 continue-on-error: true - uses: dtolnay/rust-toolchain@stable diff --git a/tests/array.rs b/tests/array.rs index 6cfa8ac63..3564c9c76 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1,7 +1,7 @@ use std::mem::size_of; #[cfg(feature = "half")] -use half::f16; +use half::{bf16, f16}; use ndarray::{array, s, Array1, Dim}; use numpy::{ dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr, @@ -527,7 +527,7 @@ fn reshape() { #[cfg(feature = "half")] #[test] -fn half_works() { +fn half_f16_works() { Python::with_gil(|py| { let np = py.eval("__import__('numpy')", None, None).unwrap(); let locals = [("np", np)].into_py_dict(py); @@ -558,7 +558,48 @@ fn half_works() { py_run!( py, array np, - "np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))" + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))" + ); + }); +} + +#[cfg(feature = "half")] +#[test] +fn half_bf16_works() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + // NumPy itself does not provide a `bfloat16` dtype itself, + // so we import ml_dtypes which does register such a dtype. + let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap(); + let locals = [("np", np), ("mldt", mldt)].into_py_dict(py); + + let array = py + .eval( + "np.array([[1, 2], [3, 4]], dtype='bfloat16')", + None, + Some(locals), + ) + .unwrap() + .downcast::>() + .unwrap(); + + assert_eq!( + array.readonly().as_array(), + array![ + [bf16::from_f32(1.0), bf16::from_f32(2.0)], + [bf16::from_f32(3.0), bf16::from_f32(4.0)] + ] + ); + + array + .readwrite() + .as_array_mut() + .map_inplace(|value| *value *= bf16::from_f32(2.0)); + + py_run!( + py, + array np, + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))" ); }); }