Skip to content

Commit

Permalink
fix sycl compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed May 12, 2024
1 parent f89260e commit 0b21749
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 11 deletions.
19 changes: 8 additions & 11 deletions include/nmtools/array/eval/kernel_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,12 @@ namespace nmtools::array
return array;
}

template <typename type, typename shape_type>
nmtools_func_attribute
auto create_mutable_array(type* data_ptr, const shape_type& shape)
{
const auto numel = index::product(shape);

auto ref = view::mutable_ref(data_ptr,numel);
return view::reshape(ref,shape);
}

template <auto DIM=0, typename dim_type=nm_index_t, typename size_type=nm_index_t, typename type>
nmtools_func_attribute
auto create_mutable_array(type* data_ptr, const size_type* shape_ptr, dim_type dim)
{
const auto shape = create_vector<DIM>(shape_ptr,dim);
return create_mutable_array(data_ptr,shape);
return device_array(data_ptr,shape,dim);
}

template <typename size_type=nm_size_t>
Expand Down Expand Up @@ -182,12 +172,19 @@ namespace nmtools::array
return;
}
assign_result(output,*result,thread_id,block_id,block_size);
} else if constexpr (meta::is_maybe_v<mutable_array_t>) {
if (!static_cast<bool>(output)) {
return;
}
assign_result(*output,result,thread_id,block_id,block_size);
} else {
auto size = nmtools::size(output);
auto idx = compute_offset(thread_id,block_id,block_size);
if (idx < size) {
auto flat_lhs = view::mutable_flatten(output);
auto flat_rhs = view::flatten(result);
static_assert( !meta::is_maybe_v<decltype(flat_lhs)> );
static_assert( !meta::is_maybe_v<decltype(flat_rhs)> );
const auto rhs = flat_rhs(idx);
auto& lhs = flat_lhs(idx);
lhs = rhs;
Expand Down
43 changes: 43 additions & 0 deletions include/nmtools/array/view/flatten.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,47 @@ namespace nmtools::view
} // flatten
} // namespace nmtools::view

namespace nmtools::array
{
template <typename...args_t, auto max_dim>
struct as_static_t<
view::flatten_t<args_t...>, max_dim
>
{
using attribute_type = view::flatten_t<args_t...>;

attribute_type attribute;

auto operator()() const
{
auto src_shape = as_static<max_dim>(attribute.src_shape);
auto src_size = as_static<max_dim>(attribute.src_size);
return view::flatten_t{src_shape,src_size};
}
};
} // namespace nmtools::array

#if NMTOOLS_HAS_STRING

namespace nmtools::utils::impl
{
template <typename...args_t, auto...fmt_args>
struct to_string_t<
view::flatten_t<args_t...>, fmt_string_t<fmt_args...>
> {
using result_type = nmtools_string;

auto operator()(const view::flatten_t<args_t...>& kwargs) const noexcept
{
nmtools_string str;
str += "flatten{";
str += ".src_shape="; str += to_string(kwargs.src_shape,Compact);
str += ".src_size="; str += to_string(kwargs.src_size,Compact);
str += "}";
}
};
}

#endif // NMTOOLS_HAS_STRING

#endif // NMTOOLS_ARRAY_VIEW_FLATTEN_HPP
20 changes: 20 additions & 0 deletions include/nmtools/array/view/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ namespace nmtools::view

} // namespace nmtools::view

namespace nmtools::array
{
template <typename...args_t, auto max_dim>
struct as_static_t<
view::reshape_t<args_t...>, max_dim
> {
using attribute_type = view::reshape_t<args_t...>;

attribute_type attribute;

auto operator()() const
{
auto src_shape = as_static<max_dim>(attribute.src_shape);
auto dst_shape = as_static<max_dim>(attribute.dst_shape);
auto src_size = as_static<max_dim>(attribute.src_size);
return view::reshape_t{src_shape,dst_shape,src_size};
}
};
}

#if NMTOOLS_HAS_STRING

namespace nmtools::utils::impl
Expand Down

0 comments on commit 0b21749

Please sign in to comment.