Skip to content

Commit ebe87bc

Browse files
committed
add layout template parameter to pyarray and pytensor
1 parent 360ac92 commit ebe87bc

File tree

5 files changed

+162
-154
lines changed

5 files changed

+162
-154
lines changed

include/xtensor-python/pyarray.hpp

+79-79
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,27 @@
2323

2424
namespace xt
2525
{
26-
template <class T>
26+
template <class T, layout_type L = layout_type::dynamic>
2727
class pyarray;
2828
}
2929

3030
namespace pybind11
3131
{
3232
namespace detail
3333
{
34-
template <class T>
35-
struct handle_type_name<xt::pyarray<T>>
34+
template <class T, xt::layout_type L>
35+
struct handle_type_name<xt::pyarray<T, L>>
3636
{
3737
static PYBIND11_DESCR name()
3838
{
3939
return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
4040
}
4141
};
4242

43-
template <typename T>
44-
struct pyobject_caster<xt::pyarray<T>>
43+
template <typename T, xt::layout_type L>
44+
struct pyobject_caster<xt::pyarray<T, L>>
4545
{
46-
using type = xt::pyarray<T>;
46+
using type = xt::pyarray<T, L>;
4747

4848
bool load(handle src, bool convert)
4949
{
@@ -72,10 +72,10 @@ namespace pybind11
7272
};
7373

7474
// Type caster for casting ndarray to xexpression<pyarray>
75-
template <typename T>
76-
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
75+
template <typename T, xt::layout_type L>
76+
struct type_caster<xt::xexpression<xt::pyarray<T, L>>> : pyobject_caster<xt::pyarray<T, L>>
7777
{
78-
using Type = xt::xexpression<xt::pyarray<T>>;
78+
using Type = xt::xexpression<xt::pyarray<T, L>>;
7979

8080
operator Type&()
8181
{
@@ -89,8 +89,8 @@ namespace pybind11
8989
};
9090

9191
// Type caster for casting xarray to ndarray
92-
template <class T>
93-
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
92+
template <class T, xt::layout_type L>
93+
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
9494
{
9595
};
9696
}
@@ -282,24 +282,24 @@ namespace xt
282282
const array_type* p_a;
283283
};
284284

285-
template <class T>
286-
struct xiterable_inner_types<pyarray<T>>
287-
: xcontainer_iterable_types<pyarray<T>>
285+
template <class T, layout_type L>
286+
struct xiterable_inner_types<pyarray<T, L>>
287+
: xcontainer_iterable_types<pyarray<T, L>>
288288
{
289289
};
290290

291-
template <class T>
292-
struct xcontainer_inner_types<pyarray<T>>
291+
template <class T, layout_type L>
292+
struct xcontainer_inner_types<pyarray<T, L>>
293293
{
294294
using storage_type = xbuffer_adaptor<T*>;
295295
using shape_type = std::vector<typename storage_type::size_type>;
296296
using strides_type = shape_type;
297-
using backstrides_type = pyarray_backstrides<pyarray<T>>;
297+
using backstrides_type = pyarray_backstrides<pyarray<T, L>>;
298298
using inner_shape_type = xbuffer_adaptor<std::size_t*>;
299299
using inner_strides_type = pystrides_adaptor<sizeof(T)>;
300300
using inner_backstrides_type = backstrides_type;
301-
using temporary_type = pyarray<T>;
302-
static constexpr layout_type layout = layout_type::dynamic;
301+
using temporary_type = pyarray<T, L>;
302+
static constexpr layout_type layout = L;
303303
};
304304

305305
/**
@@ -312,13 +312,13 @@ namespace xt
312312
* @tparam T The type of the element stored in the pyarray.
313313
* @sa pytensor
314314
*/
315-
template <class T>
316-
class pyarray : public pycontainer<pyarray<T>>,
317-
public xcontainer_semantic<pyarray<T>>
315+
template <class T, layout_type L>
316+
class pyarray : public pycontainer<pyarray<T, L>>,
317+
public xcontainer_semantic<pyarray<T, L>>
318318
{
319319
public:
320320

321-
using self_type = pyarray<T>;
321+
using self_type = pyarray<T, L>;
322322
using semantic_base = xcontainer_semantic<self_type>;
323323
using base_type = pycontainer<self_type>;
324324
using storage_type = typename base_type::storage_type;
@@ -386,8 +386,8 @@ namespace xt
386386
storage_type& storage_impl() noexcept;
387387
const storage_type& storage_impl() const noexcept;
388388

389-
friend class xcontainer<pyarray<T>>;
390-
friend class pycontainer<pyarray<T>>;
389+
friend class xcontainer<pyarray<T, L>>;
390+
friend class pycontainer<pyarray<T, L>>;
391391
};
392392

393393
/**************************************
@@ -469,8 +469,8 @@ namespace xt
469469
* @name Constructors
470470
*/
471471
//@{
472-
template <class T>
473-
inline pyarray<T>::pyarray()
472+
template <class T, layout_type L>
473+
inline pyarray<T, L>::pyarray()
474474
: base_type()
475475
{
476476
// TODO: avoid allocation
@@ -483,70 +483,70 @@ namespace xt
483483
/**
484484
* Allocates a pyarray with nested initializer lists.
485485
*/
486-
template <class T>
487-
inline pyarray<T>::pyarray(const value_type& t)
486+
template <class T, layout_type L>
487+
inline pyarray<T, L>::pyarray(const value_type& t)
488488
: base_type()
489489
{
490490
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
491491
nested_copy(m_storage.begin(), t);
492492
}
493493

494-
template <class T>
495-
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 1> t)
494+
template <class T, layout_type L>
495+
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 1> t)
496496
: base_type()
497497
{
498498
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
499499
nested_copy(m_storage.begin(), t);
500500
}
501501

502-
template <class T>
503-
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 2> t)
502+
template <class T, layout_type L>
503+
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 2> t)
504504
: base_type()
505505
{
506506
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
507507
nested_copy(m_storage.begin(), t);
508508
}
509509

