Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify non-integer code #698

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
166 changes: 81 additions & 85 deletions shelley/chain-and-ledger/dependencies/non-integer/src/NonIntegral.hs
Expand Up @@ -14,54 +14,60 @@ data CompareResult a = BELOW a Int
| UNKNOWN
deriving (Show, Eq)

scaleExp :: (RealFrac b) => b -> (Integer, b)
scaleExp x = (ceiling x, x / fromIntegral (ceiling x :: Integer))
scaleExp :: (RealFrac a) => a -> (Integer, a)
scaleExp x = (x', x / fromIntegral x')
where x' = ceiling x

-- | Exponentiation
(***) :: (RealFrac a, Enum a, Show a) => a -> a -> a
a *** b
| a == 0 = if b == 0 then 1 else 0
| b == 0 = 1
| a == 0 = 0
| a == 1 = 1
| otherwise = exp' l
where l = b * ln' a
| otherwise = exp' (b * ln' a)

ipow' :: Num a => a -> Integer -> a
ipow' x n
| n == 0 = 1
| mod n 2 == 0 = let res = ipow' x (div n 2) in res * res
| otherwise = x * ipow' x (n - 1)
| n == 0 = 1
| m == 0 = let y = ipow' x d in y * y
| otherwise = x * ipow' x (n - 1)
where (d,m) = divMod n 2

ipow :: Fractional a => a -> Integer -> a
ipow x n
| n < 0 = 1 / (ipow x (-n))
| otherwise = ipow' x n
| n < 0 = 1 / ipow' x (-n)
| otherwise = ipow' x n

logAs :: (Num a) => a -> [a]
logAs a = a' : a' : logAs (a + 1)
where
a' = a * a

-- | Approximate ln(1+x) for x \in [0, \infty)
fln :: (Fractional a, Enum a, Ord a, Show a) => Int -> a -> a
fln maxN x = if x < 0
then error ("x = " ++ show x ++ "is not inside domain [0, ..) ")
else cf maxN 0 eps Nothing 1 0 0 1 (x : [a * x | a <- logAs 1]) [1,2 ..]
-- a_1 = x, a_{2k} = a_{2k+1} = x·k^2, k >= 1
-- b_n = n, n >= 0
lncf :: (Fractional a, Enum a, Ord a, Show a) => Int -> a -> a
lncf maxN x
| x < 0 = error ("x = " ++ show x ++ " is not inside domain [0,..)")
| otherwise = cf maxN 0 eps Nothing 1 0 0 1 as [1,2..]
where as = x : map (*x) (logAs 1)

eps :: (Fractional a) => a
eps = 1 / 10^(24::Int)

-- | Compute continued fraction using max steps or bounded list of a/b factors.
-- The 'maxN' parameter gives the maximum recursion depth, 'n' gives the current
-- rursion depth, 'lastVal' is the optional last value ('Nothing' for the first
-- iteration). 'aNm2', 'bNm2' are a_{m-2} and b_{m-2}, and 'aNm1' / 'bNm1' are
-- a_{m-1} / b_{m-1} respectively, 'an' / 'bn' are lists of succecsive a_n / b_n
-- values for the recurrence relation:
-- iteration). 'aNm2' / 'bNm2' are A_{n-2} / B_{n-2}, 'aNm1' / 'bNm1' are
-- A_{n-1} / B_{n-1}, and 'aN' / 'bN' are A_n / B_n respectively, 'an' / 'bn'
-- are lists of succecsive a_n / b_n values for the recurrence relation:
--
-- A_{-1} = 1 B_{-1} = 0
-- A_0 = b_0 B_0 = 1
-- A_{n+1} = b_{n+1}*A_n + a_{n+1}*A_{n-1} B_{n+1} = b_{n+1}*B_n + a_{n+1}*B_{n-1}
-- A_{-1} = 1, A_0 = b_0
-- B_{-1} = 0, B_0 = 1
-- A_n = b_n*A_{n-1} + a_n*A_{n-2}
-- B_n = b_n*B_{n-1} + a_n*B_{n-2}
--
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept this hunk because the expressions are simpler ('n' vs 'n+1'), vertical alignment highlights its similarities and differences, is the same form in which are written in the doc, and also matches the parameters of the function implementation.

-- the convergent is calculated as x_n = A_n/B_n
-- The convergent 'xn' is calculated as x_n = A_n/B_n
--
-- a_1
-- result = b_0 + ---------------------
Expand All @@ -74,7 +80,7 @@ eps = 1 / 10^(24::Int)
-- .
--
-- The recursion stops once 'maxN' iterations have been reached, or either the
-- list 'as' or 'bs' is exhausted or 'lastVal' differs less than 'eps' from the
-- list 'as' or 'bs' is exhausted or 'lastVal' differs less than 'epsilon' from the
-- new convergent.
cf ::
(Fractional a, Ord a, Show a)
Expand All @@ -89,24 +95,16 @@ cf ::
-> [a]
-> [a]
-> a
cf _ _ _ _ _ _ aNm1 bNm1 _ [] = aNm1 / bNm1
cf _ _ _ _ _ _ aNm1 bNm1 [] _ = aNm1 / bNm1
cf maxN n epsilon lastVal aNm2 bNm2 aNm1 bNm1 (an:as) (bn:bs)
| maxN == n = convergent
| otherwise =
case lastVal of
Nothing -> cf maxN (n + 1) epsilon (Just convergent) aNm1 bNm1 aN bN as bs
Just c' -> if abs(convergent - c') < epsilon
then convergent
else cf maxN (n + 1) epsilon (Just convergent) aNm1 bNm1 aN bN as bs
| maxN == n = xn
| converges = xn
| otherwise = cf maxN (n + 1) epsilon (Just xn) aNm1 bNm1 aN bN as bs
where
ba = bn * aNm1
aa = an * aNm2
aN = ba + aa
bb = bn * bNm1
ab = an * bNm2
bN = bb + ab
convergent = aN / bN
converges = maybe False (\x -> abs (x - xn) < epsilon) lastVal
xn = aN / bN -- convergent
aN = bn * aNm1 + an * aNm2
bN = bn * bNm1 + an * bNm2
cf _ _ _ _ _ _ aN bN _ _ = aN / bN

