Skip to content

Commit

Permalink
Replace ASSERT_HOST by assert
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jan 14, 2019
1 parent f75b2c1 commit c79d613
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "weightmatrix.h"

#include <cassert> // for assert
#include "intsimdmatrix.h"
#include "simddetect.h" // for DotProduct
#include "statistc.h"
Expand Down Expand Up @@ -238,21 +239,21 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
// implement the bias, but it doesn't actually have it.
// Asserts that the call matches what we have.
void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
ASSERT_HOST(!int_mode_);
assert(!int_mode_);
MatrixDotVectorInternal(wf_, true, false, u, v);
}

void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
ASSERT_HOST(int_mode_);
ASSERT_HOST(multiplier_ != nullptr);
assert(int_mode_);
assert(multiplier_ != nullptr);
multiplier_->MatrixDotVector(wi_, scales_, u, v);
}

// MatrixDotVector for peep weights, MultiplyAccumulate adds the
// component-wise products of *this[0] and v to inout.
void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
ASSERT_HOST(!int_mode_);
ASSERT_HOST(wf_.dim1() == 1);
assert(!int_mode_);
assert(wf_.dim1() == 1);
int n = wf_.dim2();
const double* u = wf_[0];
for (int i = 0; i < n; ++i) {
Expand All @@ -265,7 +266,7 @@ void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
// The last result is discarded, as v is assumed to have an imaginary
// last value of 1, as with MatrixDotVector.
void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
ASSERT_HOST(!int_mode_);
assert(!int_mode_);
MatrixDotVectorInternal(wf_t_, false, true, u, v);
}

Expand All @@ -277,14 +278,14 @@ void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
const TransposedArray& v,
bool in_parallel) {
ASSERT_HOST(!int_mode_);
assert(!int_mode_);
int num_outputs = dw_.dim1();
ASSERT_HOST(u.dim1() == num_outputs);
ASSERT_HOST(u.dim2() == v.dim2());
assert(u.dim1() == num_outputs);
assert(u.dim2() == v.dim2());
int num_inputs = dw_.dim2() - 1;
int num_samples = u.dim2();
// v is missing the last element in dim1.
ASSERT_HOST(v.dim1() == num_inputs);
assert(v.dim1() == num_inputs);
#ifdef _OPENMP
#pragma omp parallel for num_threads(4) if (in_parallel)
#endif
Expand All @@ -306,7 +307,7 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
// use_adam_ is true.
void WeightMatrix::Update(double learning_rate, double momentum,
double adam_beta, int num_samples) {
ASSERT_HOST(!int_mode_);
assert(!int_mode_);
if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
learning_rate /= 1.0 - pow(momentum, num_samples);
Expand All @@ -328,8 +329,8 @@ void WeightMatrix::Update(double learning_rate, double momentum,

// Adds the dw_ in other to the dw_ is *this.
void WeightMatrix::AddDeltas(const WeightMatrix& other) {
ASSERT_HOST(dw_.dim1() == other.dw_.dim1());
ASSERT_HOST(dw_.dim2() == other.dw_.dim2());
assert(dw_.dim1() == other.dw_.dim1());
assert(dw_.dim2() == other.dw_.dim2());
dw_ += other.dw_;
}

Expand All @@ -340,8 +341,8 @@ void WeightMatrix::CountAlternators(const WeightMatrix& other, double* same,
double* changed) const {
int num_outputs = updates_.dim1();
int num_inputs = updates_.dim2();
ASSERT_HOST(num_outputs == other.updates_.dim1());
ASSERT_HOST(num_inputs == other.updates_.dim2());
assert(num_outputs == other.updates_.dim1());
assert(num_inputs == other.updates_.dim2());
for (int i = 0; i < num_outputs; ++i) {
const double* this_i = updates_[i];
const double* other_i = other.updates_[i];
Expand Down

0 comments on commit c79d613

Please sign in to comment.