Skip to content

Commit cfb08ae

Browse files
committed
Fixing constructor bug pytensor<..., 0>
1 parent 9aa58f8 commit cfb08ae

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

include/xtensor-python/pytensor.hpp

+19-4
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,26 @@ namespace pybind11
100100
}
101101
};
102102

103-
}
103+
} // namespace detail
104104
}
105105

106106
namespace xt
107107
{
108+
namespace detail {
109+
110+
template <std::size_t N, typename = void>
111+
struct numpy_strides
112+
{
113+
npy_intp value[N];
114+
};
115+
116+
template <std::size_t N>
117+
struct numpy_strides<N, typename std::enable_if_t<(N == 0)>::type>
118+
{
119+
npy_intp* value = nullptr;
120+
};
121+
122+
} // namespace detail
108123

109124
template <class T, std::size_t N, layout_type L>
110125
struct xiterable_inner_types<pytensor<T, N, L>>
@@ -433,8 +448,8 @@ namespace xt
433448
template <class T, std::size_t N, layout_type L>
434449
inline void pytensor<T, N, L>::init_tensor(const shape_type& shape, const strides_type& strides)
435450
{
436-
npy_intp python_strides[N];
437-
std::transform(strides.begin(), strides.end(), python_strides,
451+
detail::numpy_strides<N> python_strides;
452+
std::transform(strides.begin(), strides.end(), python_strides.value,
438453
[](auto v) { return sizeof(T) * v; });
439454
int flags = NPY_ARRAY_ALIGNED;
440455
if (!std::is_const<T>::value)
@@ -445,7 +460,7 @@ namespace xt
445460

446461
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
447462
PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.release().ptr(), static_cast<int>(shape.size()),
448-
const_cast<npy_intp*>(shape.data()), python_strides,
463+
const_cast<npy_intp*>(shape.data()), python_strides.value,
449464
nullptr, flags, nullptr));
450465

451466
if (!tmp)

test/test_pytensor.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ namespace xt
6565
EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error);
6666
}
6767

68+
TEST(pytensor, scalar_from_shape)
69+
{
70+
std::array<size_t, 0> shape;
71+
auto a = pytensor<double, 0>::from_shape(shape);
72+
pytensor<double, 0> b(1.2);
73+
EXPECT_TRUE(a.size() == b.size());
74+
EXPECT_TRUE(xt::has_shape(a, b.shape()));
75+
}
76+
6877
TEST(pytensor, strided_constructor)
6978
{
7079
central_major_result<container_type> cmr;

test_python/main.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ void col_major_array(xt::pyarray<double, xt::layout_type::column_major>& arg)
227227
}
228228
}
229229

230+
xt::pytensor<int, 0> xscalar(const xt::pytensor<int, 1>& arg)
231+
{
232+
return xt::sum(arg);
233+
}
234+
230235
template <class T>
231236
using ndarray = xt::pyarray<T, xt::layout_type::row_major>;
232237

@@ -285,6 +290,8 @@ PYBIND11_MODULE(xtensor_python_test, m)
285290
m.def("col_major_array", col_major_array);
286291
m.def("row_major_tensor", row_major_tensor);
287292

293+
m.def("xscalar", xscalar);
294+
288295
py::class_<C>(m, "C")
289296
.def(py::init<>())
290297
.def_property_readonly(

test_python/test_pyarray.py

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def test_col_row_major(self):
151151
xt.col_major_array(varF)
152152
xt.col_major_array(varF[:, :, 0]) # still col major!
153153

154+
def test_xscalar(self):
155+
var = np.arange(50, dtype=int)
156+
self.assertTrue(np.sum(var) == xt.xscalar(var))
157+
154158
def test_bad_argument_call(self):
155159
with self.assertRaises(TypeError):
156160
xt.simple_array("foo")

0 commit comments

Comments
 (0)