Skip to content

Commit 81674b1

Browse files
committed
reshape fixed
1 parent c47abe0 commit 81674b1

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

include/xtensor-python/pycontainer.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ namespace xt
225225
template <class S>
226226
inline void pycontainer<D>::reshape(const S& shape)
227227
{
228-
if (shape.size() != this->dimension() || !std::equal(shape.begin(), shape.end(), this->shape().begin()))
228+
if (shape.size() != this->dimension() || !std::equal(std::begin(shape), std::end(shape), std::begin(this->shape())))
229229
{
230230
reshape(shape, layout_type::row_major);
231231
}
@@ -254,7 +254,7 @@ namespace xt
254254
template <class S>
255255
inline void pycontainer<D>::reshape(const S& shape, const strides_type& strides)
256256
{
257-
derived_type tmp(shape, strides);
257+
derived_type tmp(xtl::forward_sequence<shape_type>(shape), strides);
258258
*static_cast<derived_type*>(this) = std::move(tmp);
259259
}
260260

test/test_common.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ namespace xt
130130
m_data = {-1, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19};
131131
m_layout = layout_type::dynamic;
132132
m_assigner.resize(m_shape[0]);
133-
for (std::size_t i = 0; i < m_shape[0]; ++i)
133+
for (std::size_t i = 0; i < std::size_t(m_shape[0]); ++i)
134134
{
135135
m_assigner[i].resize(m_shape[1]);
136136
}
@@ -184,10 +184,10 @@ namespace xt
184184
auto v_copy_b = vec;
185185
std::array<std::size_t, 3> ar = {3, 2, 4};
186186
std::vector<std::size_t> vr = {3, 2, 4};
187-
// v_copy_a.reshape(ar, true);
188-
// compare_shape(v_copy_a, rm);
189-
// v_copy_b.reshape(vr, true);
190-
// compare_shape(v_copy_b, rm);
187+
v_copy_a.reshape(ar);
188+
compare_shape(v_copy_a, rm);
189+
v_copy_b.reshape(vr);
190+
compare_shape(v_copy_b, rm);
191191
}
192192

193193
{

test/test_pyarray.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ namespace xt
161161
{
162162
pyarray<int> a;
163163
test_reshape(a);
164+
165+
pyarray<int> b = { {1, 2}, {3, 4} };
166+
a.reshape(b.shape());
167+
EXPECT_EQ(a.shape(), b.shape());
164168
}
165169

166170
TEST(pyarray, transpose)

test/test_pytensor.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ namespace xt
160160
{
161161
pytensor<int, 3> a;
162162
test_reshape<pytensor<int, 3>, container_type>(a);
163+
164+
pytensor<int, 3> b = { { { 1, 2 },{ 3, 4 } } };
165+
a.reshape(b.shape());
166+
EXPECT_EQ(a.shape(), b.shape());
163167
}
164168

165169
TEST(pytensor, transpose)

0 commit comments

Comments
 (0)