Skip to content

Commit

Permalink
Initial matmul simd support (#245)
Browse files Browse the repository at this point in the history
* fix eval output & context forwarding

* add bit_width metafunction for simd

* add matmul evaluator

* add fmadd support

* add simd matmul indexing

* add contiguous_axis metafunction

* refactor ndarray row/major offset metafunction

* rename operands to array for matmul view

* add matmul simd tests

* mark offset constructor to constexpr

* fix offset metafunction for constexpr

* fix offset metafunction for constexpr

* fix offset metafunction for constexpr
  • Loading branch information
alifahrri committed Oct 13, 2023
1 parent fa8a45b commit 6d6f75e
Show file tree
Hide file tree
Showing 23 changed files with 2,496 additions and 489 deletions.
5 changes: 4 additions & 1 deletion include/nmtools/array/array/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ namespace nmtools::array
, context_t&& context=context_t{}, output_t&& output=output_t{})
{
auto ref_ = view::ref(array);
return eval(ref_,nmtools::forward<context_t>(context),nmtools::forward<output_t>(output));
return eval(ref_
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
);
} // copy
} // namespace nmtools::array

Expand Down
38 changes: 29 additions & 9 deletions include/nmtools/array/eval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ namespace nmtools::array
// TODO: perfect forwarding for context
using ctx_t = meta::remove_cvref_t<context_t>;
using evaluator_type = evaluator_t<view_t,ctx_t,resolver_t>;
return evaluator_type{view,context};
return evaluator_type{view,nmtools::forward<context_t>(context)};
} // evaluator

namespace detail
Expand All @@ -178,8 +178,16 @@ namespace nmtools::array
using left_t = meta::get_either_left_t<view_t>;
using right_t = meta::get_either_right_t<view_t>;
// deduce return type for each type
using rleft_t = decltype(detail::eval(meta::declval<left_t>(),context,output,resolver));
using rright_t = decltype(detail::eval(meta::declval<right_t>(),context,output,resolver));
using rleft_t = decltype(detail::eval(
meta::declval<left_t>()
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver));
using rright_t = decltype(detail::eval(
meta::declval<right_t>()
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver));
constexpr auto vtype = [](){
if constexpr (meta::is_same_v<rleft_t,rright_t>) {
return meta::as_value_v<rleft_t>;
Expand All @@ -191,24 +199,36 @@ namespace nmtools::array
using return_t = meta::type_t<decltype(vtype)>;
// match either type at runtime
if (auto view_ptr = nmtools::get_if<left_t>(&view)) {
return return_t{detail::eval(*view_ptr,context,output,resolver)};
return return_t{detail::eval(*view_ptr
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver)};
} else /* if (auto view_ptr = get_if<right_t>(&view)) */ {
auto view_rptr = nmtools::get_if<right_t>(&view);
return return_t{detail::eval(*view_rptr,context,output,resolver)};
return return_t{detail::eval(*view_rptr
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver)};
}
} else if constexpr (meta::is_maybe_v<view_t>) {
using view_type = meta::get_maybe_type_t<view_t>;
using result_type = decltype(detail::eval(meta::declval<view_type>(),context,output,resolver));
using result_type = decltype(detail::eval(meta::declval<view_type>()
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver));
using return_type = nmtools_maybe<result_type>;
if (static_cast<bool>(view)) {
auto result = detail::eval(*view,context,output,resolver);
auto result = detail::eval(*view
,nmtools::forward<context_t>(context)
,nmtools::forward<output_t>(output)
,resolver);
return return_type{result};
} else {
return return_type{meta::Nothing};
}
} else /* if constexpr (meta::is_ndarray_v<view_t> || meta::is_num_v<view_t>) */ {
auto evaluator_ = evaluator<resolver_t>(view,context);
return evaluator_(output);
auto evaluator_ = evaluator<resolver_t>(view,nmtools::forward<context_t>(context));
return evaluator_(nmtools::forward<output_t>(output));
}
} // eval

Expand Down
24 changes: 24 additions & 0 deletions include/nmtools/array/eval/simd/bit_width.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NMTOOLS_ARRAY_EVAL_SIMD_BIT_WIDTH_HPP
#define NMTOOLS_ARRAY_EVAL_SIMD_BIT_WIDTH_HPP

#include "nmtools/meta.hpp"

namespace nmtools::meta
{
namespace error
{
template <typename...>
struct BIT_WIDTH_UNSUPPORTED : detail::fail_t {};
}

template <typename T>
struct bit_width
{
static constexpr auto value = error::BIT_WIDTH_UNSUPPORTED<T>{};
};

template <typename T>
constexpr inline auto bit_width_v = bit_width<T>::value;
} // namespace nmtools::meta

#endif // NMTOOLS_ARRAY_EVAL_SIMD_BIT_WIDTH_HPP
Loading

0 comments on commit 6d6f75e

Please sign in to comment.