Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Ability to change rounding mode for all floats (cf. #2976) #3149

Merged
merged 11 commits into from
Sep 2, 2013
1 change: 1 addition & 0 deletions base/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/pcre_h.jl
/errno_h.jl
/build_h.jl
/fenv_constants.jl
/file_constants.jl
/uv_constants.jl
6 changes: 5 additions & 1 deletion base/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ dirty = $(shell [ -z "$(shell git status --porcelain 2>/dev/null)" ] && echo "fa



all: pcre_h.jl errno_h.jl build_h.jl.phony file_constants.jl uv_constants.jl
all: pcre_h.jl errno_h.jl build_h.jl.phony fenv_constants.jl file_constants.jl uv_constants.jl

pcre_h.jl:
@$(PRINT_PERL) $(CPP) -dM $(shell $(PCRE_CONFIG) --prefix)/include/pcre.h | perl -nle '/^\s*#define\s+PCRE_(\w*)\s*\(?($(PCRE_CONST))\)?\s*$$/ and print "const $$1 = uint32($$2)"' | sort > $@

errno_h.jl:
@$(PRINT_PERL) echo '#include "errno.h"' | $(CPP) -dM - | perl -nle 'print "const $$1 = int32($$2)" if /^#define\s+(E\w+)\s+(\d+)\s*$$/' | sort > $@

fenv_constants.jl: ../src/fenv_constants.h
$(QUIET_PERL) ${CC} -E -P -DJULIA ../src/fenv_constants.h | tail -n 8 > $@

file_constants.jl: ../src/file_constants.h
@$(PRINT_PERL) $(CPP) -P -DJULIA ../src/file_constants.h | perl -nle 'print "$$1 0o$$2" if /^(\s*const\s+[A-z_]+\s+=)\s+(0[0-9]*)\s*$$/; print "$$1" if /^\s*(const\s+[A-z_]+\s+=\s+([1-9]|0x)[0-9A-z]*)\s*$$/' > $@

Expand Down Expand Up @@ -98,5 +101,6 @@ clean:
rm -f pcre_h.jl
rm -f errno_h.jl
rm -f build_h.jl
rm -f fenv_constants.jl
rm -f uv_constants.jl
rm -f file_constants.jl
9 changes: 9 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ export
Reverse,
RevString,
RopeString,
RoundFromZero,
RoundDown,
RoundingMode,
RoundNearest,
RoundToZero,
RoundUp,
Schur,
Set,
SparseMatrixCSC,
Expand Down Expand Up @@ -833,6 +839,9 @@ export
get_bigfloat_rounding,
set_bigfloat_rounding,
with_bigfloat_rounding,
get_rounding,
set_rounding,
with_rounding,

# statistics
cor,
Expand Down
8 changes: 8 additions & 0 deletions base/fenv_constants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
const JL_FE_UNDERFLOW = 0x0010
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file shouldn't be part of the commit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I thought I removed it, but now I see I just added the file to .gitignore. Done, thanks.

const JL_FE_OVERFLOW = 0x0008
const JL_FE_DIVBYZERO = 0x0004
const JL_FE_INVALID = 0x0001
const JL_FE_TONEAREST = 0x0000
const JL_FE_UPWARD = 0x0800
const JL_FE_DOWNWARD = 0x0400
const JL_FE_TOWARDZERO = 0x0c00
42 changes: 27 additions & 15 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MPFR

export
BigFloat,
RoundFromZero,
get_bigfloat_precision,
set_bigfloat_precision,
with_bigfloat_precision,
Expand All @@ -19,19 +20,17 @@ import
gamma, lgamma, digamma, erf, erfc, zeta, log1p, airyai, iceil, ifloor,
itrunc, eps, signbit, sin, cos, tan, sec, csc, cot, acos, asin, atan,
cosh, sinh, tanh, sech, csch, coth, acosh, asinh, atanh, atan2,
serialize, deserialize, inf, nan, hash
serialize, deserialize, inf, nan, hash,
RoundingMode, RoundDown, RoundingMode, RoundNearest, RoundToZero,
RoundUp

import Base.Math.lgamma_r

const ROUNDING_MODE = [0]
const DEFAULT_PRECISION = [256]

# Rounding modes
const RoundToNearest = 0
const RoundToZero = 1
const RoundUp = 2
const RoundDown = 3
const RoundAwayZero = 4
type RoundFromZero <: RoundingMode end

# Basic type and initialization definitions

Expand Down Expand Up @@ -98,7 +97,7 @@ for to in (Int8, Int16, Int32, Int64)
function convert(::Type{$to}, x::BigFloat)
(isinteger(x) && (typemin($to) <= x <= typemax($to))) || throw(InexactError())
convert($to, ccall((:mpfr_get_si,:libmpfr),
Clong, (Ptr{BigFloat}, Int32), &x, RoundToZero))
Clong, (Ptr{BigFloat}, Int32), &x, 0))
end
end
end
Expand All @@ -108,7 +107,7 @@ for to in (Uint8, Uint16, Uint32, Uint64)
function convert(::Type{$to}, x::BigFloat)
(isinteger(x) && (typemin($to) <= x <= typemax($to))) || throw(InexactError())
convert($to, ccall((:mpfr_get_ui,:libmpfr),
Culong, (Ptr{BigFloat}, Int32), &x, RoundToZero))
Culong, (Ptr{BigFloat}, Int32), &x, 0))
end
end
end
Expand Down Expand Up @@ -597,13 +596,26 @@ function set_bigfloat_precision(x::Int)
DEFAULT_PRECISION[end] = x
end

