Skip to content

Commit

Permalink
tests now pass on julia 0.5, docs updated
Browse files Browse the repository at this point in the history
  • Loading branch information
fredo-dedup committed Sep 27, 2016
1 parent fd7563e commit bc322cd
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 227 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ This version of automated differentiation operates at the source level (provided
Usage examples:
- derivative of x³
```
julia> rdiff( :(x^3) , x=2.) # 'x=2.' indicates the type of x to rdiff
julia> rdiff( :(x^3) , x=Float64) # 'x=Float64' indicates the type of x to rdiff
:(begin
(x^3,3 * x^2.0) # expression calculates a tuple of (value, derivate)
end)
```

- first 10 derivatives of `sin(x)` (notice the simplifications)
```
julia> rdiff( :(sin(x)) , order=10, x=2.) # derivatives up to order 10
julia> rdiff( :(sin(x)) , order=10, x=Float64) # derivatives up to order 10
:(begin
_tmp1 = sin(x)
_tmp2 = cos(x)
Expand All @@ -36,7 +36,7 @@ Usage examples:
- works on functions too
```
julia> rosenbrock(x) = (1 - x[1])^2 + 100(x[2] - x[1]^2)^2 # function to be derived
julia> rosen2 = rdiff(rosenbrock, (ones(2),), order=2) # orders up to 2
julia> rosen2 = rdiff(rosenbrock, (Vector{Float64},), order=2) # orders up to 2
(anonymous function)
```

Expand All @@ -54,7 +54,6 @@ Usage examples:
w1, w2, w3 = randn(10,10), randn(10,10), randn(1,10)
x1 = randn(10)
dann = rdiff(ann, (w1, w2, w3, x1))
dann = m.rdiff(ann, (Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Vector{Float64}))
dann(w1, w2, w3, x1) # network output + gradient on w1, w2, w3 and x1
```

14 changes: 7 additions & 7 deletions doc/maincall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Arguments

:order: (default = 1) is an integer indicating the derivation order (1 for 1st order, etc.). Order 0 is allowed and will produce an expression that is a processed version of ``ex`` with some variables names rewritten and possibly some optimizations.

:init: (multiple keyword arguments) is one or several symbol / value pairs indicating a reference value for variables appearing in ``ex`` (a reference value for each variable is needed in order to fully evaluate ``ex``, this is a requirement of the derivation algorithm). By default the generated expression will yield the derivative for each variable given unless the variable is listed in the ``ignore`` argument.
:init: (multiple keyword arguments) is one or several symbol / DataType pairs used to indicate for which variable a derivative is needed and how they should be interpreted. By default the generated expression will yield the derivative for each variable given unless the variable is listed in the ``ignore`` argument.

:evalmod: (default=Main) module where the expression is meant to be evaluated. External variables and functions should be evaluable in this module.

Expand All @@ -41,19 +41,19 @@ All the variables appearing in the ``init`` argument are considered as the expre

For orders >= 2 *only a single variable, of type Real or Vector, is allowed*. For orders 0 and 1 variables can be of type Real, Vector or Matrix and can be in an unlimited number::

julia> rdiff( :(x^3) , x=2.) # first order
julia> rdiff( :(x^3) , x=Float64) # first order
:(begin
(x^3,3 * x^2.0)
end)

julia> rdiff( :(x^3) , order = 3, x=2.) # orders up to 3
julia> rdiff( :(x^3) , order=3, x=Float64) # orders up to 3
:(begin
(x^3,3 * x^2.0,2.0 * (x * 3),6.0)
end)

``rdiff`` runs several simplification heuristics on the generated code to remove neutral statements and factorize repeated calculations. For instance calculating the derivatives of ``sin(x)`` for large orders will reduce to the calculations of ``sin(x)`` and ``cos(x)``::

julia> rdiff( :(sin(x)) , order=10, x=2.) # derivatives up to order 10
julia> rdiff( :(sin(x)) , order=10, x=Float64) # derivatives up to order 10
:(begin
_tmp1 = sin(x)
_tmp2 = cos(x)
Expand All @@ -63,17 +63,17 @@ For orders >= 2 *only a single variable, of type Real or Vector, is allowed*. Fo
(_tmp1,_tmp2,_tmp3,_tmp4,_tmp5,_tmp2,_tmp3,_tmp4,_tmp5,_tmp2,_tmp3)
end)

The expression produced can readily be turned into a function with the ``@eval`` macro::
The expression produced can easily be turned into a function with the ``@eval`` macro::

julia> res = rdiff( :(sin(x)) , order=10, x=2.)
julia> res = rdiff( :(sin(x)) , order=10, x=Float64)
julia> @eval foo(x) = $res
julia> foo(2.)
(0.9092974268256817,-0.4161468365471424,-0.9092974268256817,0.4161468365471424,0.9092974268256817,-0.4161468365471424,-0.9092974268256817,0.4161468365471424,0.9092974268256817,-0.4161468365471424,-0.9092974268256817)

When a second derivative expression is needed, only a single derivation variable is allowed. If you are dealing with a function of several (scalar) variables you will have you aggregate them into a vector::

julia> ex = :( (1 - x[1])^2 + 100(x[2] - x[1]^2)^2 ) # the rosenbrock function
julia> res = rdiff(ex, x=zeros(2), order=2)
julia> res = rdiff(ex, x=Vector{Float64}, order=2)
:(begin
_tmp1 = 1
_tmp2 = 2
Expand Down
6 changes: 3 additions & 3 deletions doc/maincall2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Arguments

:func: is a Julia generic function.

:init: is a tuple containing initial values for each parameter of ``func``. These reference values are needed to to fully evaluate ``ex``, this is a requirement of the derivation algorithm). By default the generated expression will yield the derivative for each variable given unless the variable is listed in the ``ignore`` argument.
:init: is a tuple containing the types for each parameter of ``func``. These types are necessary to pick a the right method of the given function. By default the generated expression will yield the derivative for each variable given unless the variable is listed in the ``ignore`` argument.

:order: (keyword arg, default = 1) is an integer indicating the derivation order (1 for 1st order, etc.). Order 0 is allowed and will produce a function that is a processed version of ``ex`` with some variables names rewritten and possibly some optimizations.

Expand All @@ -20,7 +20,7 @@ Arguments

:allorders: (default=true) tells rdiff whether to generate the code for all orders up to ``order`` (true) or only the last order.

:ignore: (default=[]) do not differentiate against the listed variables, useful if you are not interested in having the derivative of one of several variables in ``init``.
:ignore: (default=[]) do not differentiate against the listed variables (identified by their position index), useful if you are not interested in having the derivative of one of several variables in ``init``.


Output
Expand All @@ -35,7 +35,7 @@ Usage
``rdiff`` takes a function defined with the same subset of Julia statements ( assigments, getindex, setindex!, for loops, function calls ) as the Expression variant of ``rdiff()`` and transforms it into another function whose call will return the derivatives at all orders between 0 and the order specified::

julia> rosenbrock(x) = (1 - x[1])^2 + 100(x[2] - x[1]^2)^2 # function to be derived
julia> rosen2 = rdiff(rosenbrock, (ones(2),), order=2) # orders up to 2
julia> rosen2 = rdiff(rosenbrock, (Vector{Float64},), order=2) # orders up to 2
(anonymous function)
julia> rosen2([1,2])
(100,[-400.0,200.0],
Expand Down
96 changes: 48 additions & 48 deletions src/ReverseDiffSource.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,55 @@ __precompile__(false)

module ReverseDiffSource

import Base.show, Base.copy, Base.length

using Compat # for compatibility across julia versions

# naming conventions
const TEMP_NAME = "_tmp" # prefix of new variables
const DERIV_PREFIX = "d" # prefix of gradient variables

## misc functions
dprefix(v::Union{Symbol, AbstractString, Char}) = Symbol("$DERIV_PREFIX$v")

isSymbol(ex) = isa(ex, Symbol)
isDot(ex) = isa(ex, Expr) && ex.head == :. && isa(ex.args[1], Symbol)
isRef(ex) = isa(ex, Expr) && ex.head == :ref && isa(ex.args[1], Symbol)

## temp var name generator
let
vcount = Dict()
global newvar
function newvar(radix::Union{AbstractString, Symbol}=TEMP_NAME)
vcount[radix] = haskey(vcount, radix) ? vcount[radix]+1 : 1
return Symbol("$(radix)$(vcount[radix])")
import Base.show, Base.copy, Base.length

using Compat # for compatibility across julia versions

# naming conventions
const TEMP_NAME = "_tmp" # prefix of new variables
const DERIV_PREFIX = "d" # prefix of gradient variables

## misc functions
dprefix(v::Union{Symbol, AbstractString, Char}) = Symbol("$DERIV_PREFIX$v")

isSymbol(ex) = isa(ex, Symbol)
isDot(ex) = isa(ex, Expr) && ex.head == :. && isa(ex.args[1], Symbol)
isRef(ex) = isa(ex, Expr) && ex.head == :ref && isa(ex.args[1], Symbol)

## temp var name generator
let
vcount = Dict()
global newvar
function newvar(radix::Union{AbstractString, Symbol}=TEMP_NAME)
vcount[radix] = haskey(vcount, radix) ? vcount[radix]+1 : 1
return Symbol("$(radix)$(vcount[radix])")
end

global resetvar
function resetvar()
vcount = Dict()
end
end

global resetvar
function resetvar()
vcount = Dict()
end
end

###### Includes ######
include("node.jl")
include("bidict.jl")
include("graph.jl")
include("plot.jl")
include("simplify.jl")
include("tograph.jl")
include("tocode.jl")
include("zeronode.jl")
include("reversegraph.jl")
include("deriv_rule.jl")
include("base_rules.jl")
include("rdiff.jl")
include("frdiff.jl")

###### Exports ######
export
rdiff,
@deriv_rule, deriv_rule,
@typeequiv, typeequiv
###### Includes ######
include("node.jl")
include("bidict.jl")
include("graph.jl")
include("plot.jl")
include("simplify.jl")
include("tograph.jl")
include("tocode.jl")
include("zeronode.jl")
include("reversegraph.jl")
include("deriv_rule.jl")
include("base_rules.jl")
include("rdiff.jl")
include("frdiff.jl")

###### Exports ######
export
rdiff,
@deriv_rule, deriv_rule,
@typeequiv, typeequiv

end # module ReverseDiffSource
84 changes: 42 additions & 42 deletions src/bidict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,34 @@ import Base: setindex!, getindex, haskey, delete!,


type BiDict{K,V}
kv::Dict{K,V}
vk::Dict{V,K}

BiDict() = new(Dict{K,V}(), Dict{V,K}())

function BiDict(ks, vs)
n = length(ks)
length(unique(ks)) != n && error("Duplicate keys")
length(unique(vs)) != n && error("Duplicate values")
h = BiDict{K,V}()
for i=1:n
h.kv[ks[i]] = vs[i]
h.vk[vs[i]] = ks[i]
kv::Dict{K,V}
vk::Dict{V,K}

BiDict() = new(Dict{K,V}(), Dict{V,K}())

function BiDict(ks, vs)
n = length(ks)
length(unique(ks)) != n && error("Duplicate keys")
length(unique(vs)) != n && error("Duplicate values")
h = BiDict{K,V}()
for i=1:n
h.kv[ks[i]] = vs[i]
h.vk[vs[i]] = ks[i]
end
return h
end
return h
end

function BiDict(d)
n = length(d)
vs = values(d)
length(unique(vs)) != n && error("Duplicate values")
h = BiDict{K,V}()
for (k,v) in d
h.kv[k] = v
h.vk[v] = k

function BiDict(d)
n = length(d)
vs = values(d)
length(unique(vs)) != n && error("Duplicate values")
h = BiDict{K,V}()
for (k,v) in d
h.kv[k] = v
h.vk[v] = k
end
return h
end
return h
end
end

# BiDict() = BiDict{Any,Any}()
Expand All @@ -51,27 +51,27 @@ end

# use first dict (kv) by default for base functions
function setindex!(bd::BiDict, v, k)
k2 = get(bd.vk, v, nothing) # existing key for v ?
if k2 != nothing
delete!(bd.kv, k2)
delete!(bd.vk, v)
end
k2 = get(bd.vk, v, nothing) # existing key for v ?
if k2 != nothing
delete!(bd.kv, k2)
delete!(bd.vk, v)
end

v2 = get(bd.kv, k, nothing) # existing value for k ?
if v2 != nothing
delete!(bd.kv, k)
delete!(bd.vk, v2)
end
v2 = get(bd.kv, k, nothing) # existing value for k ?
if v2 != nothing
delete!(bd.kv, k)
delete!(bd.vk, v2)
end

bd.kv[k] = v
bd.vk[v] = k
bd.kv[k] = v
bd.vk[v] = k
end

function delete!(bd::BiDict, k)
!haskey(bd.kv, k) && error("unknown key $k")
v = bd.kv[k]
delete!(bd.kv, k)
delete!(bd.vk, v)
!haskey(bd.kv, k) && error("unknown key $k")
v = bd.kv[k]
delete!(bd.kv, k)
delete!(bd.vk, v)
end

length(bd::BiDict) = length(bd.kv)
Expand Down
Loading

0 comments on commit bc322cd

Please sign in to comment.