Skip to content
This repository has been archived by the owner on Nov 17, 2021. It is now read-only.

Fixing misuse of the standard math library (performance audit) #41

Merged
merged 9 commits into from
Mar 17, 2017
6 changes: 3 additions & 3 deletions matrix/AxisAngle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class AxisAngle : public Vector<Type, 3>
Vector<Type, 3>()
{
AxisAngle &v = *this;
Type ang = Type(2.0f)*acosf(q(0));
Type mag = sinf(ang/2.0f);
if (fabsf(mag) > 0) {
Type ang = Type(2.0f)*acos(q(0));
Type mag = sin(ang/2.0f);
if (fabs(mag) > 0) {
v(0) = ang*q(1)/mag;
v(1) = ang*q(2)/mag;
v(2) = ang*q(3)/mag;
Expand Down
1 change: 0 additions & 1 deletion matrix/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#pragma once

#include <cmath>
#include <cstdio>
#include <cstring>

Expand Down
18 changes: 9 additions & 9 deletions matrix/Quaternion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,28 @@ class Quaternion : public Vector<Type, 4>
Quaternion &q = *this;
Type t = R.trace();
if (t > Type(0)) {
t = sqrtf(Type(1) + t);
t = sqrt(Type(1) + t);
q(0) = Type(0.5) * t;
t = Type(0.5) / t;
q(1) = (R(2,1) - R(1,2)) * t;
q(2) = (R(0,2) - R(2,0)) * t;
q(3) = (R(1,0) - R(0,1)) * t;
} else if (R(0,0) > R(1,1) && R(0,0) > R(2,2)) {
t = sqrtf(Type(1) + R(0,0) - R(1,1) - R(2,2));
t = sqrt(Type(1) + R(0,0) - R(1,1) - R(2,2));
q(1) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(2,1) - R(1,2)) * t;
q(2) = (R(1,0) + R(0,1)) * t;
q(3) = (R(0,2) + R(2,0)) * t;
} else if (R(1,1) > R(2,2)) {
t = sqrtf(Type(1) - R(0,0) + R(1,1) - R(2,2));
t = sqrt(Type(1) - R(0,0) + R(1,1) - R(2,2));
q(2) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(0,2) - R(2,0)) * t;
q(1) = (R(1,0) + R(0,1)) * t;
q(3) = (R(2,1) + R(1,2)) * t;
} else {
t = sqrtf(Type(1) - R(0,0) - R(1,1) + R(2,2));
t = sqrt(Type(1) - R(0,0) - R(1,1) + R(2,2));
q(3) = Type(0.5) * t;
t = Type(0.5) / t;
q(0) = (R(1,0) - R(0,1)) * t;
Expand Down Expand Up @@ -171,8 +171,8 @@ class Quaternion : public Vector<Type, 4>
q(0) = Type(1.0);
q(1) = q(2) = q(3) = 0;
} else {
Type magnitude = sinf(angle / 2.0f);
q(0) = cosf(angle / 2.0f);
Type magnitude = sin(angle / 2.0f);
q(0) = cos(angle / 2.0f);
q(1) = axis(0) * magnitude;
q(2) = axis(1) * magnitude;
q(3) = axis(2) * magnitude;
Expand Down Expand Up @@ -389,9 +389,9 @@ class Quaternion : public Vector<Type, 4>
q(1) = q(2) = q(3) = 0;
}

Type magnitude = sinf(theta / 2.0f);
Type magnitude = sin(theta / 2.0f);

q(0) = cosf(theta / 2.0f);
q(0) = cos(theta / 2.0f);
q(1) = axis(0) * magnitude;
q(2) = axis(1) * magnitude;
q(3) = axis(2) * magnitude;
Expand All @@ -418,7 +418,7 @@ class Quaternion : public Vector<Type, 4>

if (axis_magnitude >= Type(1e-10)) {
vec = vec / axis_magnitude;
vec = vec * wrap_pi(Type(2.0) * atan2f(axis_magnitude, q(0)));
vec = vec * wrap_pi(Type(2.0) * atan2(axis_magnitude, q(0)));
}

return vec;
Expand Down
10 changes: 5 additions & 5 deletions matrix/SquareMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ bool inv(const SquareMatrix<Type, M> & A, SquareMatrix<Type, M> & inv)
for (size_t n = 0; n < M; n++) {

// if diagonal is zero, swap with row below
if (fabsf(static_cast<float>(U(n, n))) < 1e-8f) {
if (fabs(static_cast<float>(U(n, n))) < 1e-8f) {
//printf("trying pivot for row %d\n",n);
for (size_t i = n + 1; i < M; i++) {

//printf("\ttrying row %d\n",i);
if (fabsf(static_cast<float>(U(i, n))) > 1e-8f) {
if (fabs(static_cast<float>(U(i, n))) > 1e-8f) {
//printf("swapped %d\n",i);
U.swapRows(i, n);
P.swapRows(i, n);
Expand All @@ -157,11 +157,11 @@ bool inv(const SquareMatrix<Type, M> & A, SquareMatrix<Type, M> & inv)
//printf("U:\n"); U.print();
//printf("P:\n"); P.print();
//fflush(stdout);
//ASSERT(fabsf(U(n, n)) > 1e-8f);
//ASSERT(fabs(U(n, n)) > 1e-8f);
#endif

// failsafe, return zero matrix
if (fabsf(static_cast<float>(U(n, n))) < 1e-8f) {
if (fabs(static_cast<float>(U(n, n))) < 1e-8f) {
return false;
}

Expand Down Expand Up @@ -280,7 +280,7 @@ SquareMatrix <Type, M> cholesky(const SquareMatrix<Type, M> & A)
if (res <= 0) {
L(j, j) = 0;
} else {
L(j, j) = sqrtf(res);
L(j, j) = sqrt(res);
}
} else {
float sum = 0;
Expand Down
2 changes: 0 additions & 2 deletions matrix/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

#pragma once

#include <cmath>

#include "math.hpp"

namespace matrix
Expand Down
1 change: 0 additions & 1 deletion matrix/helper_functions.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include "math.hpp"
#include <cmath>

namespace matrix
{
Expand Down
1 change: 1 addition & 0 deletions matrix/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "stdlib_imports.hpp"
#ifdef __PX4_QURT
#include "dspal_math.h"
#endif
Expand Down
130 changes: 130 additions & 0 deletions matrix/stdlib_imports.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* @file stdlib_imports.hpp
*
* This file is needed to shadow the C standard library math functions with ones provided by the C++ standard library.
* This way we can guarantee that unwanted functions from the C library will never creep back in unexpectedly.
*
* @author Pavel Kirienko <pavel.kirienko@zubax.com>
*/

#pragma once

#include <cmath>
#include <cstdlib>
#include <cinttypes>

namespace matrix {

#if defined(__PX4_NUTTX)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to detect NuttX in general instead of the PX4 flavor of it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavel-kirienko

For the record no there is none. So far, a change to add NUTTX to the generated config.h will not be accepted upstream. So this has to be set in the build by the flags, and we do that with __PX4_NUTTX.

/*
* NuttX has no usable C++ math library, so we need to provide the needed definitions here manually.
*/
#define MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(name) \
inline float name(float x) { return ::name##f(x); } \
inline double name(double x) { return ::name(x); } \
inline long double name(long double x) { return ::name##l(x); }

#define MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(name) \
inline float name(float x, float y) { return ::name##f(x, y); } \
inline double name(double x, double y) { return ::name(x, y); } \
inline long double name(long double x, long double y) { return ::name##l(x, y); }

MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(fabs)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(log)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(log10)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(exp)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sqrt)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sin)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(cos)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(tan)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(asin)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(acos)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(atan)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(sinh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(cosh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(tanh)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(ceil)
MATRIX_NUTTX_WRAP_MATH_FUN_UNARY(floor)

MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(pow)
MATRIX_NUTTX_WRAP_MATH_FUN_BINARY(atan2)

#else // Not NuttX, using the C++ standard library

using std::abs;
using std::div;
using std::fabs;
using std::fmod;
using std::exp;
using std::log;
using std::log10;
using std::pow;
using std::sqrt;
using std::sin;
using std::cos;
using std::tan;
using std::asin;
using std::acos;
using std::atan;
using std::atan2;
using std::sinh;
using std::cosh;
using std::tanh;
using std::ceil;
using std::floor;
using std::frexp;
using std::ldexp;
using std::modf;

# if (__cplusplus >= 201103L)

using std::imaxabs;
using std::imaxdiv;
using std::remainder;
using std::remquo;
using std::fma;
using std::fmax;
using std::fmin;
using std::fdim;
using std::nan;
using std::nanf;
using std::nanl;
using std::exp2;
using std::expm1;
using std::log2;
using std::log1p;
using std::cbrt;
using std::hypot;
using std::asinh;
using std::acosh;
using std::atanh;
using std::erf;
using std::erfc;
using std::tgamma;
using std::lgamma;
using std::trunc;
using std::round;
using std::nearbyint;
using std::rint;
using std::scalbn;
using std::ilogb;
using std::logb;
using std::nextafter;
using std::copysign;
using std::fpclassify;
using std::isfinite;
using std::isinf;
using std::isnan;
using std::isnormal;
using std::signbit;
using std::isgreater;
using std::isgreaterequal;
using std::isless;
using std::islessequal;
using std::islessgreater;
using std::isunordered;

# endif
#endif

}