get_bigfloat_rounding() = ROUNDING_MODE[end]
function set_bigfloat_rounding(x::Int)
if x < 0 || x > 4
throw(DomainError())
function get_bigfloat_rounding()
if ROUNDING_MODE[end] == 0
return RoundNearest
elseif ROUNDING_MODE[end] == 1
return RoundToZero
elseif ROUNDING_MODE[end] == 2
return RoundUp
elseif ROUNDING_MODE[end] == 3
return RoundDown
elseif ROUNDING_MODE[end] == 4
return RoundFromZero
else
error("Invalid rounding mode")
end
ROUNDING_MODE[end] = x
end
set_bigfloat_rounding(::Type{RoundNearest}) = ROUNDING_MODE[end] = 0
set_bigfloat_rounding(::Type{RoundToZero}) = ROUNDING_MODE[end] = 1
set_bigfloat_rounding(::Type{RoundUp}) = ROUNDING_MODE[end] = 2
set_bigfloat_rounding(::Type{RoundDown}) = ROUNDING_MODE[end] = 3
set_bigfloat_rounding(::Type{RoundFromZero}) = ROUNDING_MODE[end] = 4

function copysign(x::BigFloat, y::BigFloat)
z = BigFloat()
Expand Down Expand Up @@ -635,7 +647,7 @@ end

function itrunc(x::BigFloat)
z = BigInt()
ccall((:mpfr_get_z, :libmpfr), Int32, (Ptr{BigInt}, Ptr{BigFloat}, Int32), &z, &x, RoundToZero)
ccall((:mpfr_get_z, :libmpfr), Int32, (Ptr{BigInt}, Ptr{BigFloat}, Int32), &z, &x, 0)
return z
end

Expand Down Expand Up @@ -682,7 +694,7 @@ function with_bigfloat_precision(f::Function, precision::Integer)
end
end

function with_bigfloat_rounding(f::Function, rounding::Integer)
function with_bigfloat_rounding{T<:RoundingMode}(f::Function, rounding::Type{T})
old_rounding = get_bigfloat_rounding()
set_bigfloat_rounding(rounding)
try
Expand Down
45 changes: 45 additions & 0 deletions base/rounding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module Rounding
include("fenv_constants.jl")

export
RoundingMode, RoundNearest, RoundToZero, RoundUp, RoundDown,
get_rounding, set_rounding, with_rounding

## rounding modes ##
abstract RoundingMode
type RoundNearest <: RoundingMode end
type RoundToZero <: RoundingMode end
type RoundUp <: RoundingMode end
type RoundDown <: RoundingMode end

set_rounding(::Type{RoundNearest}) = ccall(:fesetround, Cint, (Cint, ), JL_FE_TONEAREST)
set_rounding(::Type{RoundToZero}) = ccall(:fesetround, Cint, (Cint, ), JL_FE_TOWARDZERO)
set_rounding(::Type{RoundUp}) = ccall(:fesetround, Cint, (Cint, ), JL_FE_UPWARD)
set_rounding(::Type{RoundDown}) = ccall(:fesetround, Cint, (Cint, ), JL_FE_DOWNWARD)

function get_rounding()
r = ccall(:fegetround, Cint, ())
if r == JL_FE_TONEAREST
return RoundNearest
elseif r == JL_FE_DOWNWARD
return RoundDown
elseif r == JL_FE_UPWARD
return RoundUp
elseif r == JL_FE_TOWARDZERO
return RoundToZero
else
error()
end
end

function with_rounding{T<:RoundingMode}(f::Function, rounding::Type{T})
old_rounding = get_rounding()
set_rounding(rounding)
try
return f()
finally
set_rounding(old_rounding)
end
end

end #module
4 changes: 4 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ include("fftw.jl")
include("dsp.jl")
importall .DSP

# rounding utilities
include("rounding.jl")
importall .Rounding

