Skip to content

Commit

Permalink
More fully quote generated code for static IR
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoct committed Feb 21, 2019
1 parent 8eadda5 commit 1c7da58
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 302 deletions.
16 changes: 8 additions & 8 deletions Manifest.toml
Expand Up @@ -27,9 +27,9 @@ version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a"
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.4.0"
version = "1.5.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
Expand All @@ -47,15 +47,15 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7"
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.3"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad"
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.7"
version = "0.0.10"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
Expand Down Expand Up @@ -220,9 +220,9 @@ version = "0.27.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions", "Test"]
git-tree-sha1 = "d14bb7b03defd2deaa5675646f6783089e0556f0"
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.7.0"
version = "0.8.0"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
Expand Down
156 changes: 78 additions & 78 deletions docs/src/ref/modeling.md
Expand Up @@ -323,107 +323,107 @@ Like distributions, generative functions indicate which of their arguments suppo
It is an error if a tracked value is passed as an argument of a generative function, when differentiation is not supported by the generative function for that argument.
If a generative function `gen_fn` has `accepts_output_grad(gen_fn) = true`, then the return value of the generative function call will be tracked and will propagate further through the caller `@gen` function's computation.

## Differencing code

`@gen` functions may include blocks of *differencing code* annotated with the `@diff` keyword.
Code that is annotated with `@diff` is only executed during one of the [Trace update methods](@ref).
During a trace update operation, `@diff` code is simply inserted inline into the body of the generative function.
Therefore, `@diff` code can read from the state of the non-diff code.
However, the flow of information is one-directional: diff` code is not permitted to affect the state of the regular code.
## Static Modeling Language

`@diff` code is used to compute the retdiff value for the update (see [Retdiff](@ref)) and the argdiff values for calls to generative function calls (see [Argdiff](@ref)).
To compute these values, the `@diff` code has access to special keywords:
The *static modeling language* is a restricted variant of the built-in modeling language.
Models written in the static modeling language can result in better inference performance (more inference operations per second and less memory consumption), than the full built-in modeling language, especially for models used with iterative inference algorithms like Markov chain Monte Carlo.

`@argdiff`, which returns the argdiff that was passed to the update method for the generative function.

`@choicediff`, which returns a value of one of the following types that indicates whether the random choice changed or not:
```@docs
NewChoiceDiff
NoChoiceDiff
PrevChoiceDiff
A function is identified as using the static modeling language by adding the `static` annotation to the function.
For example:
```julia
@gen (static) function foo(prob::Float64)
z1 = @trace(bernoulli(prob), :a)
z2 = @trace(bernoulli(prob), :b)
z3 = z1 || z2
z4 = !z3
return z4
end
```
After running this code, `foo` is a Julia value whose type is a subtype of `StaticIRGenerativeFunction`, which is a subtype of [`GenerativeFunction`](@ref).

### Static computation graph
Using the `static` annotation instructs Gen to statically construct a directed acyclic graph for the computation represented by the body of the function.
For the function `foo` above, the static graph looks like:
```@raw html
<div style="text-align:center">
<img src="../../images/static_graph.png" alt="example static computation graph" width="50%"/>
</div>
```
In this graph, oval nodes represent random choices, square nodes represent Julia computations, and diamond nodes represent arguments.
The light blue shaded node is the return value of the function.
Having access to the static graph allows Gen to generate specialized code for [Trace update operations](@ref) that skips unecessary parts of the computation.
Specifically, when applying an update operation, a the graph is analyzed, and each value in the graph identified as having possibly changed, or not.
Nodes in the graph do not need to be re-executed if none of their input values could have possibly changed.
Also, even if some inputs to a generative function node may have changed, knowledge that some of the inputs have not changed often allows the generative function being called to more efficiently perform its update operation.
This is the case for functions produced by [Generative Function Combinators](@ref).

