Skip to content

Commit

Permalink
docs; add/use liftLin2 and indexDefault to simplify/speed
Browse files Browse the repository at this point in the history
  • Loading branch information
Barak A. Pearlmutter committed Apr 12, 2009
1 parent dac0571 commit 67c3f71
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
10 changes: 10 additions & 0 deletions List/Uttl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ _ !!~ i | i<0 = error "negative index"
[x] !!~ _ = x
(x:_) !!~ 0 = x
(x:xs) !!~ i = xs !!~ (i-1)

-- | The 'indexDefault' function indexes into a list like @(!!)@, but
-- returns the given default when it runs off the end.

indexDefault :: a -> [a] -> Int -> a

indexDefault def _ i | i<0 = error "negative index"
indexDefault def (x:_) 0 = x
indexDefault def [] i = def
indexDefault def (x:xs) i = indexDefault def xs (i-1)
106 changes: 62 additions & 44 deletions Numeric/FAD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ Bj&#246;rn Buckwalter (<bjorn.buckwalter@gmail.com>)
Notes:
Each invocation of the differentiation function introduces a
distinct perturbation, which requires a distinct dual number type.
In order to prevent these from being confused, tagging, called
Each invocation of the differentiation function introduces a distinct
perturbation, which requires a distinct derivative-carrying number
type. In order to prevent these from being confused, tagging, called
branding in the Haskell community, is used. This seems to prevent
perturbation confusion, although it would be nice to have an actual
proof of this. The technique does require adding invocations of
lift at appropriate places when nesting is present.
proof of this. The technique does require adding invocations of lift
at appropriate places when nesting is present, and degrades modularity
by exposing "forall" types in type signatures.
-}

Expand All @@ -56,7 +57,7 @@ tagging to allow dynamic nesting, if the type system would allow.