# BigInts and BigFloats
include("gmp.jl")
importall .GMP
Expand Down
10 changes: 10 additions & 0 deletions src/fenv_constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <fenv.h>
const JL_FE_INEXACT = FE_INEXACT
const JL_FE_UNDERFLOW = FE_UNDERFLOW
const JL_FE_OVERFLOW = FE_OVERFLOW
const JL_FE_DIVBYZERO = FE_DIVBYZERO
const JL_FE_INVALID = FE_INVALID
const JL_FE_TONEAREST = FE_TONEAREST
const JL_FE_UPWARD = FE_UPWARD
const JL_FE_DOWNWARD = FE_DOWNWARD
const JL_FE_TOWARDZERO = FE_TOWARDZERO
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ TESTS = all core keywordargs numbers strings unicode collections hashing \
math functional bigint sorting statistics spawn parallel arpack file \
git pkg resolve suitesparse complex version pollfd mpfr broadcast \
socket floatapprox priorityqueue readdlm regex float16 combinatorics \
sysinfo
sysinfo rounding

$(TESTS) ::
@$(PRINT_JULIA) $(call spawn,$(JULIA_EXECUTABLE)) ./runtests.jl $@
Expand Down
6 changes: 3 additions & 3 deletions test/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ z = BigFloat(30)
# rounding modes
with_bigfloat_precision(4) do
# default mode is round to nearest
down, up = with_bigfloat_rounding(MPFR.RoundToNearest) do
down, up = with_bigfloat_rounding(RoundNearest) do
BigFloat("0.0938"), BigFloat("0.102")
end
with_bigfloat_rounding(MPFR.RoundDown) do
with_bigfloat_rounding(RoundDown) do
@test BigFloat(0.1) == down
@test BigFloat(0.1) != up
end
with_bigfloat_rounding(MPFR.RoundUp) do
with_bigfloat_rounding(RoundUp) do
@test BigFloat(0.1) != down
@test BigFloat(0.1) == up
end
Expand Down
89 changes: 89 additions & 0 deletions test/rounding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Small sanity tests to ensure changing the rounding of float functions work
using Base.Test

## Float64 checks
# a + b returns a number exactly between prevfloat(1.) and 1., so its
# final result depends strongly on the utilized rounding direction.
a = prevfloat(0.5)
b = 0.5
c = 0x1p-54
d = prevfloat(1.)

# Default rounding direction, RoundNearest
@test a + b === 1.
@test - a - b === -1.
@test a - b === -c
@test b - a === c

# RoundToZero
with_rounding(RoundToZero) do
@test a + b === d
@test - a - b === -d
@test a - b === -c
@test b - a === c
end

# Sanity check to see if we have returned to RoundNearest
@test a + b === 1.
@test - a - b === -1.
@test a - b == -c
@test b - a == c

# RoundUp
with_rounding(RoundUp) do
@test a + b === 1.
@test - a - b === -d
@test a - b === -c
@test b - a === c
end

# RoundDown
with_rounding(RoundDown) do
@test a + b === d
@test - a - b === -1.
@test a - b === -c
@test b - a === c
end

## Float32 checks

a32 = prevfloat(0.5f0)
b32 = 0.5f0
c32 = (1.f0 - prevfloat(1.f0))/2
d32 = prevfloat(1.0f0)

# Default rounding direction, RoundNearest
@test a32 + b32 === 1.0f0
@test - a32 - b32 === -1.0f0
@test a32 - b32 === -c32
@test b32 - a32 === c32

# RoundToZero
with_rounding(RoundToZero) do
@test a32 + b32 === d32
@test - a32 - b32 === -d32
@test a32 - b32 === -c32
@test b32 - a32 === c32
end

# Sanity check to see if we have returned to RoundNearest
@test a32 + b32 === 1.0f0
@test - a32 - b32 === -1.0f0
@test a32 - b32 == -c32
@test b32 - a32 == c32

# RoundUp
with_rounding(RoundUp) do
@test a32 + b32 === 1.0f0
@test - a32 - b32 === -d32
@test a32 - b32 === -c32
@test b32 - a32 === c32
end

# RoundDown
with_rounding(RoundDown) do
@test a32 + b32 === d32
@test - a32 - b32 === -1.0f0
@test a32 - b32 === -c32
@test b32 - a32 === c32
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ testnames = ["core", "keywordargs", "numbers", "strings", "unicode",
"arpack", "file", "suitesparse", "version",
"resolve", "pollfd", "mpfr", "broadcast", "complex",
"socket", "floatapprox", "readdlm", "regex", "float16",
"combinatorics", "sysinfo"]
"combinatorics", "sysinfo", "rounding"]

tests = ARGS==["all"] ? testnames : ARGS

Expand Down