Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix SupportDNNL for multiple inputs (#21102)
Browse files Browse the repository at this point in the history
* Fix SupportDNNL for multiple inputs

* Fix SupportDNNL condition for multiple inputs
  • Loading branch information
agrabow committed Jul 22, 2022
1 parent ecb5026 commit 5e5e0e3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,16 @@ enum DNNLTensorsDtypes { AllSame = 0, Mixed = 1 };

template <int MinNdim, int MaxNdim, DNNLTypeMode TypeMode, DNNLTensorsDtypes MixedTensors>
static inline bool SupportDNNL(const std::vector<NDArray>& inputs) {
int dtype = MixedTensors ? -1 : inputs[0].dtype();
if (!SupportDNNLType<TypeMode>(dtype)) {
if (!SupportDNNLType<TypeMode>(inputs[0].dtype())) {
return false;
}
int dtype = MixedTensors ? -1 : inputs[0].dtype();
for (NDArray input : inputs) {
if (dtype == -1) {
if (!SupportDNNL<MinNdim, MaxNdim, TypeMode>(input))
return false;
} else {
if (input.dtype() != dtype && !SupportDNNLShape<MinNdim, MaxNdim>(input.shape()))
if (input.dtype() != dtype || !SupportDNNLShape<MinNdim, MaxNdim>(input.shape()))
return false;
}
}
Expand Down

0 comments on commit 5e5e0e3

Please sign in to comment.