diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..012e5565 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,57 @@ +name: CI +on: + pull_request: + push: + schedule: + - cron: '44 9 16 * *' # run the cron job one time per month +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.5' + - '1.6' + - '1.7' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 + env: + PYTHON: "" + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v1 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@latest + - uses: julia-actions/julia-runtest@latest + docs: + name: Documentation + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: '1.7' + - name: Install dependencies + run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key + run: julia --project=docs/ docs/make.jl \ No newline at end of file diff --git a/README.md b/README.md index a9f8e29d..dbd2a7d8 100644 --- a/README.md +++ b/README.md @@ -1,100 +1,21 @@ # GraphPPL -GraphPPL.jl is a probabilistic programming language focused on probabilistic graphical models. +| **Documentation** | **Build Status** | +|:-------------------------------------------------------------------------:|:--------------------------------:| +| [![][docs-stable-img]][docs-stable-url] [![][docs-dev-img]][docs-dev-url] | [![DOI][ci-img]][ci-url] | -# Inference Backend - -GraphPPL.jl does not export any Bayesian inference backend. It provides a simple DSL parser and model generation helpers. To run inference on -generated models user needs to have a Bayesian inference backend with GraphPPL.jl support (e.g. [ReactiveMP.jl](https://github.com/biaslab/ReactiveMP.jl)). - -# Examples - -## Coin flip +[docs-dev-img]: https://img.shields.io/badge/docs-dev-blue.svg +[docs-dev-url]: https://biaslab.github.io/GraphPPL.jl/dev -```julia -@model function coin_model() - a = datavar(Float64) - b = datavar(Float64) - y = datavar(Float64) - - θ ~ Beta(a, b) - y ~ Bernoulli(θ) - - return y, a, b, θ -end -``` +[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg +[docs-stable-url]: https://biaslab.github.io/GraphPPL.jl/stable -## State Space Model +[ci-img]: https://github.com/biaslab/GraphPPL.jl/actions/workflows/ci.yml/badge.svg?branch=master +[ci-url]: https://github.com/biaslab/GraphPPL.jl/actions -```julia -@model function ssm(n, θ, x0, Q::ConstVariable, P::ConstVariable) - - x = randomvar(n) - y = datavar(Vector{Float64}, n) - - x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0)) - - x_prev = x_prior - - A = constvar([ cos(θ) -sin(θ); sin(θ) cos(θ) ]) - - for i in 1:n - x[i] ~ MvNormalMeanCovariance(A * x_prev, Q) - y[i] ~ MvNormalMeanCovariance(x[i], P) - - x_prev = x[i] - end - - return x, y -end -``` +GraphPPL.jl is a probabilistic programming language focused on probabilistic graphical models. This repository is aimed for advanced users, please refer to the [ReactiveMP.jl](https://github.com/biaslab/ReactiveMP.jl) repository for more comprehensive and self-contained documentation and usages examples. -## Hidden Markov Model - -```julia -@model [ default_factorisation = MeanField() ] function transition_model(n) - - A ~ MatrixDirichlet(ones(3, 3)) - B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ]) - - s_0 ~ Categorical(fill(1.0 / 3.0, 3)) - - s = randomvar(n) - x = datavar(Vector{Float64}, n) - - s_prev = s_0 - - for t in 1:n - s[t] ~ Transition(s_prev, A) where { q = q(out, in)q(a) } - x[t] ~ Transition(s[t], B) - s_prev = s[t] - end - - return s, x, A, B -end -``` - -## Gaussian Mixture Model +# Inference Backend -```julia -@model [ default_factorisation = MeanField() ] function gaussian_mixture_model(n) - - s ~ Beta(1.0, 1.0) - - m1 ~ NormalMeanVariance(-2.0, 1e3) - w1 ~ GammaShapeRate(0.01, 0.01) - - m2 ~ NormalMeanVariance(2.0, 1e3) - w2 ~ GammaShapeRate(0.01, 0.01) - - z = randomvar(n) - y = datavar(Float64, n) - - for i in 1:n - z[i] ~ Bernoulli(s) - y[i] ~ NormalMixture(z[i], (m1, m2), (w1, w2)) - end - - return s, m1, w1, m2, w2, z, y -end -``` \ No newline at end of file +GraphPPL.jl does not export any Bayesian inference backend. It provides a simple DSL parser, model generation, constraints specification and meta specification helpers. To run inference on +generated models user needs to have a Bayesian inference backend with GraphPPL.jl support (e.g. [ReactiveMP.jl](https://github.com/biaslab/ReactiveMP.jl)). \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index 104fa4a6..1f877127 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,6 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GraphPPL = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c" + +[compat] +Documenter = "0.27.7" diff --git a/docs/src/index.md b/docs/src/index.md index b4970b50..b58763c2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,11 +2,21 @@ Welcome to the documentation for GraphPPL.jl. +Useful links: + +- [`ReactiveMP.jl` documentation](https://biaslab.github.io/ReactiveMP.jl/stable/) +- [User guide: Model specification](@ref user-guide-model-specification) +- [User guide: Constraints specification](@ref user-guide-constraints-specification) +- [User guide: Meta specification](@ref user-guide-meta-specification) + + + ## Table of Contents ```@contents Pages = [ - "user-guide.md" + "user-guide.md", + "transformation-steps.md" ] Depth = 2 ``` diff --git a/docs/src/transformation-steps.md b/docs/src/transformation-steps.md index 61035f0f..7d24a536 100644 --- a/docs/src/transformation-steps.md +++ b/docs/src/transformation-steps.md @@ -1,4 +1,4 @@ -# Transformation steps +# Model specification transformation steps for the ReactiveMP.jl backend ## Step 1: Normalizarion of `~` operator node arguments @@ -14,11 +14,7 @@ is translated to lhs ~ Node(..., var"#anonymous" ~ f(...), ...) ``` -The only one exception is reference expression of the form `x[f(i)]` which are left untouched. - -This step is recursive from top to bottom. - -This step forces model to create an anonymous node for any inner function call within `~` operator expression. In some cases backend can (and will) optimize this inner anonymous nodes into just function calls. E.g. following example won't create any additional nodes in the model +The only one exception is reference expression of the form `x[f(i)]` which are left untouched. This step forces model to create an anonymous node for any inner function call within `~` operator expression. In some cases ReactiveMP.jl backend can (and will) optimize this inner anonymous nodes into just function calls. E.g. following example won't create any additional nodes in the model ```julia precision = 1.0 @@ -32,56 +28,121 @@ noise ~ NormalMeanVariance(noise_mean, 1.0 / precision) # Since 1.0 and precisio Any expression of the form -``` -datavar(args...) +```julia +y = datavar(args...) # empty options here +# or +y = datavar(args...) where { options... } ``` is translated to ``` -datavar(var"#model", ensure_type(args[1]), args[2:end]...) +y = datavar(var"#model", options, :y, ensure_type(args[1]), args[2:end]...) ``` -where `var"#model"` references to an anonymous model variable, `ensure_type` function ensures that the first argument is a valid type object, rest of the arguments are left untouched. +where `var"#model"` references to an hidden model variable, `ensure_type` function ensures that the first argument is a valid type object, rest of the arguments are left untouched. -This step is recursive from top to bottom. +The list of possible options: +- `subject`: specifies a subject that will be used to pass data variable related information, see more info in `Rocket.jl` documentation. +- `allow_missing`: boolea flag that controls is is possible to pass `missing` data or not ### `randomvar()` transformation Any expression of the form -``` -randomvar(args...) +```julia +x = randomvar(args...) # empty options here +# or +x = randomvar(args...) where { options... } ``` is translated to ``` -randomvar(var"#model", args...) +x = randomvar(var"#model", options, :x, args...) ``` -where `var"#model"` references to an anonymous model variable, arguments are left untouched. +where `var"#model"` references to an anonymous model variable, arguments are left untouched. -This step is recursive from top to bottom. +The list of possible options (see ReactiveMP.jl documentation for more info about these options): +- `pipeline` +- `prod_constraint` +- `prod_strategy` +- `marginal_form_constraint` +- `marginal_form_check_strategy` +- `messages_form_constraint` +- `messages_form_check_strategy` ### `constvar()` transformation Any expression of the form -``` -constvar(args...) +```julia +c = constvar(args...) # constvar's do not support any extra options flags ``` is translated to ``` -constvar(var"#model", args...) +c = constvar(var"#model", :c, args...) ``` where `var"#model"` references to an anonymous model variable, arguments are left untouched. -This step is recursive from top to bottom. +## Step 3: Tilde pass + +### 3.0 Node reference pass + +All expression of the form + +```julia +variable ~ Node(args...) +``` + +are translated to + +```julia +node, variable ~ Node(args...) +``` + +### 3.1 Node options pass + +All expressions of the form + +```julia +node, variable ~ Node(args...) where { options... } +``` + +are translated to + +```julia +node, variable ~ Node(args...; options...) +``` + +### 3.2 Functional relations pass + +All expression of the form + +```julia +node, variable ~ Node(args...; options...) +``` + +represent a valid functional dependency between `variable` and `args...`. There are 2 options for further modification of this expression: + +1. If `variable` has been created before with the help of `datavar()` or `randomvar()` functions the previous expression is translated to: + +```julia +node = make_node(var"#model", options, variable, args...) +``` + +2. If `variable` has not been created before the expression is translated to: + +```julia +node = make_node(var"#model", options, AutoVar(:variable), args...) +``` + +that internally creates a new variable in the model. -### `~` operator transformation +## Step 4: Final pass -WIP \ No newline at end of file +During the final pass `GraphPPL.jl` inject before any `return ...` call (and also at the very end) the `activate!` call to the `var#"model"` \ No newline at end of file diff --git a/docs/src/user-guide.md b/docs/src/user-guide.md index 77131d89..3fe75522 100644 --- a/docs/src/user-guide.md +++ b/docs/src/user-guide.md @@ -2,30 +2,34 @@ Probabilistic models incorporate elements of randomness to describe an event or phenomenon by using random variables and probability theory. A probabilistic model can be represented visually by using probabilistic graphical models (PGMs). A factor graph is a type of PGM that is well suited to cast inference tasks in terms of graphical manipulations. -`GraphPPL.jl` is a Julia package presenting a model specification language for probabilistic models. +`GraphPPL.jl` is a Julia package presenting a model specification language for probabilistic models. -## Model specification +!!! note + `GraphPPL.jl` does not work without extra "backend" package. Currently the only one available "backend" package is `ReactiveMP.jl`. + +## [Model specification](@id user-guide-model-specification) -The `GraphPPL.jl` package exports the `generate_model_expression` function for model specification. -`ReactiveMP.jl` package later than imports `generate_model_expression` and reexports it as a `@model` macro for model specification. -This `@model` macro accepts two arguments: model options and the model specification itself in a form of regular Julia function. For example: +The `GraphPPL.jl` package exports the `@model` macro for model specification. This `@model` macro accepts two arguments: model options and the model specification itself in a form of regular Julia function. For example: ```julia -@model [ option1 = ..., option2 = ... ] function model_name(model_arguments...) +@model [ option1 = ..., option2 = ... ] function model_name(model_arguments...; model_keyword_arguments...) # model specification here return ... end ``` -Model options are optional and may be omitted: +Model options, `model_arguments` and `model_keyword_arguments` are optional and may be omitted: ```julia -@model function model_name(model_arguments...) +@model function model_name() # model specification here return ... end ``` +!!! note + `options`, `constraints` and `meta` keyword arguments are reserved and cannot be used in `model_keyword_arguments`. + The `@model` macro returns a regular Julia function (in this example `model_name()`) which can be executed as usual. It returns a reference to a model object itself and a tuple of a user specified return variables, e.g: ```julia @@ -40,7 +44,7 @@ end model, (x, y) = my_model(model_arguments...) ``` -It is also important to note that any model should return something, such as variables or nodes. If a model doesn't return anything then an error will be raised during runtime. +It is not necessary to return anything from the model, in that case `GraphPPL.jl` will automatically inject `return nothing` to the end of the model function. ## A full example before diving in @@ -94,6 +98,9 @@ Additionally you can specify an extra `::ConstVariable` type for some of the mod end ``` +!!! note + `::ConstVariable` annotation does not play role in Julia's multiple dispatch. `GraphPPL.jl` removes this annotation and replaces it with `::Any`. + ### Data variables It is important to have a mechanism to pass data values to the model. You can create data inputs with `datavar()` function. As a first argument it accepts a type specification and optional dimensionality (as additional arguments or as a tuple). @@ -107,6 +114,8 @@ y = datavar(Float64, n, m) # Returns a matrix of `y_i_j` data input objects with y = datavar(Float64, (n, m)) # It is also possible to use a tuple for dimensionality ``` +`datavar()` call supports `where { options... }` block for extra options specification. Read `ReactiveMP.jl` documentation to know more about possible creation options. + ### Random variables There are several ways to create random variables. The first one is an explicit call to `randomvar()` function. By default it doesn't accept any argument, creates a single random variable in the model and returns it. It is also possible to pass dimensionality arguments to `randomvar()` function in the same way as for the `datavar()` function. @@ -120,6 +129,8 @@ x = randomvar(n, m) # Returns a matrix of random variables with size `(n, m)` x = randomvar((n, m)) # It is also possible to use a tuple for dimensionality ``` +In the same way as `datavar()` function, `randomvar()` options supports `where { options... }` block for exxtra options. Read `ReactiveMP.jl` documentation to know more about possible creation options. + The second way to create a random variable is to create a node with the `~` operator. If the random variable has not yet been created before this call, it will be created automatically during the creation of the node. Read more about the `~` operator below. ## Node creation @@ -176,7 +187,7 @@ Example: y ~ NormalMeanVariance(y_mean, y_var) where { q = q(y_mean)q(y_var)q(y) } # mean-field factorisation over q ``` -A list of all available options is presented below: +A list of some of the available options specific to `ReactiveMP.jl` is presented below. For the full list we refer the reader to the `ReactiveMP.jl` documentation. #### Factorisation constraint option @@ -243,10 +254,166 @@ Is is possible to pass any extra metadata to a factor node with the `meta` optio z ~ f(x, y) where { meta = ... } ``` -#### Portal option +For more information about possible node creation options we refer the reader to the `ReactiveMP.jl` documentation. + +## [Constraints specification](@id user-guide-constraints-specification) + +`GraphPPL.jl` exports `@constraints` macro for the extra constraints specification that can be used during the inference step in `ReactiveMP.jl` package. + +### General syntax + +`@constraints` macro accepts both regular julia functions and just simple blocks. In the first case it returns a function that return constraints and in the second case it returns constraints directly. + +```julia +myconstraints = @constraints begin + q(x) :: PointMass + q(x, y) = q(x)q(y) +end +``` + +or + +```julia +@constraints function make_constraints(flag) + q(x) :: PointMass + if flag + q(x, y) = q(x)q(y) + end +end + +myconstraints = make_constraints(true) +``` + +### Marginal and messages form constraints -To assign a factor node's local portal for all outbound messages the user may use a `portal` option: +To specify marginal or messages form constraints `@constraints` macro uses `::` operator (in the similar way as Julia uses it for type specification) + +The following constraint + +```julia +@constraints begin + q(x) :: PointMass +end +``` + +indicates that the resulting marginal of the variable (or array of variables) named `x` must be approximated with a `PointMass` object. To set messages form constraint `@constraints` macro uses `μ(...)` instead of `q(...)`: + +```julia +@constraints begin + q(x) :: PointMass + μ(x) :: SampleList + # it is possible to assign different form constraints on the same variable + # both for the marginal and for the messages +end +``` + +`@constraints` macro understands "stacked" form constraints. For example the following form constraint + +```julia +@constraints begin + q(x) :: SampleList(1000, LeftProposal()) :: PointMass +end +``` + +indicates that the resulting posterior first maybe approximated with a `SampleList` and in addition the result of this approximation should be approximated as a `PointMass`. +For more information about form constraints we refer the reader to the `ReactiveMP.jl` documentation. + + +### Factorisation constraints on posterior distribution `q()` + +`@model` macro specifies generative model `p(s, y)` where `s` is a set of random variables and `y` is a set of obseervations. In a nutshell the goal of probabilistic programming is to find `p(s|y)`. `p(s|y)` during the inference procedure can be approximated with another `q(s)` using e.g. KL divergency. By default there are no extra factorisation constraints on `q(s)` and the result is `q(s) = p(s|y)`. However, inference may be not tractable for every model without extra factorisation constraints. To circumvent this, `GraphPPL.jl` and `ReactiveMP.jl` accepts optional factorisation constraints specification syntax: + +For example: ```julia -y ~ NormalMeanVariance(m, v) where { portal = LoggerPortal() } # Log all outbound messages with `LoggerPortal` portal +@constraints begin + q(x, y) = q(x)q(y) +end ``` + +specifies a so-called mean-field assumption on variables `x` and `y` in the model. Futhermore, if `x` is an array of variables in our model we may induce extra mean-field assumption on `x` in the following way. + +```julia +@constraints begin + q(x, y) = q(x)q(y) + q(x) = q(x[begin])..q(x[end]) +end +``` + +These constraints specifies a mean-field assumption between variables `x` and `y` (either single variable or collection of variables) and additionally specifies mean-field assumption on variables `x_i`. + +!!! note + `@constraints` macro does not support matrix-based collections of variables. E.g. it is not possible to write `q(x[begin, begin])..q(x[end, end])` + +It is possible to write more complex factorisation constraints, for example: + +```julia +@constraints begin + q(x, y) = q(x[begin], y[begin])..q(x[end], y[end]) +end +``` + +Specifies a mean-field assumption between collection of variables named `x` and `y` only for variables with different indices. Another example is + +```julia +@constraints function make_constraints(k) + q(x) = q(x[begin:k])q(x[k+1:end]) +end +``` + +In this example we specify a mean-field assumption between a set of variables `x[begin:k]` and `x[k+1:end]`. + +To create a model with extra constraints user may use optional `constraints` keyword argument for the model function: + +```julia +@model function my_model(arguments...) + ... +end + +constraints = @constraints begin + ... +end + +model, (x, y) = model_name(arguments..., constraints = constraints) +``` + +For more information about factorisation constraints we refer the reader to the `ReactiveMP.jl` documentation. + +## [Meta specification](@id user-guide-meta-specification) + +Some nodes in `ReactiveMP.jl` accept optional meta structure that may be used to change or customise the inference procedure. As an example `GCV` node accepts the approxximation method that will be used to approximate non-conjugate relationships between variables in this node. `GraphPPL.jl` exports `@meta` macro to specify node-specific meta information. For example: + +```julia +meta = @meta begin + GCV(x, k, w) <- GCVMetadata(GaussHermiteCubature(20)) +end +``` + +indicates, that for every `GCV` node in the model that has `x`, `k` and `w` as connected variables the `GCVMetadata(GaussHermiteCubature(20))` meta object should be used. + +`@meta` accepts function expression in the same way as `@constraints` macro, e.g: + + +```julia +@meta make_meta(n) + GCV(x, k, w) <- GCVMetadata(GaussHermiteCubature(n)) +end + +meta = make_meta(20) +``` + +To create a model with extra meta options user may use optional `meta` keyword argument for the model function: + +```julia +@model function my_model(arguments...) + ... +end + +meta = @meta begin + ... +end + +model, (x, y) = model_name(arguments..., meta = meta) +``` + +For more information about the meta specification we refer the reader to the `ReactiveMP.jl` documentation. diff --git a/src/backends/reactivemp.jl b/src/backends/reactivemp.jl index 0d5c7c00..c61b7b3d 100644 --- a/src/backends/reactivemp.jl +++ b/src/backends/reactivemp.jl @@ -8,16 +8,16 @@ function write_argument_guard(::ReactiveMPBackend, argument::Symbol) return :(@assert !($argument isa ReactiveMP.AbstractVariable) "It is not allowed to pass AbstractVariable objects to a model definition arguments. ConstVariables should be passed as their raw values.") end -function write_randomvar_expression(::ReactiveMPBackend, model, varexp, arguments, kwarguments) - return :($varexp = ReactiveMP.randomvar($model, $(GraphPPL.fquote(varexp)), $(arguments...); $(kwarguments...))) +function write_randomvar_expression(::ReactiveMPBackend, model, varexp, options, arguments) + return :($varexp = ReactiveMP.randomvar($model, $options, $(GraphPPL.fquote(varexp)), $(arguments...))) end -function write_datavar_expression(::ReactiveMPBackend, model, varexpr, type, arguments, kwarguments) - return :($varexpr = ReactiveMP.datavar($model, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ GraphPPL.ensure_type($(type)) }, $(arguments...); $(kwarguments...))) +function write_datavar_expression(::ReactiveMPBackend, model, varexpr, options, type, arguments) + return :($varexpr = ReactiveMP.datavar($model, $options, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ GraphPPL.ensure_type($(type)) }, $(arguments...))) end -function write_constvar_expression(::ReactiveMPBackend, model, varexpr, arguments, kwarguments) - return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...); $(kwarguments...))) +function write_constvar_expression(::ReactiveMPBackend, model, varexpr, arguments) + return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...))) end function write_as_variable(::ReactiveMPBackend, model, varexpr) @@ -25,42 +25,57 @@ function write_as_variable(::ReactiveMPBackend, model, varexpr) end function write_make_node_expression(::ReactiveMPBackend, model, fform, variables, options, nodeexpr, varexpr) - return :($nodeexpr = ReactiveMP.make_node($model, $fform, $varexpr, $(variables...); $(options...))) + return :($nodeexpr = ReactiveMP.make_node($model, $options, $fform, $varexpr, $(variables...))) end function write_autovar_make_node_expression(::ReactiveMPBackend, model, fform, variables, options, nodeexpr, varexpr, autovarid) - return :(($nodeexpr, $varexpr) = ReactiveMP.make_node($model, $fform, ReactiveMP.AutoVar($(GraphPPL.fquote(autovarid))), $(variables...); $(options...))) + return :(($nodeexpr, $varexpr) = ReactiveMP.make_node($model, $options, $fform, ReactiveMP.AutoVar($(GraphPPL.fquote(autovarid))), $(variables...))) end -function write_node_options(::ReactiveMPBackend, fform, variables, options) - return map(options) do option +function write_node_options(::ReactiveMPBackend, model, fform, variables, options) + is_factorisation_option_present = false + is_meta_option_present = false + is_pipeline_option_present = false + factorisation_option = :(nothing) + meta_option = :(nothing) + pipeline_option = :(nothing) + + foreach(options) do option # Factorisation constraint option if @capture(option, q = fconstraint_) - return write_fconstraint_option(fform, variables, fconstraint) + !is_factorisation_option_present || error("Factorisation constraint option $(option) for $(fform) has been redefined.") + is_factorisation_option_present = true + factorisation_option = write_fconstraint_option(fform, variables, fconstraint) elseif @capture(option, meta = fmeta_) - return write_meta_option(fform, fmeta) + !is_meta_option_present || error("Meta specification option $(option) for $(fform) has been redefined.") + is_meta_option_present = true + meta_option = write_meta_option(fform, fmeta) elseif @capture(option, pipeline = fpipeline_) - return write_pipeline_option(fform, fpipeline) + !is_pipeline_option_present || error("Pipeline specification option $(option) for $(fform) has been redefined.") + is_pipeline_option_present = true + pipeline_option = write_pipeline_option(fform, fpipeline) + else + error("Unknown option '$option' for '$fform' node") end - - error("Unknown option '$option' for '$fform' node") end + + return :(ReactiveMP.FactorNodeCreationOptions($factorisation_option, $meta_option, $pipeline_option)) end # Meta helper functions function write_meta_option(fform, fmeta) - return :(meta = $fmeta) + return :($fmeta) end # Pipeline helper functions function write_pipeline_option(fform, fpipeline) if @capture(fpipeline, +(stages__)) - return :(pipeline = +($(map(stage -> write_pipeline_stage(fform, stage), stages)...))) + return :(+($(map(stage -> write_pipeline_stage(fform, stage), stages)...))) else - return :(pipeline = $(write_pipeline_stage(fform, fpipeline))) + return :($(write_pipeline_stage(fform, fpipeline))) end end @@ -84,7 +99,7 @@ function write_pipeline_stage(fform, stage) indices = Expr(:tuple, map(s -> :(ReactiveMP.interface_get_index(Val{ $(GraphPPL.fquote(fform)) }, Val{ $(GraphPPL.fquote(first(s))) })), specs)...) initials = Expr(:tuple, map(s -> :($(last(s))), specs)...) - return :(RequireInboundFunctionalDependencies($indices, $initials)) + return :(ReactiveMP.RequireInboundFunctionalDependencies($indices, $initials)) else return stage end @@ -131,11 +146,11 @@ function write_fconstraint_option(form, variables, fconstraint) factorisation = Expr(:tuple, map(f -> Expr(:tuple, f...), indexed)...) errorstr = """Invalid factorisation constraint: ($fconstraint). Arguments are not unique, check node's interface names and model specification variable names.""" - return :(factorisation = GraphPPL.check_uniqueness($factorisation) ? GraphPPL.sorted_factorisation($factorisation) : error($errorstr)) + return :(GraphPPL.check_uniqueness($factorisation) ? GraphPPL.sorted_factorisation($factorisation) : error($errorstr)) elseif @capture(fconstraint, MeanField()) - return :(factorisation = MeanField()) + return :(ReactiveMP.MeanField()) elseif @capture(fconstraint, FullFactorisation()) - return :(factorisation = FullFactorisation()) + return :(ReactiveMP.FullFactorisation()) else error("Invalid factorisation constraint: $fconstraint") end @@ -144,24 +159,104 @@ end ## function write_randomvar_options(::ReactiveMPBackend, variable, options) - return map(options) do option - @capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'") - return option + is_pipeline_option_present = false + is_prod_constraint_option_present = false + is_prod_strategy_option_present = false + is_marginal_form_constraint_option_present = false + is_marginal_form_check_strategy_option_present = false + is_messages_form_constraint_option_present = false + is_messages_form_check_strategy_option_present = false + + pipeline_option = :(nothing) + prod_constraint_option = :(nothing) + prod_strategy_option = :(nothing) + marginal_form_constraint_option = :(nothing) + marginal_form_check_strategy_option = :(nothing) + messages_form_constraint_option = :(nothing) + messages_form_check_strategy_option = :(nothing) + + foreach(options) do option + if @capture(option, pipeline = value_) + !is_pipeline_option_present || error("`pipeline` option $(option) for random variable $(variable) has been redefined.") + is_pipeline_option_present = true + pipeline_option = value + elseif @capture(option, $(:(prod_constraint)) = value_) + !is_prod_constraint_option_present || error("`prod_constraint` option $(option) for random variable $(variable) has been redefined.") + is_prod_constraint_option_present = true + prod_constraint_option = value + elseif @capture(option, $(:(prod_strategy)) = value_) + !is_prod_strategy_option_present || error("`prod_strategy` option $(option) for random variable $(variable) has been redefined.") + is_prod_strategy_option_present = true + prod_strategy_option = value + elseif @capture(option, $(:(marginal_form_constraint)) = value_) + !is_marginal_form_constraint_option_present || error("`marginal_form_constraint` option $(option) for random variable $(variable) has been redefined.") + is_marginal_form_constraint_option_present = true + marginal_form_constraint_option = value + elseif @capture(option, $(:(form_constraint)) = value_) # backward compatibility + @warn "`form_constraint` option is deprecated. Use `marginal_form_constraint` option for variable $(variable) instead." + !is_marginal_form_constraint_option_present || error("`marginal_form_constraint` option $(option) for random variable $(variable) has been redefined.") + is_marginal_form_constraint_option_present = true + marginal_form_constraint_option = value + elseif @capture(option, $(:(marginal_form_check_strategy)) = value_) + !is_marginal_form_check_strategy_option_present || error("`marginal_form_check_strategy` option $(option) for random variable $(variable) has been redefined.") + is_marginal_form_check_strategy_option_present = true + marginal_form_check_strategy_option = value + elseif @capture(option, $(:(messages_form_constraint)) = value_) + !is_messages_form_constraint_option_present || error("`messages_form_constraint` option $(option) for random variable $(variable) has been redefined.") + is_messages_form_constraint_option_present = true + messages_form_constraint_option = value + elseif @capture(option, $(:(messages_form_check_strategy)) = value_) + !is_messages_form_check_strategy_option_present || error("`messages_form_check_strategy` option $(option) for random variable $(variable) has been redefined.") + is_messages_form_check_strategy_option_present = true + messages_form_check_strategy_option = value + else + error("Unknown option '$option' for randomv variable '$variable'.") + end end + + return :(ReactiveMP.RandomVariableCreationOptions( + $pipeline_option, + nothing, # it does not make a lot of sense to override `proxy_variables` option + $prod_constraint_option, + $prod_strategy_option, + $marginal_form_constraint_option, + $marginal_form_check_strategy_option, + $messages_form_constraint_option, + $messages_form_check_strategy_option + )) end -function write_constvar_options(::ReactiveMPBackend, variable, options) - return map(options) do option - @capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'") - return option +function write_datavar_options(::ReactiveMPBackend, variable, type, options) + is_subject_option_present = false + is_allow_missing_option_present = false + + # default options + subject_option = :(nothing) + allow_missing_option = :(Val(false)) + + foreach(options) do option + if @capture(option, subject = value_) + !is_subject_option_present || error("`subject` option $(option) for data variable $(variable) has been redefined.") + is_subject_option_present = true + subject_option = value + elseif @capture(option, $(:(allow_missing)) = value_) + !is_allow_missing_option_present || error("`allow_missing` option $(option) for data variable $(variable) has been redefined.") + is_allow_missing_option_present = true + allow_missing_option = :(Val($value)) + else + error("Unknown option '$option' for data variable '$variable'.") + end end + + return :(ReactiveMP.DataVariableCreationOptions(ReactiveMP.PointMass{ GraphPPL.ensure_type($type) }, $subject_option, $allow_missing_option)) end -function write_datavar_options(::ReactiveMPBackend, variable, options) - return map(options) do option - @capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'") - return option - end +function write_default_model_constraints(::ReactiveMPBackend) + return :(ReactiveMP.UnspecifiedConstraints()) +end + +function write_default_model_meta(::ReactiveMPBackend) + return :(ReactiveMP.UnspecifiedMeta()) end # Constraints specification language @@ -204,8 +299,12 @@ function write_factorisation_functional_index(::ReactiveMPBackend, repr, fn) return :(ReactiveMP.FunctionalIndex{$(QuoteNode(repr))}($fn)) end -function write_form_constraint_specification(::ReactiveMPBackend, T, args, kwargs) - return :(ReactiveMP.FormConstraintsSpecification($T, $args, $kwargs)) +function write_form_constraint_specification_entry(::ReactiveMPBackend, T, args, kwargs) + return :(ReactiveMP.make_form_constraint($T, $args...; $kwargs...)) +end + +function write_form_constraint_specification(::ReactiveMPBackend, specification) + return :(ReactiveMP.FormConstraintSpecification($specification)) end ## Meta specification language diff --git a/src/constraints.jl b/src/constraints.jl index 9c109a32..b98ae100 100644 --- a/src/constraints.jl +++ b/src/constraints.jl @@ -46,7 +46,12 @@ function write_factorisation_splitted_range end function write_factorisation_functional_index end """ - write_form_constraint_specification(backend, T, args, kwargs) + write_form_constraint_specification_entry(backend, T, args, kwargs) +""" +function write_form_constraint_specification_entry end + +""" + write_form_constraint_specification(backend, specification) """ function write_form_constraint_specification end @@ -115,7 +120,7 @@ function parse_form_constraint(backend, expr) end end - return write_form_constraint_specification(backend, T, args, kwargs) + return write_form_constraint_specification_entry(backend, T, args, kwargs) end ## @@ -156,20 +161,24 @@ function generate_constraints_expression(backend, constraints_specification) cs_body = prewalk(cs_body) do expression if iscall(expression, :(::)) if @capture(expression.args[2], q(formsym_Symbol)) - specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + form = write_form_constraint_specification(backend, :(+($(specs... )))) + errstr = "Marginal form constraint q($(formsym)) has been redefined." return quote if haskey($marginals_form_constraints_symbol, $(QuoteNode(formsym))) - error("Marginal form constraint q($(formsym)) has been redefined.") + error($errstr) end - $marginals_form_constraints_symbol = (; $marginals_form_constraints_symbol..., $formsym = ($(specs... ),)) + $marginals_form_constraints_symbol = (; $marginals_form_constraints_symbol..., $formsym = $form) end elseif @capture(expression.args[2], μ(formsym_Symbol)) - specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + form = write_form_constraint_specification(backend, :(+($(specs... )))) + errstr = "Messages form constraint μ($(formsym)) has been redefined." return quote if haskey($messages_form_constraints_symbol, $(QuoteNode(formsym))) - error("Messages form constraint μ($(formsym)) has been redefined.") + error($errstr) end - $messages_form_constraints_symbol = (; $messages_form_constraints_symbol..., $formsym = ($(specs... ),)) + $messages_form_constraints_symbol = (; $messages_form_constraints_symbol..., $formsym = $form) end else error("Invalid form factorisation constraint. $(expression.args[2]) has to be in the form of q(varname) for marginal form constraint or μ(varname) for messages form constraint.") diff --git a/src/model.jl b/src/model.jl index 1bf11704..c7c012cb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -101,17 +101,17 @@ argument_write_default_value(arg, default) = Expr(:kw, arg, default) function write_argument_guard end """ - write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) + write_randomvar_expression(backend, model, varexpr, options, arguments) """ function write_randomvar_expression end """ - write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments) + write_datavar_expression(backend, model, varexpr, options, type, arguments) """ function write_datavar_expression end """ - write_constvar_expression(backend, model, varexpr, arguments, kwarguments) + write_constvar_expression(backend, model, varexpr, arguments) """ function write_constvar_expression end @@ -131,7 +131,7 @@ function write_make_node_expression end function write_autovar_make_node_expression end """ - write_node_options(backend, fform, variables, options) + write_node_options(backend, model, fform, variables, options) """ function write_node_options end @@ -141,14 +141,19 @@ function write_node_options end function write_randomvar_options end """ - write_constvar_options(backend, variable, options) + write_datavar_options(backend, variable, type, options) """ -function write_constvar_options end +function write_datavar_options end """ - write_datavar_options(backend, variable, options) + write_default_model_constraints(backend) """ -function write_datavar_options end +function write_default_model_constraints end + +""" + write_default_model_meta(backend) +""" +function write_default_model_meta end macro model(model_specification) return esc(:(@model [] $model_specification)) @@ -167,7 +172,10 @@ function generate_model_expression(backend, model_options, model_specification) return (name, value) end - ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),))) + ms_constraints = write_default_model_constraints(backend) + ms_meta = write_default_model_meta(backend) + ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),))) + @capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) || error("Model specification language requires full function definition") @@ -206,7 +214,7 @@ function generate_model_expression(backend, model_options, model_specification) end ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id - return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], []) + return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ]) end # Step 0: Check that all inputs are not AbstractVariables @@ -231,17 +239,17 @@ function generate_model_expression(backend, model_options, model_specification) varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr)) return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...))) elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ }) - return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...))) + return :($varexpr = randomvar($(arguments...); $(options...))) elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ }) - return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...))) + return :($varexpr = datavar($(arguments...); $(options...))) elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ }) - return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...))) + return error("Error in the expression $(expression). `constvar()` call does not support `where { }` syntax.") elseif @capture(expression, varexpr_ = randomvar(arguments__)) return :($varexpr = randomvar($(arguments...); )) elseif @capture(expression, varexpr_ = datavar(arguments__)) return :($varexpr = datavar($(arguments...); )) elseif @capture(expression, varexpr_ = constvar(arguments__)) - return :($varexpr = constvar($(arguments...); )) + return :($varexpr = constvar($(arguments...))) else return expression end @@ -264,28 +272,31 @@ function generate_model_expression(backend, model_options, model_specification) # Step 2: Main pass ms_body = postwalk(ms_body) do expression # Step 2.1 Convert datavar calls - if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__)) + if @capture(expression, varexpr_ = datavar(arguments__; options__)) @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" - @assert length(arguments) >= 1 "datavar() call requires type specification as a first argument" + @assert length(arguments) >= 1 "Invalid datavar() creation. datavar(::Type{T}, [ dims... ]) requires type specification as a first argument, but the expression `$(expression)` has no type argument." push!(varids, varexpr) type_argument = arguments[1] tail_arguments = arguments[2:end] + dvoptions = write_datavar_options(backend, varexpr, type_argument, options) - return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments) + return write_datavar_expression(backend, model, varexpr, dvoptions, type_argument, tail_arguments) # Step 2.2 Convert randomvar calls - elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__)) + elseif @capture(expression, varexpr_ = randomvar(arguments__; options__)) @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" push!(varids, varexpr) - return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) + rvoptions = write_randomvar_options(backend, varexpr, options) + + return write_randomvar_expression(backend, model, varexpr, rvoptions, arguments) # Step 2.3 Conver constvar calls - elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__)) + elseif @capture(expression, varexpr_ = constvar(arguments__)) @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" push!(varids, varexpr) - return write_constvar_expression(backend, model, varexpr, arguments, kwarguments) + return write_constvar_expression(backend, model, varexpr, arguments) # Step 2.2 Convert tilde expressions elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) # println(expression) @@ -296,7 +307,7 @@ function generate_model_expression(backend, model_options, model_specification) end variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments) - options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments) + options = write_node_options(backend, model, fform, [ varexpr, arguments... ], kwarguments) if short_id ∈ varids return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) @@ -313,19 +324,23 @@ function generate_model_expression(backend, model_options, model_specification) final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_)) final_pass_target = (x) -> @capture(x, return ret_) + ms_body = quote + $ms_body + return nothing + end + ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression @capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression end res = quote - function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options)) + function $ms_name($(ms_args...); $(ms_kwargs...), constraints = $(ms_constraints), meta = $(ms_meta), options = $(ms_options)) $(ms_args_checks...) options = merge($(ms_options), options) - $model = Model(options) + $model = Model(constraints, meta, options) $(ms_args_const_init_block...) $ms_body - error("'return' statement is missing") end end diff --git a/test/constraints.jl b/test/constraints.jl deleted file mode 100644 index 04495fed..00000000 --- a/test/constraints.jl +++ /dev/null @@ -1,6 +0,0 @@ -module ConstraintsTests - -using Test -using GraphPPL - -end diff --git a/test/runtests.jl b/test/runtests.jl index 10235e05..d69e76b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ doctest(GraphPPL) @test length(Test.detect_ambiguities(GraphPPL)) == 0 end + include("utils.jl") + end end \ No newline at end of file