23
23
24
24
namespace xt
25
25
{
26
- template <class T >
26
+ template <class T , layout_type L = layout_type::dynamic >
27
27
class pyarray ;
28
28
}
29
29
30
30
namespace pybind11
31
31
{
32
32
namespace detail
33
33
{
34
- template <class T >
35
- struct handle_type_name <xt::pyarray<T>>
34
+ template <class T , xt::layout_type L >
35
+ struct handle_type_name <xt::pyarray<T, L >>
36
36
{
37
37
static PYBIND11_DESCR name ()
38
38
{
39
39
return _ (" numpy.ndarray[" ) + npy_format_descriptor<T>::name () + _ (" ]" );
40
40
}
41
41
};
42
42
43
- template <typename T>
44
- struct pyobject_caster <xt::pyarray<T>>
43
+ template <typename T, xt::layout_type L >
44
+ struct pyobject_caster <xt::pyarray<T, L >>
45
45
{
46
- using type = xt::pyarray<T>;
46
+ using type = xt::pyarray<T, L >;
47
47
48
48
bool load (handle src, bool convert)
49
49
{
@@ -72,10 +72,10 @@ namespace pybind11
72
72
};
73
73
74
74
// Type caster for casting ndarray to xexpression<pyarray>
75
- template <typename T>
76
- struct type_caster <xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
75
+ template <typename T, xt::layout_type L >
76
+ struct type_caster <xt::xexpression<xt::pyarray<T, L >>> : pyobject_caster<xt::pyarray<T, L >>
77
77
{
78
- using Type = xt::xexpression<xt::pyarray<T>>;
78
+ using Type = xt::xexpression<xt::pyarray<T, L >>;
79
79
80
80
operator Type&()
81
81
{
@@ -89,8 +89,8 @@ namespace pybind11
89
89
};
90
90
91
91
// Type caster for casting xarray to ndarray
92
- template <class T >
93
- struct type_caster <xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
92
+ template <class T , xt::layout_type L >
93
+ struct type_caster <xt::xarray<T, L >> : xtensor_type_caster_base<xt::xarray<T, L >>
94
94
{
95
95
};
96
96
}
@@ -282,24 +282,24 @@ namespace xt
282
282
const array_type* p_a;
283
283
};
284
284
285
- template <class T >
286
- struct xiterable_inner_types <pyarray<T>>
287
- : xcontainer_iterable_types<pyarray<T>>
285
+ template <class T , layout_type L >
286
+ struct xiterable_inner_types <pyarray<T, L >>
287
+ : xcontainer_iterable_types<pyarray<T, L >>
288
288
{
289
289
};
290
290
291
- template <class T >
292
- struct xcontainer_inner_types <pyarray<T>>
291
+ template <class T , layout_type L >
292
+ struct xcontainer_inner_types <pyarray<T, L >>
293
293
{
294
294
using storage_type = xbuffer_adaptor<T*>;
295
295
using shape_type = std::vector<typename storage_type::size_type>;
296
296
using strides_type = shape_type;
297
- using backstrides_type = pyarray_backstrides<pyarray<T>>;
297
+ using backstrides_type = pyarray_backstrides<pyarray<T, L >>;
298
298
using inner_shape_type = xbuffer_adaptor<std::size_t *>;
299
299
using inner_strides_type = pystrides_adaptor<sizeof (T)>;
300
300
using inner_backstrides_type = backstrides_type;
301
- using temporary_type = pyarray<T>;
302
- static constexpr layout_type layout = layout_type::dynamic ;
301
+ using temporary_type = pyarray<T, L >;
302
+ static constexpr layout_type layout = L ;
303
303
};
304
304
305
305
/* *
@@ -312,13 +312,13 @@ namespace xt
312
312
* @tparam T The type of the element stored in the pyarray.
313
313
* @sa pytensor
314
314
*/
315
- template <class T >
316
- class pyarray : public pycontainer <pyarray<T>>,
317
- public xcontainer_semantic<pyarray<T>>
315
+ template <class T , layout_type L >
316
+ class pyarray : public pycontainer <pyarray<T, L >>,
317
+ public xcontainer_semantic<pyarray<T, L >>
318
318
{
319
319
public:
320
320
321
- using self_type = pyarray<T>;
321
+ using self_type = pyarray<T, L >;
322
322
using semantic_base = xcontainer_semantic<self_type>;
323
323
using base_type = pycontainer<self_type>;
324
324
using storage_type = typename base_type::storage_type;
@@ -386,8 +386,8 @@ namespace xt
386
386
storage_type& storage_impl () noexcept ;
387
387
const storage_type& storage_impl () const noexcept ;
388
388
389
- friend class xcontainer <pyarray<T>>;
390
- friend class pycontainer <pyarray<T>>;
389
+ friend class xcontainer <pyarray<T, L >>;
390
+ friend class pycontainer <pyarray<T, L >>;
391
391
};
392
392
393
393
/* *************************************
@@ -469,8 +469,8 @@ namespace xt
469
469
* @name Constructors
470
470
*/
471
471
// @{
472
- template <class T >
473
- inline pyarray<T>::pyarray()
472
+ template <class T , layout_type L >
473
+ inline pyarray<T, L >::pyarray()
474
474
: base_type()
475
475
{
476
476
// TODO: avoid allocation
@@ -483,70 +483,70 @@ namespace xt
483
483
/* *
484
484
* Allocates a pyarray with nested initializer lists.
485
485
*/
486
- template <class T >
487
- inline pyarray<T>::pyarray(const value_type& t)
486
+ template <class T , layout_type L >
487
+ inline pyarray<T, L >::pyarray(const value_type& t)
488
488
: base_type()
489
489
{
490
490
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
491
491
nested_copy (m_storage.begin (), t);
492
492
}
493
493
494
- template <class T >
495
- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 1 > t)
494
+ template <class T , layout_type L >
495
+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 1 > t)
496
496
: base_type()
497
497
{
498
498
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
499
499
nested_copy (m_storage.begin (), t);
500
500
}
501
501
502
- template <class T >
503
- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 2 > t)
502
+ template <class T , layout_type L >
503
+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 2 > t)
504
504
: base_type()
505
505
{
506
506
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
507
507
nested_copy (m_storage.begin (), t);
508
508
}
509
509
510
- template <class T >
511
- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 3 > t)
510
+ template <class T , layout_type L >
511
+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 3 > t)
512
512
: base_type()
513
513
{
514
514
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
515
515
nested_copy (m_storage.begin (), t);
516
516
}
517
517
518
- template <class T >
519
- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 4 > t)
518
+ template <class T , layout_type L >
519
+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 4 > t)
520
520
: base_type()
521
521
{
522
522
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
523
523
nested_copy (m_storage.begin (), t);
524
524
}
525
525
526
- template <class T >
527
- inline pyarray<T>::pyarray(nested_initializer_list_t <T, 5 > t)
526
+ template <class T , layout_type L >
527
+ inline pyarray<T, L >::pyarray(nested_initializer_list_t <T, 5 > t)
528
528
: base_type()
529
529
{
530
530
base_type::resize (xt::shape<shape_type>(t), layout_type::row_major);
531
531
nested_copy (m_storage.begin (), t);
532
532
}
533
533
534
- template <class T >
535
- inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
534
+ template <class T , layout_type L >
535
+ inline pyarray<T, L >::pyarray(pybind11::handle h, pybind11::object::borrowed_t b)
536
536
: base_type(h, b)
537
537
{
538
538
init_from_python ();
539
539
}
540
540
541
- template <class T >
542
- inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
541
+ template <class T , layout_type L >
542
+ inline pyarray<T, L >::pyarray(pybind11::handle h, pybind11::object::stolen_t s)
543
543
: base_type(h, s)
544
544
{
545
545
init_from_python ();
546
546
}
547
547
548
- template <class T >
549
- inline pyarray<T>::pyarray(const pybind11::object& o)
548
+ template <class T , layout_type L >
549
+ inline pyarray<T, L >::pyarray(const pybind11::object& o)
550
550
: base_type(o)
551
551
{
552
552
init_from_python ();
@@ -558,8 +558,8 @@ namespace xt
558
558
* @param shape the shape of the pyarray
559
559
* @param l the layout of the pyarray
560
560
*/
561
- template <class T >
562
- inline pyarray<T>::pyarray(const shape_type& shape, layout_type l)
561
+ template <class T , layout_type L >
562
+ inline pyarray<T, L >::pyarray(const shape_type& shape, layout_type l)
563
563
: base_type()
564
564
{
565
565
strides_type strides (shape.size ());
@@ -574,8 +574,8 @@ namespace xt
574
574
* @param value the value of the elements
575
575
* @param l the layout of the pyarray
576
576
*/
577
- template <class T >
578
- inline pyarray<T>::pyarray(const shape_type& shape, const_reference value, layout_type l)
577
+ template <class T , layout_type L >
578
+ inline pyarray<T, L >::pyarray(const shape_type& shape, const_reference value, layout_type l)
579
579
: base_type()
580
580
{
581
581
strides_type strides (shape.size ());
@@ -591,8 +591,8 @@ namespace xt
591
591
* @param strides the strides of the pyarray
592
592
* @param value the value of the elements
593
593
*/
594
- template <class T >
595
- inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
594
+ template <class T , layout_type L >
595
+ inline pyarray<T, L >::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
596
596
: base_type()
597
597
{
598
598
init_array (shape, strides);
@@ -604,8 +604,8 @@ namespace xt
604
604
* @param shape the shape of the pyarray
605
605
* @param strides the strides of the pyarray
606
606
*/
607
- template <class T >
608
- inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides)
607
+ template <class T , layout_type L >
608
+ inline pyarray<T, L >::pyarray(const shape_type& shape, const strides_type& strides)
609
609
: base_type()
610
610
{
611
611
init_array (shape, strides);
@@ -619,8 +619,8 @@ namespace xt
619
619
/* *
620
620
* The copy constructor.
621
621
*/
622
- template <class T >
623
- inline pyarray<T>::pyarray(const self_type& rhs)
622
+ template <class T , layout_type L >
623
+ inline pyarray<T, L >::pyarray(const self_type& rhs)
624
624
: base_type(), semantic_base(rhs)
625
625
{
626
626
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
@@ -639,8 +639,8 @@ namespace xt
639
639
/* *
640
640
* The assignment operator.
641
641
*/
642
- template <class T >
643
- inline auto pyarray<T>::operator =(const self_type& rhs) -> self_type&
642
+ template <class T , layout_type L >
643
+ inline auto pyarray<T, L >::operator =(const self_type& rhs) -> self_type&
644
644
{
645
645
self_type tmp (rhs);
646
646
*this = std::move (tmp);
@@ -656,9 +656,9 @@ namespace xt
656
656
/* *
657
657
* The extended copy constructor.
658
658
*/
659
- template <class T >
659
+ template <class T , layout_type L >
660
660
template <class E >
661
- inline pyarray<T>::pyarray(const xexpression<E>& e)
661
+ inline pyarray<T, L >::pyarray(const xexpression<E>& e)
662
662
: base_type()
663
663
{
664
664
// TODO: prevent intermediary shape allocation
@@ -672,28 +672,28 @@ namespace xt
672
672
/* *
673
673
* The extended assignment operator.
674
674
*/
675
- template <class T >
675
+ template <class T , layout_type L >
676
676
template <class E >
677
- inline auto pyarray<T>::operator =(const xexpression<E>& e) -> self_type&
677
+ inline auto pyarray<T, L >::operator =(const xexpression<E>& e) -> self_type&
678
678
{
679
679
return semantic_base::operator =(e);
680
680
}
681
681
// @}
682
682
683
- template <class T >
684
- inline auto pyarray<T>::ensure(pybind11::handle h) -> self_type
683
+ template <class T , layout_type L >
684
+ inline auto pyarray<T, L >::ensure(pybind11::handle h) -> self_type
685
685
{
686
686
return base_type::ensure (h);
687
687
}
688
688
689
- template <class T >
690
- inline bool pyarray<T>::check_(pybind11::handle h)
689
+ template <class T , layout_type L >
690
+ inline bool pyarray<T, L >::check_(pybind11::handle h)
691
691
{
692
692
return base_type::check_ (h);
693
693
}
694
694
695
- template <class T >
696
- inline void pyarray<T>::init_array(const shape_type& shape, const strides_type& strides)
695
+ template <class T , layout_type L >
696
+ inline void pyarray<T, L >::init_array(const shape_type& shape, const strides_type& strides)
697
697
{
698
698
strides_type adapted_strides (strides);
699
699
@@ -722,8 +722,8 @@ namespace xt
722
722
init_from_python ();
723
723
}
724
724
725
- template <class T >
726
- inline void pyarray<T>::init_from_python()
725
+ template <class T , layout_type L >
726
+ inline void pyarray<T, L >::init_from_python()
727
727
{
728
728
m_shape = inner_shape_type (reinterpret_cast <size_type*>(PyArray_SHAPE (this ->python_array ())),
729
729
static_cast <size_type>(PyArray_NDIM (this ->python_array ())));
@@ -734,20 +734,20 @@ namespace xt
734
734
this ->get_min_stride () * static_cast <size_type>(PyArray_SIZE (this ->python_array ())));
735
735
}
736
736
737
- template <class T >
738
- inline auto pyarray<T>::shape_impl() const noexcept -> const inner_shape_type&
737
+ template <class T , layout_type L >
738
+ inline auto pyarray<T, L >::shape_impl() const noexcept -> const inner_shape_type&
739
739
{
740
740
return m_shape;
741
741
}
742
742
743
- template <class T >
744
- inline auto pyarray<T>::strides_impl() const noexcept -> const inner_strides_type&
743
+ template <class T , layout_type L >
744
+ inline auto pyarray<T, L >::strides_impl() const noexcept -> const inner_strides_type&
745
745
{
746
746
return m_strides;
747
747
}
748
748
749
- template <class T >
750
- inline auto pyarray<T>::backstrides_impl() const noexcept -> const inner_backstrides_type&
749
+ template <class T , layout_type L >
750
+ inline auto pyarray<T, L >::backstrides_impl() const noexcept -> const inner_backstrides_type&
751
751
{
752
752
// m_backstrides wraps the numpy array backstrides, which is a raw pointer.
753
753
// The address of the raw pointer stored in the wrapper would be invalidated when the pyarray is copied.
@@ -756,14 +756,14 @@ namespace xt
756
756
return m_backstrides;
757
757
}
758
758
759
- template <class T >
760
- inline auto pyarray<T>::storage_impl() noexcept -> storage_type&
759
+ template <class T , layout_type L >
760
+ inline auto pyarray<T, L >::storage_impl() noexcept -> storage_type&
761
761
{
762
762
return m_storage;
763
763
}
764
764
765
- template <class T >
766
- inline auto pyarray<T>::storage_impl() const noexcept -> const storage_type&
765
+ template <class T , layout_type L >
766
+ inline auto pyarray<T, L >::storage_impl() const noexcept -> const storage_type&
767
767
{
768
768
return m_storage;
769
769
}
0 commit comments