510-
template <class T>
511-
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 3> t)
510+
template <class T, layout_type L>
511+
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 3> t)
512512
: base_type()
513513
{
514514
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
515515
nested_copy(m_storage.begin(), t);
516516
}
517517

518-
template <class T>
519-
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 4> t)
518+
template <class T, layout_type L>
519+
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 4> t)
520520
: base_type()
521521
{
522522
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
523523
nested_copy(m_storage.begin(), t);
524524
}
525525

526-
template <class T>
527-
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 5> t)
526+
template <class T, layout_type L>
527+
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 5> t)
528528
: base_type()
529529
{
530530
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
531531
nested_copy(m_storage.begin(), t);
532532
}
533533

534-
template <class T>
535-
inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
534+
template <class T, layout_type L>
535+
inline pyarray<T, L>::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
536536
: base_type(h, b)
537537
{
538538
init_from_python();
539539
}
540540

541-
template <class T>
542-
inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
541+
template <class T, layout_type L>
542+
inline pyarray<T, L>::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
543543
: base_type(h, s)
544544
{
545545
init_from_python();
546546
}
547547

548-
template <class T>
549-
inline pyarray<T>::pyarray(const pybind11::object& o)
548+
template <class T, layout_type L>
549+
inline pyarray<T, L>::pyarray(const pybind11::object& o)
550550
: base_type(o)
551551
{
552552
init_from_python();
@@ -558,8 +558,8 @@ namespace xt
558558
* @param shape the shape of the pyarray
559559
* @param l the layout of the pyarray
560560
*/
561-
template <class T>
562-
inline pyarray<T>::pyarray(const shape_type& shape, layout_type l)
561+
template <class T, layout_type L>
562+
inline pyarray<T, L>::pyarray(const shape_type& shape, layout_type l)
563563
: base_type()
564564
{
565565
strides_type strides(shape.size());
@@ -574,8 +574,8 @@ namespace xt
574574
* @param value the value of the elements
575575
* @param l the layout of the pyarray
576576
*/
577-
template <class T>
578-
inline pyarray<T>::pyarray(const shape_type& shape, const_reference value, layout_type l)
577+
template <class T, layout_type L>
578+
inline pyarray<T, L>::pyarray(const shape_type& shape, const_reference value, layout_type l)
579579
: base_type()
580580
{
581581
strides_type strides(shape.size());
@@ -591,8 +591,8 @@ namespace xt
591591
* @param strides the strides of the pyarray
592592
* @param value the value of the elements
593593
*/
594-
template <class T>
595-
inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
594+
template <class T, layout_type L>
595+
inline pyarray<T, L>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
596596
: base_type()
597597
{
598598
init_array(shape, strides);
@@ -604,8 +604,8 @@ namespace xt
604604
* @param shape the shape of the pyarray
605605
* @param strides the strides of the pyarray
606606
*/
607-
template <class T>
608-
inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides)
607+
template <class T, layout_type L>
608+
inline pyarray<T, L>::pyarray(const shape_type& shape, const strides_type& strides)
609609
: base_type()
610610
{
611611
init_array(shape, strides);
@@ -619,8 +619,8 @@ namespace xt
619619
/**
620620
* The copy constructor.
621621
*/
622-
template <class T>
623-
inline pyarray<T>::pyarray(const self_type& rhs)
622+
template <class T, layout_type L>
623+
inline pyarray<T, L>::pyarray(const self_type& rhs)
624624
: base_type(), semantic_base(rhs)
625625
{
626626
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
@@ -639,8 +639,8 @@ namespace xt
639639
/**
640640
* The assignment operator.
641641
*/
642-
template <class T>
643-
inline auto pyarray<T>::operator=(const self_type& rhs) -> self_type&
642+
template <class T, layout_type L>
643+
inline auto pyarray<T, L>::operator=(const self_type& rhs) -> self_type&
644644
{
645645
self_type tmp(rhs);
646646
*this = std::move(tmp);
@@ -656,9 +656,9 @@ namespace xt
656656
/**
657657
* The extended copy constructor.
658658
*/
659-
template <class T>
659+
template <class T, layout_type L>
660660
template <class E>
661-
inline pyarray<T>::pyarray(const xexpression<E>& e)
661+
inline pyarray<T, L>::pyarray(const xexpression<E>& e)
662662
: base_type()
663663
{
664664
// TODO: prevent intermediary shape allocation
@@ -672,28 +672,28 @@ namespace xt
672672
/**
673673
* The extended assignment operator.
674674
*/
675-
template <class T>
675+
template <class T, layout_type L>
676676
template <class E>
677-
inline auto pyarray<T>::operator=(const xexpression<E>& e) -> self_type&
677+
inline auto pyarray<T, L>::operator=(const xexpression<E>& e) -> self_type&
678678
{
679679
return semantic_base::operator=(e);
680680
}
681681
//@}
682682

683-
template <class T>
684-
inline auto pyarray<T>::ensure(pybind11::handle h) -> self_type
683+
template <class T, layout_type L>
684+
inline auto pyarray<T, L>::ensure(pybind11::handle h) -> self_type
685685
{
686686
return base_type::ensure(h);
687687
}
688688

689-
template <class T>
690-
inline bool pyarray<T>::check_(pybind11::handle h)
689+
template <class T, layout_type L>
690+
inline bool pyarray<T, L>::check_(pybind11::handle h)
691691
{
692692
return base_type::check_(h);
693693
}
694694

695-
template <class T>
696-
inline void pyarray<T>::init_array(const shape_type& shape, const strides_type& strides)
695+
template <class T, layout_type L>
696+
inline void pyarray<T, L>::init_array(const shape_type& shape, const strides_type& strides)
697697
{
698698
strides_type adapted_strides(strides);
699699

@@ -722,8 +722,8 @@ namespace xt
722722
init_from_python();
723723
}
724724

725-
template <class T>
726-
inline void pyarray<T>::init_from_python()
725+
template <class T, layout_type L>
726+
inline void pyarray<T, L>::init_from_python()
727727
{
728728
m_shape = inner_shape_type(reinterpret_cast<size_type*>(PyArray_SHAPE(this->python_array())),
729729
static_cast<size_type>(PyArray_NDIM(this->python_array())));
@@ -734,20 +734,20 @@ namespace xt
734734
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
735735
}
736736

737-
template <class T>
738-
inline auto pyarray<T>::shape_impl() const noexcept -> const inner_shape_type&
737+
template <class T, layout_type L>
738+
inline auto pyarray<T, L>::shape_impl() const noexcept -> const inner_shape_type&
739739
{
740740
return m_shape;
741741
}
742742

743-
template <class T>
744-
inline auto pyarray<T>::strides_impl() const noexcept -> const inner_strides_type&
743+
template <class T, layout_type L>
744+
inline auto pyarray<T, L>::strides_impl() const noexcept -> const inner_strides_type&
745745
{
746746
return m_strides;
747747
}
748748

749-
template <class T>
750-
inline auto pyarray<T>::backstrides_impl() const noexcept -> const inner_backstrides_type&
749+
template <class T, layout_type L>
750+
inline auto pyarray<T, L>::backstrides_impl() const noexcept -> const inner_backstrides_type&
751751
{
752752
// m_backstrides wraps the numpy array backstrides, which is a raw pointer.
753753
// The address of the raw pointer stored in the wrapper would be invalidated when the pyarray is copied.
@@ -756,14 +756,14 @@ namespace xt
756756
return m_backstrides;
757757
}
758758

759-
template <class T>
760-
inline auto pyarray<T>::storage_impl() noexcept -> storage_type&
759+
template <class T, layout_type L>
760+
inline auto pyarray<T, L>::storage_impl() noexcept -> storage_type&
761761
{
762762
return m_storage;
763763
}
764764

765-
template <class T>
766-
inline auto pyarray<T>::storage_impl() const noexcept -> const storage_type&
765+
template <class T, layout_type L>
766+
inline auto pyarray<T, L>::storage_impl() const noexcept -> const storage_type&
767767
{
768768
return m_storage;
769769
}

0 commit comments

Comments
 (0)