Skip to content

Commit

Permalink
Fix the issue openvinotoolkit#24114
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryan8912 committed Apr 28, 2024
1 parent 2ce0e4a commit f8e620b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,49 @@ void jit_select_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
h->mov(dst.b16, aux.b16);
}

// SoftSign

jit_softsign_emitter::jit_softsign_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {}

jit_softsign_emitter::jit_softsign_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_softsign_emitter::get_inputs_count() const { return 1; }
size_t jit_softsign_emitter::get_aux_vecs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_softsign_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& /*node*/) {
return {{element::f32}};
}

void jit_softsign_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_softsign_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);
TReg aux = TReg(aux_vec_idxs[0]);
h->movi(aux.s, 1);
h->fdiv(dst.s, src.s, aux.s);
h->fadd(dst.s, dst.s, src.s);
h->fdiv(dst.s, dst.s, aux.s);
h->fsub(dst.s, aux.s, dst.s);
h->fmul(dst.s, dst.s, src.s);
}


/// SIGMOID ///
jit_sigmoid_emitter::jit_sigmoid_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,27 @@ class jit_select_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_softsign_emitter : public jit_emitter {
public:
jit_softsign_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_softsign_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;
size_t get_aux_vecs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};
class jit_sigmoid_emitter : public jit_emitter {
public:
jit_sigmoid_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down

0 comments on commit f8e620b

Please sign in to comment.