Skip to content

Commit

Permalink
Update operator== of Mat, add support for nan and Bit datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 29, 2019
1 parent 5eda6d2 commit 0678418
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions dabnn/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,32 @@ inline bool Mat::operator==(const Mat &m) const {
h == m.h && c == m.c && data_type == m.data_type)) {
return false;
}
FORZ(i, total()) {
if (std::abs(static_cast<float *>(data)[i] - m[i]) > 1e-5) {
return false;
if (m.data_type == DataType::Float) {
FORZ(i, total()) {
const auto elem = static_cast<float *>(data)[i];
if (std::isnan(elem) && !std::isnan(m[i])) {
PNT(elem, m[i]);
return false;
}
if (!std::isnan(elem) && std::isnan(m[i])) {
PNT(elem, m[i]);
return false;
}
if (std::abs(elem - m[i]) > 1e-5) {
PNT(i, elem, m[i]);
return false;
}
}
} else if (m.data_type == DataType::Bit) {
FORZ(i, total()) {
const auto elem = static_cast<uint64_t *>(data)[i];
if (elem != m[i]) {
PNT(elem, m[i]);
return false;
}
}
} else {
throw std::invalid_argument("Unknown datatype");
}
return true;
}
Expand Down

0 comments on commit 0678418

Please sign in to comment.