Skip to content

Commit

Permalink
fix build at ubuntu 22.04
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaolong committed Mar 20, 2023
1 parent 4ade75c commit 0d659f3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/cpu/rnn/jit_uni_rnn_common_postgemm_dispatcher.hpp
Expand Up @@ -232,7 +232,7 @@ using rnn_postgemm_bwd_f32_t = rnn_postgemm_dispatcher<prop_kind::backward,
using rnn_postgemm_fwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::forward,
data_type::bf16, data_type::f32>;
using rnn_postgemm_bwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::backward,
data_type::bf16, data_type::f32>;
data_type::bf16, data_type::bf16>;
using rnn_postgemm_fwd_u8_t = rnn_postgemm_dispatcher<prop_kind::forward,
data_type::u8, data_type::s32>;

Expand Down
14 changes: 8 additions & 6 deletions src/cpu/rnn/ref_rnn.hpp
Expand Up @@ -61,12 +61,13 @@ void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
template <prop_kind_t aprop, impl::data_type_t src_type,
impl::data_type_t weights_type, impl::data_type_t acc_type>
struct _ref_rnn_common_t : public primitive_impl_t {
static constexpr impl::data_type_t scratch_type
= aprop == prop_kind::forward ? acc_type : src_type;

typedef typename prec_traits<src_type>::type src_data_t;
typedef typename prec_traits<weights_type>::type weights_data_t;
typedef typename prec_traits<acc_type>::type acc_data_t;

typedef typename utils::conditional<aprop == prop_kind::forward, acc_data_t,
src_data_t>::type scratch_data_t;
typedef typename prec_traits<scratch_type>::type scratch_data_t;

using class_name
= _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>;
Expand Down Expand Up @@ -224,8 +225,9 @@ struct _ref_rnn_common_t : public primitive_impl_t {
set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
weights_layer_assign_func);

rnn_postgemm_ = new rnn_postgemm_dispatcher<aprop, src_type, acc_type>(
pd()->rnn_, pd());
rnn_postgemm_
= new rnn_postgemm_dispatcher<aprop, src_type, scratch_type>(
pd()->rnn_, pd());
assert(rnn_postgemm_ != nullptr);
switch (pd()->cell_kind()) {
case alg_kind::vanilla_rnn:
Expand Down Expand Up @@ -308,7 +310,7 @@ struct _ref_rnn_common_t : public primitive_impl_t {
size_t ws_grid_comp_offset_;
size_t scratch_gates_offset_;
size_t scratch_cell_offset_;
rnn_postgemm_dispatcher<aprop, src_type, acc_type> *rnn_postgemm_;
rnn_postgemm_dispatcher<aprop, src_type, scratch_type> *rnn_postgemm_;

grid_execution_f grid_computation;
cell_execution_f cell_func;
Expand Down

0 comments on commit 0d659f3

Please sign in to comment.