-- Forward Automatic Differentiation
module Numeric.FAD (
-- * Higher-Order Generalized Dual Numbers
-- * Derivative Towers: Higher-Order Generalized Dual Numbers
Tower, lift, primal,

-- * First-Order Differentiation Operators
Expand Down Expand Up @@ -86,7 +87,7 @@ where
import Data.List (transpose)
import Data.Foldable (Foldable)
import qualified Data.Foldable (all)
import List.Uttl (zipWithDefaults)
import List.Uttl (zipWithDefaults, indexDefault)
import Data.Function (on)

-- To Do:
Expand All @@ -99,16 +100,24 @@ import Data.Function (on)

-- Notes:

-- The constructor is "Bundle" because dual numbers are tangent-vector
-- bundles, in the terminology of differential geometry. For the same
-- reason, the accessor for the first derivative is "tangent".
-- This package implements forward automatic differentiation,
-- generalized to produce not only first derivatives, but a tower of
-- all higher-order derivatives. This is done by replacing a base (or
-- "primal") numberic type by a numeric type that holds the primal value
-- but also carries along the derivative(s). If we produced only
-- first derivatives, the augmented type would be a "Dual Number".
-- And Dual Numbers are tangent-vector bundles, in the terminology of
-- differential geometry. For the this reason, we call the accessor
-- for the first derivative "tangent". We also sometimes refer to the
-- augmented numbers as "bundles", since they bundle together a primal
-- value and some derivative information.

-- The multivariate case is handled as a list on inputs, but an
-- arbitrary functor on outputs. This asymmetry is because Haskell
-- provides fmap but not fzipWith.

-- The derivative towers can be truncated, using Zero. Care is taken
-- to preserve said trunction, when possible.
-- The derivative towers can be truncated. Care is taken to preserve
-- said trunction whenever possible.


-- Other quirks:
Expand Down Expand Up @@ -161,13 +170,13 @@ newtype Tower tag a = Tower [a] deriving Show
-- Injectors and accessors for derivative towers

-- | The 'lift' function injects a primal number into the domain of
-- dual numbers, with a zero tower. If dual numbers were a monad,
-- 'lift' would be 'return'.
-- derivative towers, with a zero tower. If generalized dual numbers
-- were a monad, 'lift' would be 'return'.
lift :: Num a => a -> Tower tag a
lift = (`bundle` zero)

-- | The 'bundle' function takes a primal number and a dual number
-- tower and returns a dual number tower with the given tower shifted
-- | The 'bundle' function takes a primal number and a derivative
-- tower and returns a derivative tower with the given tower shifted
-- up one and the new primal inserted.
--
-- Property: @x = bundle (primal x) (tangentTower x)@
Expand All @@ -181,42 +190,43 @@ zero :: Num a => Tower tag a
zero = toTower []

-- | The 'apply' function applies a function to a number lifted from
-- the primal domain to the dual number domain, with derivative 1,
-- thus calculating the generalized push-forward, in the differential
-- geometric sense, of the given function at the given point.
-- the primal domain to the derivative tower domain, with unit
-- derivative, thus calculating the generalized push-forward, in the
-- differential geometric sense, of the given function at the given
-- point.
apply :: Num a => (Tower tag a -> b) -> a -> b
apply = (. (`bundle` 1))

-- | The 'towerElt' function finds the i-th element of a dual number
-- | The 'towerElt' function finds the i-th element of a derivative
-- | tower, where the 0-th element is the primal value, the 1-st
-- | element is the first derivative, etc.
towerElt :: Num a => Int -> Tower tag a -> a
towerElt i (Tower xs) = zeroPad xs !! i
towerElt i (Tower xs) = xs !!!! i

-- | The 'fromTower' function converts a dual number tower to a list
-- of values with the i-th derivatives, i=0,1,..., possibly truncated
-- | The 'fromTower' function converts a derivative tower to a list of
-- values with the i-th derivatives, i=0,1,..., possibly truncated
-- when all remaining values in the tower are zero.
fromTower :: Tower tag a -> [a]
fromTower (Tower xs) = xs

-- | The 'toTower' function converts a list of numbers into a dual
-- | number tower.
-- | The 'toTower' function converts a list of numbers into a
-- | derivative tower.
toTower :: [a] -> Tower tag a
toTower = Tower

-- | The 'primal' function finds the primal value from a dual number
-- | The 'primal' function finds the primal value from a derivative
-- | tower. The inverse of 'lift'.
primal :: Num a => Tower tag a -> a
primal = towerElt 0

-- | The 'tangent' function finds the tangent value of a dual number
-- | The 'tangent' function finds the tangent value of a derivative
-- | tower, i.e., the first-order derivative.
tangent :: Num a => Tower tag a -> a
tangent = towerElt 1

-- | The 'tangentTower' function finds the entire tower of tangent
-- values of a dual number tower, starting at the 1st derivative.
-- This is equivalent, in an appropriate sense, to taking the first
-- values of a derivative tower, starting at the 1st derivative. This
-- is equivalent, in an appropriate sense, to taking the first
-- derivative.
tangentTower :: Num a => Tower tag a -> Tower tag a
tangentTower (Tower []) = zero
Expand Down Expand Up @@ -292,25 +302,30 @@ liftA1disc = (. primal)
liftA2disc :: (Num a) => (a -> a -> b) -> Tower tag a -> Tower tag a -> b
liftA2disc = (`on` primal)

-- | The 'liftLin' function lifts a linear scalar function from the
-- primal domain into the derivative tower domain. WARNING: the
-- | The 'liftLin' function lifts a scalar linear function from the
-- primal domain into the derivative tower domain. WARNING: The
-- restriction to linear functions is not enforced by the type system.
liftLin :: (a -> b) -> Tower tag a -> Tower tag b
liftLin f = toTower . map f . fromTower

-- | The 'liftLin2' function lifts a binary linear function from the
-- primal domain into the derivative tower domain. WARNING 1: The
-- restriction to linear functions is not enforced by the type system.
-- WARNING 2: Binary linear means linear in both arguments together,
-- not bilinear.
liftLin2 :: (Num a, Num b) =>
(a -> a -> b) -> Tower tag a -> Tower tag a -> Tower tag b
liftLin2 f = (toTower.) . (zipWithDefaults f 0 0 `on` fromTower)

-- Numeric operations on derivative towers.

instance Num a => Num (Tower tag a) where
(Tower []) + y = y
x + (Tower []) = x
x + y = bundle (primal x + primal y) (tangentTower x + tangentTower y)
x - (Tower []) = x
(Tower []) - x = negate x
x - y = bundle (primal x - primal y) (tangentTower x - tangentTower y)
(Tower []) * _ = zero
_ * (Tower []) = zero
x * y = liftA2 (*) (flip (,)) x y
negate = liftLin negate
(+) = liftLin2 (+)
(-) = liftLin2 (-)
(*) (Tower []) _ = zero
(*) _ (Tower []) = zero
(*) x y = liftA2 (*) (flip (,)) x y
negate = liftLin negate
abs = liftA1 abs
(\x->let x0 = primal x
in
Expand Down Expand Up @@ -359,8 +374,8 @@ instance Floating a => Floating (Tower tag a) where
exp = liftA1_ exp const
sqrt = liftA1_ sqrt (const . recip . (2*))
log = liftA1 log recip
-- Bug on zero base, e.g., (0**2), since derivative is fine but
-- can get division by 0 and log 0, oops. Need special cases, ick.
-- Bug on zero base, e.g., diffUU (**2) 0 = NaN, which is wrong.
-- Need special cases to bypass avoidable division by 0 and log 0.
-- Here are some untested ideas:
-- (**) x (Tower []) = 1
-- (**) x y@(Tower [y0]) = liftA1 (**y0) ((y*) . (**(y-1))) x
Expand Down Expand Up @@ -544,6 +559,9 @@ zeroPad = (++ repeat 0)
zeroPadF :: (Num a, Functor f) => [f a] -> [f a]
zeroPadF fxs@(fx:_) = fxs ++ repeat (fmap (const 0) fx)

(!!!!) :: Num a => [a] -> Int -> a
(!!!!) = indexDefault 0

-- | The 'transposePad' function is like Data.List.transpose except
-- that it fills in missing elements with 0 rather than skipping them.
-- It can give a ragged output to a ragged input, but the lengths in
Expand All @@ -561,7 +579,7 @@ transposePadF :: (Num a, Foldable f, Functor f) => f [a] -> [f a]
transposePadF fx =
if Data.Foldable.all null fx
then []
else (fmap ((!!0) . zeroPad) fx) : (transposePadF (fmap (drop 1) fx))
else (fmap (!!!!0) fx) : (transposePadF (fmap (drop 1) fx))

-- The 'transposeF' function transposes w/ infinite zero row padding.

Expand Down

0 comments on commit 67c3f71

Please sign in to comment.