-- | Simple way to find integer powers that bound x. At every step the bounds
-- are doubled. Assumption x > 0, the calculated bound is `factor^l <= x <=
Expand All @@ -122,7 +120,7 @@ bound ::
-> Integer
-> (Integer, Integer)
bound factor x x' x'' l u
| x' <= x && x'' >= x = (l, u)
| x' <= x && x <= x'' = (l, u)
| otherwise = bound factor x (x' * x') (x'' * x'') (2 * l) (2 * u)

-- | Bisect bounds to find the smallest integer power such that
Expand All @@ -134,68 +132,66 @@ contract ::
-> Integer
-> Integer
-> Integer
contract factor x l u
| l + 1 == u = l
| otherwise =
if x < x'
then contract factor x l mid
else contract factor x mid u
contract factor x = go
where
mid = l + ((u - l) `div` 2)
x' = ipow factor mid

e :: (RealFrac a, Show a) => a
e = exp' 1
go l u
| l + 1 == u = l
| otherwise =
if x < x'
then go l mid
else go mid u
where
mid = l + ((u - l) `div` 2)
x' = ipow factor mid

exp1 :: (RealFrac a, Show a) => a
exp1 = exp' 1

-- | find n with `e^n<=x<e^(n+1)`
findE :: (RealFrac a) => a -> a -> Integer
findE eone x = contract eone x lower upper
findE e x = contract e x lower upper
where
(lower, upper) = bound eone x (1 / eone) eone (-1) 1
(lower, upper) = bound e x (1/e) e (-1) 1

-- | Compute natural logarithm via continued fraction, first splitting integral
-- part and then using continued fractions approximation for `ln(1+x)`
ln' :: (RealFrac a, Enum a, Show a) => a -> a
ln' x = if x == 0
then error "0 is not in domain of ln"
else fromIntegral n + approxln
ln' x
| x <= 0 = error (show x ++ " is not in domain of ln")
| otherwise = fromIntegral n + lncf 1000 x'
where (n, x') = splitLn x
approxln = fln 1000 x'

splitLn :: (RealFrac b, Show b) => b -> (Integer, b)
splitLn :: (RealFrac a, Show a) => a -> (Integer, a)
splitLn x = (n, x')
where n = findE e x
y' = exp' (fromIntegral n)
x' = (x / y') - 1 -- x / e^n > 1!
where n = findE exp1 x
y' = ipow exp1 n
x' = (x / y') - 1 -- x / e^n > 1!

exp' :: (RealFrac a, Show a) => a -> a
exp' x
| x < 0 = 1 / exp' (-x)
| otherwise = ipow x' n
where (n, x_) = scaleExp x
x' = taylorExp 1000 1 x_ 1 1 1
| x < 0 = 1 / exp' (-x)
| otherwise = ipow x' n
where (n, x_) = scaleExp x
x' = taylorExp 1000 1 x_ 1 1 1

taylorExp :: (RealFrac a, Show a) => Int -> Int -> a -> a -> a -> a -> a
taylorExp maxN currentN x lastX acc divisor
| maxN == currentN = acc
| abs nextX < eps = acc
| otherwise = taylorExp maxN (currentN + 1) x nextX (acc + nextX) (divisor + 1)
where nextX = (lastX * x) / divisor
taylorExp maxN n x lastX acc divisor
| maxN == n = acc
| abs nextX < eps = acc
| otherwise = taylorExp maxN (n + 1) x nextX (acc + nextX) (divisor + 1)
where nextX = (lastX * x) / divisor

taylorExpCmp :: (RealFrac a, Show a) => a -> a -> a -> CompareResult a
csoroz marked this conversation as resolved.
Show resolved Hide resolved
taylorExpCmp boundX cmp x =
taylorExpCmp' 1000 0 boundX cmp x x 1 1

taylorExpCmp' :: (RealFrac a, Show a) => Int -> Int -> a -> a -> a -> a -> a -> a -> CompareResult a
taylorExpCmp' maxN currentN boundX cmp x err acc divisor
| maxN == currentN = UNKNOWN
| abs nextX < eps = UNKNOWN
| otherwise =
if cmp >= acc' + errorTerm then ABOVE acc' (currentN + 1)
else if cmp < acc' - errorTerm then BELOW acc' (currentN + 1)
else taylorExpCmp' maxN (currentN + 1) boundX cmp x error' acc' divisor'
where divisor' = divisor + 1
errorTerm = error' * boundX
nextX = err
error' = (err * x) / divisor'
acc' = acc + nextX
taylorExpCmp boundX cmp x = go 1000 0 x 1 1
where
go maxN n err acc divisor
| maxN == n = UNKNOWN
| abs nextX < eps = UNKNOWN
| cmp > acc' + errorTerm = ABOVE acc' (n + 1)
| cmp < acc' - errorTerm = BELOW acc' (n + 1)
| otherwise = go maxN (n + 1) err' acc' divisor'
where errorTerm = err' * boundX
divisor' = divisor + 1
nextX = err
err' = (err * x) / divisor'
acc' = acc + nextX