Skip to content

Commit

Permalink
Use constructor with parameters for IntSimdMatrix
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jan 14, 2019
1 parent e237a38 commit 26be7c5
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/arch/intsimdmatrix.cpp
Expand Up @@ -36,7 +36,7 @@ const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
multiplier = new IntSimdMatrixSSE();
} else {
// Default c++ implementation.
multiplier = new IntSimdMatrix();
multiplier = new IntSimdMatrix(1, 1, 1, 1, 1, {});
}
return multiplier;
}
Expand Down
42 changes: 24 additions & 18 deletions src/arch/intsimdmatrix.h
Expand Up @@ -60,14 +60,30 @@ namespace tesseract {
// is required to allow the base class implementation to do all the work.
class IntSimdMatrix {
public:
// Constructor should set the data members to indicate the sizes.
// NOTE: Base constructor public only for test purposes.
IntSimdMatrix()
: num_outputs_per_register_(1),
max_output_registers_(1),
num_inputs_per_register_(1),
num_inputs_per_group_(1),
num_input_groups_(1) {}
// Function to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v);

IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
// Number of 32 bit outputs held in each register.
num_outputs_per_register_(num_outputs_per_register),
// Maximum number of registers that we will use to hold outputs.
max_output_registers_(max_output_registers),
// Number of 8 bit inputs in the inputs register.
num_inputs_per_register_(num_inputs_per_register),
// Number of inputs in each weight group.
num_inputs_per_group_(num_inputs_per_group),
// Number of groups of inputs to be broadcast.
num_input_groups_(num_input_groups),
// A series of functions to compute a partial result.
partial_funcs_(partial_funcs)
{}

// Factory makes and returns an IntSimdMatrix (sub)class of the best
// available type for the current architecture.
Expand Down Expand Up @@ -100,16 +116,6 @@ class IntSimdMatrix {
double* v) const;

protected:
// Function to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v);

// Rounds the input up to a multiple of the given factor.
static int Roundup(int input, int factor) {
return (input + factor - 1) / factor * factor;
Expand Down
14 changes: 6 additions & 8 deletions src/arch/intsimdmatrixavx2.cpp
Expand Up @@ -269,16 +269,14 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
namespace tesseract {
#endif // __AVX2__

IntSimdMatrixAVX2::IntSimdMatrixAVX2() {
IntSimdMatrixAVX2::IntSimdMatrixAVX2()
#ifdef __AVX2__
num_outputs_per_register_ = kNumOutputsPerRegister;
max_output_registers_ = kMaxOutputRegisters;
num_inputs_per_register_ = kNumInputsPerRegister;
num_inputs_per_group_ = kNumInputsPerGroup;
num_input_groups_ = kNumInputGroups;
partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8};
: IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8})
#else
: IntSimdMatrix(1, 1, 1, 1, 1, {})
#endif // __AVX2__
{
}

} // namespace tesseract.
7 changes: 5 additions & 2 deletions src/arch/intsimdmatrixsse.cpp
Expand Up @@ -33,10 +33,13 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
}
#endif // __SSE4_1__

IntSimdMatrixSSE::IntSimdMatrixSSE() {
IntSimdMatrixSSE::IntSimdMatrixSSE()
#ifdef __SSE4_1__
partial_funcs_ = {PartialMatrixDotVector1};
: IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1})
#else
: IntSimdMatrix(1, 1, 1, 1, 1, {})
#endif // __SSE4_1__
{
}

} // namespace tesseract.
2 changes: 1 addition & 1 deletion unittest/intsimdmatrix_test.cc
Expand Up @@ -82,7 +82,7 @@ class IntSimdMatrixTest : public ::testing::Test {
}

TRand random_;
IntSimdMatrix base_;
IntSimdMatrix base_ = IntSimdMatrix(1, 1, 1, 1, 1, {});
};

// Test the C++ implementation without SIMD.
Expand Down

0 comments on commit 26be7c5

Please sign in to comment.