Skip to content

Commit

Permalink
Set best or user selected IntSimdMatrix
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jan 14, 2019
1 parent 605b4d6 commit d36231e
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 50 deletions.
17 changes: 1 addition & 16 deletions src/arch/intsimdmatrix.cpp
Expand Up @@ -23,25 +23,10 @@

namespace tesseract {

const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr;
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixNative =
IntSimdMatrix(1, 1, 1, 1, 1, {});

// Factory makes and returns an IntSimdMatrix (sub)class of the best
// available type for the current architecture.
/* static */
const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
const IntSimdMatrix* multiplier;
if (SIMDDetect::IsAVX2Available()) {
multiplier = &IntSimdMatrixAVX2;
} else if (SIMDDetect::IsSSEAvailable()) {
multiplier = &IntSimdMatrixSSE;
} else {
// Default c++ implementation.
multiplier = &IntSimdMatrixNative;
}
return multiplier;
}

// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const {
Expand Down
5 changes: 1 addition & 4 deletions src/arch/intsimdmatrix.h
Expand Up @@ -85,10 +85,6 @@ class IntSimdMatrix {
partial_funcs_(partial_funcs)
{}

// Factory makes and returns an IntSimdMatrix (sub)class of the best
// available type for the current architecture.
static const IntSimdMatrix* GetFastestMultiplier();

// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
void Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const;
Expand All @@ -115,6 +111,7 @@ class IntSimdMatrix {
const GenericVector<double>& scales, const int8_t* u,
double* v) const;

static const IntSimdMatrix* intSimdMatrix;
static const IntSimdMatrix IntSimdMatrixAVX2;
static const IntSimdMatrix IntSimdMatrixSSE;
static const IntSimdMatrix IntSimdMatrixNative;
Expand Down
14 changes: 8 additions & 6 deletions src/arch/simddetect.cpp
Expand Up @@ -19,6 +19,7 @@
#include "dotproduct.h"
#include "dotproductavx.h"
#include "dotproductsse.h"
#include "intsimdmatrix.h" // for IntSimdMatrix
#include "params.h" // for STRING_VAR
#include "tprintf.h" // for tprintf

Expand Down Expand Up @@ -68,8 +69,9 @@ static double DotProductGeneric(const double* u, const double* v, int n) {
return total;
}

static void SetDotProduct(DotProductFunction function) {
DotProduct = function;
static void SetDotProduct(DotProductFunction f, const IntSimdMatrix* m = nullptr) {
DotProduct = f;
IntSimdMatrix::intSimdMatrix = m;
}

// Constructor.
Expand Down Expand Up @@ -126,12 +128,12 @@ SIMDDetect::SIMDDetect() {
#if defined(AVX)
} else if (avx_available_) {
// AVX detected.
SetDotProduct(DotProductAVX);
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
#endif
#if defined(SSE4_1)
} else if (sse_available_) {
// SSE detected.
SetDotProduct(DotProductSSE);
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
#endif
}
}
Expand All @@ -153,13 +155,13 @@ void SIMDDetect::Update() {
#if defined(AVX)
} else if (!strcmp(dotproduct.string(), "avx")) {
// AVX selected by config variable.
SetDotProduct(DotProductAVX);
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
dotproduct_method = "avx";
#endif
#if defined(SSE4_1)
} else if (!strcmp(dotproduct.string(), "sse")) {
// SSE selected by config variable.
SetDotProduct(DotProductSSE);
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
dotproduct_method = "sse";
#endif
} else {
Expand Down
14 changes: 6 additions & 8 deletions src/lstm/networkio.cpp
Expand Up @@ -31,11 +31,6 @@ const float kMinCertainty = -20.0f;
// Probability corresponding to kMinCertainty.
const float kMinProb = exp(kMinCertainty);

// Holds the optimal integer multiplier for this machine.
// This is a leaked, lazily initialized singleton, and is used for computing
// padding to apply to i_ for SIMD use.
const IntSimdMatrix* NetworkIO::multiplier_ = nullptr;

// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
void NetworkIO::Resize2d(bool int_mode, int width, int num_features) {
stride_map_ = StrideMap();
Expand Down Expand Up @@ -985,9 +980,12 @@ void NetworkIO::ClipVector(int t, float range) {
// for the SIMD operations to be safe.
/* static */
int NetworkIO::GetPadding(int num_features) {
if (multiplier_ == nullptr)
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
return multiplier_->RoundInputs(num_features) - num_features;
int padding = 0;
if (IntSimdMatrix::intSimdMatrix) {
padding =
IntSimdMatrix::intSimdMatrix->RoundInputs(num_features) - num_features;
}
return padding;
}

} // namespace tesseract.
4 changes: 0 additions & 4 deletions src/lstm/networkio.h
Expand Up @@ -338,10 +338,6 @@ class NetworkIO {
bool int_mode_;
// Stride for 2d input data.
StrideMap stride_map_;
// Holds the optimal integer multiplier for this machine.
// This is a leaked, lazily initialized singleton, and is used for computing
// padding to apply to i_ for SIMD use.
static const IntSimdMatrix* multiplier_;
};

} // namespace tesseract.
Expand Down
11 changes: 5 additions & 6 deletions src/lstm/weightmatrix.cpp
Expand Up @@ -143,8 +143,8 @@ void WeightMatrix::ConvertToInt() {
}
wf_.Resize(1, 1, 0.0);
int_mode_ = true;
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
multiplier_->Init(wi_, shaped_w_);
if (IntSimdMatrix::intSimdMatrix)
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
}

// Allocates any needed memory for running Backward, and zeroes the deltas,
Expand Down Expand Up @@ -196,8 +196,8 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (int_mode_) {
if (!wi_.DeSerialize(fp)) return false;
if (!scales_.DeSerialize(fp)) return false;
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
multiplier_->Init(wi_, shaped_w_);
if (IntSimdMatrix::intSimdMatrix)
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
} else {
if (!wf_.DeSerialize(fp)) return false;
if (training) {
Expand Down Expand Up @@ -245,8 +245,7 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const {

void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
assert(int_mode_);
assert(multiplier_ != nullptr);
multiplier_->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
IntSimdMatrix::intSimdMatrix->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
}

// MatrixDotVector for peep weights, MultiplyAccumulate adds the
Expand Down
10 changes: 4 additions & 6 deletions src/lstm/weightmatrix.h
Expand Up @@ -64,7 +64,7 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> {
// backward steps with the matrix and updates to the weights.
class WeightMatrix {
public:
WeightMatrix() : int_mode_(false), use_adam_(false), multiplier_(nullptr) {}
WeightMatrix() : int_mode_(false), use_adam_(false) {}
// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
// Note the order is outputs, inputs, as this is the order of indices to
Expand All @@ -85,13 +85,13 @@ class WeightMatrix {
// Scale so the max absolute value becomes INT8_MAX.
// Round to integer.
// Store a multiplicative scale factor (as a float) that will reproduce
// the original value, subject to rounding errors.
// the original value, subject to rounding errors.
void ConvertToInt();
// Returns the size rounded up to an internal factor used by the SIMD
// implementation for its input.
int RoundInputs(int size) const {
if (multiplier_ == nullptr) return size;
return multiplier_->RoundInputs(size);
if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) return size;
return IntSimdMatrix::intSimdMatrix->RoundInputs(size);
}

// Accessors.
Expand Down Expand Up @@ -178,8 +178,6 @@ class WeightMatrix {
GENERIC_2D_ARRAY<double> dw_sq_sum_;
// The weights matrix reorganized in whatever way suits this instance.
std::vector<int8_t> shaped_w_;
// Holds the optimal integer multiplier for this machine.
const IntSimdMatrix* multiplier_;
};

} // namespace tesseract.
Expand Down

0 comments on commit d36231e

Please sign in to comment.