You can plot the graph for a function with the `static` annotation if you have PyCall installed, and a Python environment that contains the [graphviz](https://pypi.org/project/graphviz/) Python package, using, e.g.:
```julia
using PyCall
@pyimport graphviz
using Gen: draw_graph
draw_graph(foo, graphviz, "test")
```
This will produce a file `test.pdf` in the current working directory containing the rendered graph.

`@calldiff`, which returns a value of one of the following types that provides information about the change in return value from the function:
```@docs
NewCallDiff
NoCallDiff
UnknownCallDiff
CustomCallDiff
### Restrictions

In order to be able to construct the static graph, Gen restricts the permitted syntax that can be used in functions annotated with `static`.
In particular, each statement in the body must be one of the following:

- A pure functional Julia expression on the right-hand side, and a symbol on the left-hand side, e.g.:

```julia
z4 = !z3
```

To set a retdiff value, the `@diff` code uses the `@retdiff` keyword.
- A `@trace` expression on the right-hand side, and a symbol on the left-hand side, e.g.:

**Example.**
In the function below, if the argument is false and the argument did not change, then there is no change to the return value.
If the argument did not change, and :a and :b did not change, then there is no change to the return value.
Otherwise, return an [`DefaultRetDiff`](@ref) value.
```julia
@gen function foo(val::Bool)
val = val && @trace(bernoulli(0.3), :a)
val = val && @trace(bernoulli(0.4), :b)
@diff begin
argdiff = @argdiff()
if argdiff == noargdiff
if !val || (isnodiff(@choicediff(:a)) && isnodiff(@choicediff(:b)))
@retdiff(noretdiff)
else
@retdiff(defaultretdiff)
end
else
@retdiff(defaultretdiff)
end
end
return val
end
z2 = @trace(bernoulli(prob), :b)
```
The trace statement must use a literal Julia symbol for the first component in the address. Unlike the full built-in modeling-language, the address is not optional.

## Static DSL
- A `return` statement, with a literal Julia symbol on the right-hand side, e.g.:

The *Static DSL* supports a subset of the built-in modeling language.
A static DSL function is identified by adding the `static` annotation to the function.
For example:
```julia
@gen (static) function foo(prob::Float64)
z1 = @trace(bernoulli(prob), :a)
z2 = @trace(bernoulli(prob), :b)
z3 = z1 || z2
return z3
end
return z4
```

After running this code, `foo` is a Julia value whose type is a subtype of `StaticIRGenerativeFunction`, which is a subtype of `GenerativeFunction`.
The functions must also satisfy the following rules:

The static DSL permits a subset of the syntax of the built-in modeling language.
In particular, each statement must be one of the following forms:
- `@trace` expressions cannot appear anywhere in the function body except for as the outer-most expression on the right-hand side of a statement.

- `<symbol> = <julia-expr>`
- Each literal symbol used in the left-hand side of a statement must be unique (e.g. you cannot re-assign to a variable).

- `<symbol> = @trace(<dist|gen-fn>(..),<symbol> [ => ..])`
- Julia closures and list comprehensions are not allowed.

- `@trace(<dist|gen-fn>(..),<symbol> [ => ..])`
- For composite addresses (e.g. `:a => 2 => :c`) the first component of the address must be a literal symbol, and there may only be one statement in the function body that uses this symbol for the first component of its address.

- `return <symbol>`
- Julia control flow constructs (e.g. `if`, `for`, `while`) cannot be used as top-level statements in the function body. Control flow should be implemented inside Julia functions that are called, generative functions that are called such as generative functions produced using [Generative Function Combinators](@ref).

Currently, trainable parameters are not supported in static DSL functions.
NOTE: Currently, trainable parameters are not supported in static DSL functions.

### Loading generated functions
Before a function with a static annotation can be used, the [`load_generated_functions`](@ref) method must be called:
```@docs
load_generated_functions
```
Typically, one call to this function, at the top level of a script, separates the definition of generative functions from the execution of inference code, e.g.:
```julia
using Gen: load_generated_functions

Note that the `@trace` keyword may only appear in at the top-level of the right-hand-side expresssion.
Also, addresses used with the `@trace` keyword must be a literal Julia symbol (e.g. `:a`). If multi-part addresses are used, the first component in the multi-part address must be a literal Julia symbol (e.g. `:a => i` is valid).
# define generative functions and inference code
..

Also, symbols used on the left-hand-side of assignment statements must be unique (this is called 'static single assignment' (SSA) form) (this is called 'static single-assignment' (SSA) form).
# allow static generative functions defined above to be used
load_generated_functions()

**Loading generated functions.**
Before a static DSL function can be invoked at runtime, `Gen.load_generated_functions()` method must be called.
Typically, this call immediately preceeds the execution of the inference algorithm.
# run inference code
..
```

**Performance tips.**
For better performance, annotate the left-hand side of random choices with the type.
### Performance tips
For better performance when the arguments are simple data types like `Float64`, annotate the arguments with the concrete type.
This permits a more optimized trace data structure to be generated for the generative function.
For example:
```julia
@gen (static) function foo(prob::Float64)
z1::Bool = @trace(bernoulli(prob), :a)
z2::Bool = @trace(bernoulli(prob), :b)
z3 = z1 || z2
return z3
end
```
83 changes: 46 additions & 37 deletions src/Gen.jl
@@ -1,57 +1,66 @@
#__precompile__(false)

module Gen

const generated_functions = []
function load_generated_functions()
for function_defn in generated_functions
Core.eval(Main, function_defn)
end
const generated_functions = []

"""
load_generated_functions()
Permit use of generative functions written in the static modeling language up to this point.
"""
function load_generated_functions()
for function_defn in generated_functions
Core.eval(Main, function_defn)
end
export generated_functions
end

export load_generated_functions

# built-in extensions to the reverse mode AD
include("backprop.jl")
# built-in extensions to the reverse mode AD
include("backprop.jl")

# addresses and address selections
include("address.jl")
# addresses and address selections
include("address.jl")

# abstract and built-in concrete choice map data types
include("choice_map.jl")
# abstract and built-in concrete choice map data types
include("choice_map.jl")

# a homogeneous trie data type (not for use as choice map)
include("trie.jl")
# a homogeneous trie data type (not for use as choice map)
include("trie.jl")

# generative function interface
include("gen_fn_interface.jl")
# generative function interface
include("gen_fn_interface.jl")

# built-in data types for arg-diff and ret-diff values
include("diff.jl")
# built-in data types for arg-diff and ret-diff values
include("diff.jl")

# built-in probability disributions
include("modeling_library/modeling_library.jl")
# built-in probability disributions
include("modeling_library/modeling_library.jl")

# utilities for parsing
include("dsl_common.jl")
# utilities for parsing
include("dsl_common.jl")

# optimization of trainable parameters
include("optimization.jl")
# optimization of trainable parameters
include("optimization.jl")

# dynamic embedded generative function
include("dynamic/dynamic.jl")
# dynamic embedded generative function
include("dynamic/dynamic.jl")

# static IR generative function
include("static_ir/static_ir.jl")
# static IR generative function
include("static_ir/static_ir.jl")

# DSLs for defining dynamic embedded and static IR generative functions
# 'Dynamic DSL' and 'Static DSL'
include("dsl/dsl.jl")
# DSLs for defining dynamic embedded and static IR generative functions
# 'Dynamic DSL' and 'Static DSL'
include("dsl/dsl.jl")

# injective function DSL (not currently documented)
include("injective.jl")
# injective function DSL (not currently documented)
include("injective.jl")

# selection DSL (not currently documented)
include("selection.jl")
# selection DSL (not currently documented)
include("selection.jl")

# inference and learning library
include("inference/inference.jl")
# inference and learning library
include("inference/inference.jl")

end # module Gen

0 comments on commit 1c7da58

Please sign in to comment.