Skip to content

Commit

Permalink
Merge pull request #2524 from 9il/ctfe-math-2
Browse files Browse the repository at this point in the history
findRoot optimization and bugfix
  • Loading branch information
H. S. Teoh committed Sep 19, 2014
2 parents 6e626e0 + e125d03 commit c09c178
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 49 deletions.
8 changes: 4 additions & 4 deletions std/math.d
Expand Up @@ -130,14 +130,14 @@ version(unittest)
if (signbit(x) != signbit(y))
return 0;

if (isinf(x) && isinf(y))
if (isInfinity(x) && isInfinity(y))
return 1;
if (isinf(x) || isinf(y))
if (isInfinity(x) || isInfinity(y))
return 0;

if (isnan(x) && isnan(y))
if (isNaN(x) && isNaN(y))
return 1;
if (isnan(x) || isnan(y))
if (isNaN(x) || isNaN(y))
return 0;

char[30] bufx;
Expand Down
107 changes: 62 additions & 45 deletions std/numeric.d
Expand Up @@ -788,7 +788,7 @@ body {
// Test the function at point c; update brackets accordingly
void bracket(T c)
{
T fc = f(c);
R fc = f(c);
if (fc == 0 || fc.isNaN) { // Exact solution, or NaN
a = c;
fa = fc;
Expand All @@ -815,22 +815,28 @@ body {
a and b differ so wildly in magnitude that the result would be meaningless,
perform a bisection instead.
*/
T secant_interpolate(T a, T b, T fa, T fb)
static T secant_interpolate(T a, T b, R fa, R fb)
{
if (( ((a - b) == a) && b!=0) || (a!=0 && ((b - a) == b))) {
// Catastrophic cancellation
if (a == 0) a = copysign(0.0L, b);
else if (b == 0) b = copysign(0.0L, a);
else if (signbit(a) != signbit(b)) return 0;
if (a == 0)
a = copysign(T(0), b);
else if (b == 0)
b = copysign(T(0), a);
else if (signbit(a) != signbit(b))
return 0;
T c = ieeeMean(a, b);
return c;
}
// avoid overflow
if (b - a > T.max) return b / 2.0 + a / 2.0;
if (fb - fa > T.max) return a - (b - a) / 2;
T c = a - (fa / (fb - fa)) * (b - a);
if (c == a || c == b) return (a + b) / 2;
return c;
// avoid overflow
if (b - a > T.max)
return b / 2 + a / 2;
if (fb - fa > R.max)
return a - (b - a) / 2;
T c = a - (fa / (fb - fa)) * (b - a);
if (c == a || c == b)
return (a + b) / 2;
return c;
}

/* Uses 'numsteps' newton steps to approximate the zero in [a..b] of the
Expand All @@ -841,19 +847,21 @@ body {
T newtonQuadratic(int numsteps)
{
// Find the coefficients of the quadratic polynomial.
T a0 = fa;
T a1 = (fb - fa)/(b - a);
T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
immutable T a0 = fa;
immutable T a1 = (fb - fa)/(b - a);
immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);

// Determine the starting point of newton steps.
T c = oppositeSigns(a2, fa) ? a : b;

// start the safeguarded newton steps.
for (int i = 0; i<numsteps; ++i) {
T pc = a0 + (a1 + a2 * (c - b))*(c - a);
T pdc = a1 + a2*((2.0 * c) - (a + b));
if (pdc == 0) return a - a0 / a1;
else c = c - pc / pdc;
foreach (int i; 0..numsteps) {
immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
immutable T pdc = a1 + a2*((2 * c) - (a + b));
if (pdc == 0)
return a - a0 / a1;
else
c = c - pc / pdc;
}
return c;
}
Expand Down Expand Up @@ -881,7 +889,7 @@ whileloop:
T a0 = a, b0 = b; // record the brackets

// Do two higher-order (cubic or parabolic) interpolation steps.
for (int QQ = 0; QQ < 2; ++QQ) {
foreach (int QQ; 0..2) {
// Cubic inverse interpolation requires that
// all four function values fa, fb, fd, and fe are distinct;
// otherwise use quadratic interpolation.
Expand All @@ -892,16 +900,16 @@ whileloop:
bool ok = distinct;
if (distinct) {
// Cubic inverse interpolation of f(x) at a, b, d, and e
real q11 = (d - e) * fd / (fe - fd);
real q21 = (b - d) * fb / (fd - fb);
real q31 = (a - b) * fa / (fb - fa);
real d21 = (b - d) * fd / (fd - fb);
real d31 = (a - b) * fb / (fb - fa);

real q22 = (d21 - q11) * fb / (fe - fb);
real q32 = (d31 - q21) * fa / (fd - fa);
real d32 = (d31 - q21) * fd / (fd - fa);
real q33 = (d32 - q22) * fa / (fe - fa);
immutable q11 = (d - e) * fd / (fe - fd);
immutable q21 = (b - d) * fb / (fd - fb);
immutable q31 = (a - b) * fa / (fb - fa);
immutable d21 = (b - d) * fd / (fd - fb);
immutable d31 = (a - b) * fb / (fb - fa);

immutable q22 = (d21 - q11) * fb / (fe - fb);
immutable q32 = (d31 - q21) * fa / (fd - fa);
immutable d32 = (d31 - q21) * fd / (fd - fa);
immutable q33 = (d32 - q22) * fa / (fe - fa);
c = a + (q31 + q32 + q33);
if (c.isNaN || (c <= a) || (c >= b)) {
// DAC: If the interpolation predicts a or b, it's
Expand Down Expand Up @@ -950,11 +958,15 @@ whileloop:
// probably false.
if(c==a || c==b || c.isNaN || fabs(c - u) > (b - a) / 2) {
if ((a-b) == a || (b-a) == b) {
if ( (a>0 && b<0) || (a<0 && b>0) ) c = 0;
if ( (a>0 && b<0) || (a<0 && b>0) )
c = 0;
else {
if (a==0) c = ieeeMean(cast(T)copysign(0.0L, b), b);
else if (b==0) c = ieeeMean(cast(T)copysign(0.0L, a), a);
else c = ieeeMean(a, b);
if (a==0)
c = ieeeMean(copysign(T(0), b), b);
else if (b==0)
c = ieeeMean(copysign(T(0), a), a);
else
c = ieeeMean(a, b);
}
} else {
c = a + (b - a) / 2;
Expand All @@ -973,17 +985,18 @@ whileloop:
// perform a binary chop.

if( (a==0 || b==0 ||
(fabs(a) >= 0.5 * fabs(b) && fabs(b) >= 0.5 * fabs(a)))
&& (b - a) < 0.25 * (b0 - a0)) {
(fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
&& (b - a) < T(0.25) * (b0 - a0)) {
baditer = 1;
continue;
}
// DAC: If this happens on consecutive iterations, we probably have a
// pathological function. Perform a number of bisections equal to the
// total number of consecutive bad iterations.

if ((b - a) < 0.25 * (b0 - a0)) baditer = 1;
for (int QQ = 0; QQ < baditer ;++QQ) {
if ((b - a) < T(0.25) * (b0 - a0))
baditer = 1;
foreach (int QQ; 0..baditer) {
e = d;
fe = fd;

Expand All @@ -992,8 +1005,10 @@ whileloop:
else {
T usea = a;
T useb = b;
if (a == 0) usea = copysign(0.0L, b);
else if (b == 0) useb = copysign(0.0L, a);
if (a == 0)
usea = copysign(T(0), b);
else if (b == 0)
useb = copysign(T(0), a);
w = ieeeMean(usea, useb);
}
bracket(w);
Expand Down Expand Up @@ -1025,11 +1040,13 @@ unittest

// Test functions
real cubicfn (real x) {
++numCalls;
if (x>float.max) x = float.max;
if (x<-double.max) x = -double.max;
// This has a single real root at -59.286543284815
return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
++numCalls;
if (x>float.max)
x = float.max;
if (x<-double.max)
x = -double.max;
// This has a single real root at -59.286543284815
return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
}
// Test a function with more than one root.
real multisine(real x) { ++numCalls; return sin(x); }
Expand Down

0 comments on commit c09c178

Please sign in to comment.