diff --git a/.gitignore b/.gitignore index 224059d0..8f6b5414 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ Coverage.ipynb **/.DS_Store examples/*Compiled +statprof diff --git a/Project.toml b/Project.toml index 2088272a..74ca4c8f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphPPL" uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c" authors = ["Dmitry Bagaev "] -version = "1.0.5" +version = "1.1.0" [deps] MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/docs/make.jl b/docs/make.jl index 0a2e7383..2a4ddbcd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,7 +6,8 @@ makedocs( sitename = "GraphPPL.jl", pages = [ "Home" => "index.md", - "User guide" => "user-guide.md" + "User guide" => "user-guide.md", + "Utils" => "utils.md" ], format = Documenter.HTML( prettyurls = get(ENV, "CI", nothing) == "true" diff --git a/docs/src/utils.md b/docs/src/utils.md new file mode 100644 index 00000000..dcdd6595 --- /dev/null +++ b/docs/src/utils.md @@ -0,0 +1,7 @@ +# Utils + +```@docs +GraphPPL.ishead +GraphPPL.isblock +GraphPPL.iscall +``` \ No newline at end of file diff --git a/src/GraphPPL.jl b/src/GraphPPL.jl index 36eea4a0..37ad0990 100644 --- a/src/GraphPPL.jl +++ b/src/GraphPPL.jl @@ -1,342 +1,14 @@ module GraphPPL -export @model - -import MacroTools -import MacroTools: @capture, postwalk, prewalk, walk - -function conditioned_walk(f, condition_skip, condition_apply, x) - walk(x, x -> condition_skip(x) ? x : condition_apply(x) ? f(x) : conditioned_walk(f, condition_skip, condition_apply, x), identity) -end - -""" - fquote(expr) - -This function forces `Expr` or `Symbol` to be quoted. -""" -fquote(expr::Symbol) = Expr(:quote, expr) -fquote(expr::Int) = expr -fquote(expr::Expr) = expr - -""" - ensure_type -""" -ensure_type(x::Type) = x -ensure_type(x) = error("Valid type object was expected but '$x' has been found") - -is_kwargs_expression(x) = false -is_kwargs_expression(x::Expr) = x.head === :parameters - -""" - parse_varexpr(varexpr) - -This function parses variable id and returns a tuple of 3 different representations of the same variable -1. Original expression -2. Short variable identificator (used in variables lookup table) -3. Full variable identificator (used in model as a variable id) -""" -function parse_varexpr(varexpr::Symbol) - varexpr = varexpr - short_id = varexpr - full_id = varexpr - return varexpr, short_id, full_id -end - -function parse_varexpr(varexpr::Expr) - - # TODO: It might be handy to have this feature in the future for e.g. interacting with UnPack.jl package - # TODO: For now however we fallback to a more informative error message since it is not obvious how to parse such expressions yet - @capture(varexpr, (tupled_ids__, )) && - error("Multiple variable declarations, definitions and assigments are forbidden within @model macro. Try to split $(varexpr) into several independent statements.") - - @capture(varexpr, id_[idx__]) || - error("Variable identificator can be in form of a single symbol (x ~ ...) or indexing expression (x[i] ~ ...)") - - varexpr = varexpr - short_id = id - full_id = Expr(:call, :Symbol, fquote(id), Expr(:quote, :_), Expr(:quote, Symbol(join(idx, :_)))) - - return varexpr, short_id, full_id -end - -""" - normalize_tilde_arguments(args) - -This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well. -It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general) -if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes -for constant non-linear transformation expressions or not by analyzing input arguments. -""" -function normalize_tilde_arguments(args) - return map(args) do arg - if @capture(arg, id_[idx_]) - return :($(__normalize_arg(id))[$idx]) - else - return __normalize_arg(arg) - end - end -end - -function __normalize_arg(arg) - if @capture(arg, (f_(v__) where { options__ }) | (f_(v__))) - if f === :(|>) - @assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)" - f = v[2] - v = [ v[1] ] - end - nvarexpr = gensym(:nvar) - nnodeexpr = gensym(:nnode) - options = options !== nothing ? options : [] - v = normalize_tilde_arguments(v) - return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr) - else - return arg - end -end - -argument_write_default_value(arg, default::Nothing) = arg -argument_write_default_value(arg, default) = Expr(:kw, arg, default) - - -""" - write_argument_guard(backend, argument) -""" -function write_argument_guard end - -""" - write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) -""" -function write_randomvar_expression end - -""" - write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments) -""" -function write_datavar_expression end - -""" - write_constvar_expression(backend, model, varexpr, arguments, kwarguments) -""" -function write_constvar_expression end - -""" - write_as_variable(backend, model, varexpr) -""" -function write_as_variable end - -""" - write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) -""" -function write_make_node_expression end - -""" - write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid) -""" -function write_autovar_make_node_expression end - -""" - write_node_options(backend, fform, variables, options) -""" -function write_node_options end - -""" - write_randomvar_options(backend, variable, options) -""" -function write_randomvar_options end - -""" - write_constvar_options(backend, variable, options) -""" -function write_constvar_options end - -""" - write_datavar_options(backend, variable, options) -""" -function write_datavar_options end +using MacroTools include("backends/reactivemp.jl") __get_current_backend() = ReactiveMPBackend() -macro model(model_specification) - return esc(:(@model [] $model_specification)) -end - -macro model(model_options, model_specification) - return GraphPPL.generate_model_expression(__get_current_backend(), model_options, model_specification) -end - -function generate_model_expression(backend, model_options, model_specification) - @capture(model_options, [ ms_options__ ]) || - error("Model specification options should be in a form of [ option1 = ..., option2 = ... ]") - - ms_options = map(ms_options) do option - (@capture(option, name_ = value_) && name isa Symbol) || error("Invalid option specification: $(option). Expected: 'option_name = option_value'.") - return (name, value) - end - - ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),))) - - @capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) || - error("Model specification language requires full function definition") - - model = gensym(:model) - - ms_args_ids = Vector{Symbol}() - ms_args_guard_ids = Vector{Symbol}() - ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}() - - ms_arg_expression_converter = (ms_arg) -> begin - if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable) - # rc_arg = gensym(:constvar) - push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym - push!(ms_args_guard_ids, arg) - push!(ms_args_ids, arg) - return argument_write_default_value(arg, smth) - elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_) - push!(ms_args_guard_ids, arg) - push!(ms_args_ids, arg) - return argument_write_default_value(:($(arg)::$(T)), smth) - elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol) - push!(ms_args_guard_ids, arg) - push!(ms_args_ids, arg) - return argument_write_default_value(arg, smth) - else - error("Invalid argument specification: $(ms_arg)") - end - end - - ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args) - ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs) - - if length(Set(ms_args_ids)) !== length(ms_args_ids) - error("There are duplicates in argument specification list: $(ms_args_ids)") - end - - ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id - return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], []) - end - - # Step 0: Check that all inputs are not AbstractVariables - # It is highly recommended not to create AbstractVariables outside of the model creation macro - # Doing so can lead to undefined behaviour - ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids) - - # Step 1: Probabilistic arguments normalisation - ms_body = prewalk(ms_body) do expression - if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__))) - options = options === nothing ? [] : options - - # Filter out keywords arguments to options array - arguments = filter(arguments) do arg - ifparameters = arg isa Expr && arg.head === :parameters - if ifparameters - foreach(a -> push!(options, a), arg.args) - end - return !ifparameters - end - - varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr)) - return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...))) - elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ }) - return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...))) - elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ }) - return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...))) - elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ }) - return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...))) - elseif @capture(expression, varexpr_ = randomvar(arguments__)) - return :($varexpr = randomvar($(arguments...); )) - elseif @capture(expression, varexpr_ = datavar(arguments__)) - return :($varexpr = datavar($(arguments...); )) - elseif @capture(expression, varexpr_ = constvar(arguments__)) - return :($varexpr = constvar($(arguments...); )) - else - return expression - end - end - - bannedids = Set{Symbol}() - - ms_body = postwalk(ms_body) do expression - if @capture(expression, lhs_ = rhs_) - if !(@capture(rhs, datavar(args__))) && !(@capture(rhs, randomvar(args__))) && !(@capture(rhs, constvar(args__))) - varexpr, short_id, full_id = parse_varexpr(lhs) - push!(bannedids, short_id) - end - end - return expression - end - - varids = Set{Symbol}(ms_args_ids) - - # Step 2: Main pass - ms_body = postwalk(ms_body) do expression - # Step 2.1 Convert datavar calls - if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__)) - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" - @assert length(arguments) >= 1 "datavar() call requires type specification as a first argument" - - push!(varids, varexpr) - - type_argument = arguments[1] - tail_arguments = arguments[2:end] - - return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments) - # Step 2.2 Convert randomvar calls - elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__)) - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" - push!(varids, varexpr) - - return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) - # Step 2.3 Conver constvar calls - elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__)) - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" - push!(varids, varexpr) - - return write_constvar_expression(backend, model, varexpr, arguments, kwarguments) - # Step 2.2 Convert tilde expressions - elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) - # println(expression) - varexpr, short_id, full_id = parse_varexpr(varexpr) - - if short_id ∈ bannedids - error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.") - end - - variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments) - options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments) - - if short_id ∈ varids - return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) - else - push!(varids, short_id) - return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id) - end - else - return expression - end - end - - # Step 3: Final pass - final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_)) - final_pass_target = (x) -> @capture(x, return ret_) - - ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression - @capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression - end - - res = quote - - function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options)) - $(ms_args_checks...) - options = merge($(ms_options), options) - $model = Model(options) - $(ms_args_const_init_block...) - $ms_body - error("'return' statement is missing") - end - end - - return esc(res) -end +include("utils.jl") +include("model.jl") +include("constraints.jl") +include("meta.jl") end # module diff --git a/src/backends/reactivemp.jl b/src/backends/reactivemp.jl index 434f45a7..0d5c7c00 100644 --- a/src/backends/reactivemp.jl +++ b/src/backends/reactivemp.jl @@ -162,4 +162,58 @@ function write_datavar_options(::ReactiveMPBackend, variable, options) @capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'") return option end +end + +# Constraints specification language + +## Factorisations constraints specification language + +function write_constraints_specification(::ReactiveMPBackend, factorisation, marginalsform, messagesform) + return :(ReactiveMP.ConstraintsSpecification($factorisation, $marginalsform, $messagesform)) +end + +function write_factorisation_constraint(::ReactiveMPBackend, names, entries) + return :(ReactiveMP.FactorisationConstraintsSpecification($names, $entries)) +end + +function write_factorisation_constraint_entry(::ReactiveMPBackend, names, entries) + return :(ReactiveMP.FactorisationConstraintsEntry($names, $entries)) +end + +function write_init_factorisation_not_defined(::ReactiveMPBackend, spec, name) + return :($spec = ReactiveMP.FactorisationSpecificationNotDefinedYet{$(QuoteNode(name))}()) +end + +function write_check_factorisation_is_not_defined(::ReactiveMPBackend, spec) + return :($spec isa ReactiveMP.FactorisationSpecificationNotDefinedYet) +end + +function write_factorisation_split(::ReactiveMPBackend, left, right) + return :(ReactiveMP.factorisation_split($left, $right)) +end + +function write_factorisation_combined_range(::ReactiveMPBackend, left, right) + return :(ReactiveMP.CombinedRange($left, $right)) +end + +function write_factorisation_splitted_range(::ReactiveMPBackend, left, right) + return :(ReactiveMP.SplittedRange($left, $right)) +end + +function write_factorisation_functional_index(::ReactiveMPBackend, repr, fn) + return :(ReactiveMP.FunctionalIndex{$(QuoteNode(repr))}($fn)) +end + +function write_form_constraint_specification(::ReactiveMPBackend, T, args, kwargs) + return :(ReactiveMP.FormConstraintsSpecification($T, $args, $kwargs)) +end + +## Meta specification language + +function write_meta_specification(::ReactiveMPBackend, entries) + return :(ReactiveMP.MetaSpecification($entries)) +end + +function write_meta_specification_entry(::ReactiveMPBackend, F, N, meta) + return :(ReactiveMP.MetaSpecificationEntry(Val($F), Val($N), $meta)) end \ No newline at end of file diff --git a/src/constraints.jl b/src/constraints.jl new file mode 100644 index 00000000..9c109a32 --- /dev/null +++ b/src/constraints.jl @@ -0,0 +1,316 @@ +export @constraints + +""" + write_constraints_specification(backend, factorisation, marginalsform, messagesform) +""" +function write_constraints_specification end + +""" + write_factorisation_constraint(backend, names, entries) +""" +function write_factorisation_constraint end + +""" + write_factorisation_constraint_entry(backend, names, entries) +""" +function write_factorisation_constraint_entry end + +""" + write_init_factorisation_not_defined(backend, spec, name) +""" +function write_init_factorisation_not_defined end + +""" + write_check_factorisation_is_not_defined(backend, spec) +""" +function write_check_factorisation_is_not_defined end + +""" + write_factorisation_split(backend, left, right) +""" +function write_factorisation_split end + +""" + write_factorisation_combined_range(backend, left, right) +""" +function write_factorisation_combined_range end + +""" + write_factorisation_splitted_range(backend, left, right) +""" +function write_factorisation_splitted_range end + +""" + write_factorisation_functional_index(backend, repr, fn) +""" +function write_factorisation_functional_index end + +""" + write_form_constraint_specification(backend, T, args, kwargs) +""" +function write_form_constraint_specification end + +macro constraints(constraints_specification) + return generate_constraints_expression(__get_current_backend(), constraints_specification) +end + +## Factorisation constraints + +struct FactorisationConstraintLHSInfo + name :: String + hash :: UInt + varname :: Symbol +end + +## + +## Form constraints + +function flatten_functional_form_constraint_specification(expr) + return flatten_functional_form_constraint_specification!(expr, Expr(:(call), :(::))) +end + +function flatten_functional_form_constraint_specification!(symbol::Symbol, toplevel::Expr) + push!(toplevel.args, symbol) + return toplevel +end + +function flatten_functional_form_constraint_specification!(expr::Expr, toplevel::Expr) + if ishead(expr, :(::)) && ishead(expr.args[1], :(::)) + flatten_functional_form_constraint_specification!(expr.args[1], toplevel) + flatten_functional_form_constraint_specification!(expr.args[2], toplevel) + elseif ishead(expr, :(::)) + push!(toplevel.args, expr.args[1]) + push!(toplevel.args, expr.args[2]) + else + push!(toplevel.args, expr) + end + return toplevel +end + +function parse_form_constraint(backend, expr) + T, args, kwargs = if expr isa Symbol + expr, :(()), :((;)) + else + if @capture(expr, f_(args__; kwargs__)) + f, :(($(args...), )), :((; $(kwargs...), )) + elseif @capture(expr, f_(args__)) + + as = [] + ks = [] + + for arg in args + if ishead(arg, :kw) + push!(ks, arg) + else + push!(as, arg) + end + end + + f, :(($(as...), )), :((; $(ks...), )) + elseif @capture(expr, f_()) + f, :(()), :((;)) + else + error("Unssuported form constraints call specification in the expression `$(expr)`") + end + end + + return write_form_constraint_specification(backend, T, args, kwargs) +end + +## + +function generate_constraints_expression(backend, constraints_specification) + + if isblock(constraints_specification) + generatedfname = gensym(:constraints) + generatedfbody = :(function $(generatedfname)() $constraints_specification end) + return :($(generate_constraints_expression(backend, generatedfbody))()) + end + + @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || + error("Constraints specification language requires full function definition") + + cs_args = cs_args === nothing ? [] : cs_args + cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs + + lhs_dict = Dict{UInt, FactorisationConstraintLHSInfo}() + + marginals_form_constraints_symbol = gensym(:marginals_form_constraint) + marginals_form_constraints_symbol_init = :($marginals_form_constraints_symbol = (;)) + + messages_form_constraints_symbol = gensym(:messages_form_constraint) + messages_form_constraints_symbol_init = :($messages_form_constraints_symbol = (;)) + + factorisation_constraints_symbol = gensym(:factorisation_constraint) + factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ()) + + # First we modify form constraints related statements + cs_body = prewalk(cs_body) do expression + if ishead(expression, :(::)) + return flatten_functional_form_constraint_specification(expression) + end + return expression + end + + cs_body = prewalk(cs_body) do expression + if iscall(expression, :(::)) + if @capture(expression.args[2], q(formsym_Symbol)) + specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + return quote + if haskey($marginals_form_constraints_symbol, $(QuoteNode(formsym))) + error("Marginal form constraint q($(formsym)) has been redefined.") + end + $marginals_form_constraints_symbol = (; $marginals_form_constraints_symbol..., $formsym = ($(specs... ),)) + end + elseif @capture(expression.args[2], μ(formsym_Symbol)) + specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args))) + return quote + if haskey($messages_form_constraints_symbol, $(QuoteNode(formsym))) + error("Messages form constraint μ($(formsym)) has been redefined.") + end + $messages_form_constraints_symbol = (; $messages_form_constraints_symbol..., $formsym = ($(specs... ),)) + end + else + error("Invalid form factorisation constraint. $(expression.args[2]) has to be in the form of q(varname) for marginal form constraint or μ(varname) for messages form constraint.") + end + end + return expression + end + + # Second we modify factorisation constraints related statements + # First we record all lhs expression's hash ids and create unique variable names for them + # q(x, y) = q(x)q(y) -> hash(q(x, y)) + # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed + cs_body = postwalk(cs_body) do expression + # We also do a simple sanity check right now, names should be an array of Symbols only + if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__)) + + (length(names) !== 0 && all(name -> name isa Symbol, names)) || + error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""") + + # We replace '..' in RHS expression with `write_factorisation_split` + rhs = postwalk(rhs) do rexpr + if @capture(rexpr, a_ .. b_) + return write_factorisation_split(backend, a, b) + end + return rexpr + end + + lhs_names = Set{Symbol}(names) + rhs_names = Set{Symbol}() + + # We do a simple check to be sure that LHS and RHS has the exact same set of names + # We also check here that all indices are either a simple Symbol or an indexing expression here + rhs = postwalk(MacroTools.prettify(rhs, alias = false)) do entry + if @capture(entry, q(indices__)) + for index in indices + if index isa Symbol + (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.") + push!(rhs_names, index) + (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.") + elseif isref(index) + push!(rhs_names, first(index.args)) + (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.") + else + error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") + end + end + end + return entry + end + + (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.") + + lhs_hash = hash(lhs) + lhs_info = if haskey(lhs_dict, lhs_hash) + lhs_dict[ lhs_hash ] + else + lhs_name = string("q(", join(names, ", "), ")") + lhs_varname = gensym(lhs_name) + lhs_info = FactorisationConstraintLHSInfo(lhs_name, lhs_hash, lhs_varname) + lhs_dict[lhs_hash] = lhs_info + end + + lhs_name = lhs_info.name + lhs_varname = lhs_info.varname + + new_factorisation_specification = write_factorisation_constraint(backend, :(Val(($(map(QuoteNode, names)...),))), :(Val($(rhs)))) + check_is_not_defined = write_check_factorisation_is_not_defined(backend, lhs_varname) + + result = quote + $(check_is_not_defined) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.") + $(lhs_varname) = $(new_factorisation_specification) + $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname)) + end + + return result + end + return expression + end + + # This block write initial variables for factorisation specification + cs_lhs_init_block = map(collect(lhs_dict)) do pair + lhs_info = last(pair) + lhs_name = lhs_info.name + lhs_varname = lhs_info.varname + lhs_symbol = Symbol(lhs_name) + return write_init_factorisation_not_defined(backend, lhs_varname, lhs_symbol) + end + + cs_body = prewalk(cs_body) do expression + if @capture(expression, q(args__)) + rhs_prod_names = Symbol[] + rhs_prod_entries_args = map(args) do arg + if arg isa Symbol + push!(rhs_prod_names, arg) + return :(nothing) + elseif isref(arg) + (length(arg.args) === 2) || error("Indexing expression $(expression) is too difficult to parse and is not supported (yet?).") + push!(rhs_prod_names, first(arg.args)) + + index = last(arg.args) + + # First we replace all `begin` and `end` with `firstindex` and `lastindex` functions + index = postwalk(index) do iexpr + if iexpr isa Symbol && iexpr === :begin + return write_factorisation_functional_index(backend, :begin, :firstindex) + elseif iexpr isa Symbol && iexpr === :end + return write_factorisation_functional_index(backend, :end, :lastindex) + else + return iexpr + end + end + + if @capture(index, a_:b_) + return write_factorisation_combined_range(backend, a, b) + else + return index + end + else + error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") + end + end + + entry = write_factorisation_constraint_entry(backend, :(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(Val(($(rhs_prod_entries_args...), )))) + + return :(($entry, )) + end + return expression + end + + return_specification = write_constraints_specification(backend, factorisation_constraints_symbol, marginals_form_constraints_symbol, messages_form_constraints_symbol) + + res = quote + function $cs_name($(cs_args...); $(cs_kwargs...)) + $(marginals_form_constraints_symbol_init) + $(messages_form_constraints_symbol_init) + $(factorisation_constraints_symbol_init) + $(cs_lhs_init_block...) + $(cs_body) + $(return_specification) + end + end + + return esc(res) +end \ No newline at end of file diff --git a/src/meta.jl b/src/meta.jl new file mode 100644 index 00000000..6ef562d9 --- /dev/null +++ b/src/meta.jl @@ -0,0 +1,91 @@ +export @meta + +""" + write_meta_specification(backend, entries) +""" +function write_meta_specification end + +""" + write_meta_specification_entry(backend, F, N, meta) +""" +function write_meta_specification_entry end + +macro meta(meta_specification) + return generate_meta_expression(__get_current_backend(), meta_specification) +end + +struct MetaSpecificationLHSInfo + hash :: UInt + checkname :: Symbol +end + +function generate_meta_expression(backend, meta_specification) + + if isblock(meta_specification) + generatedfname = gensym(:constraints) + generatedfbody = :(function $(generatedfname)() $meta_specification end) + return :($(generate_meta_expression(backend, generatedfbody))()) + end + + @capture(meta_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || + error("Meta specification language requires full function definition") + + cs_args = cs_args === nothing ? [] : cs_args + cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs + + lhs_dict = Dict{UInt, MetaSpecificationLHSInfo}() + + meta_spec_symbol = gensym(:meta) + meta_spec_symbol_init = :($meta_spec_symbol = ()) + + cs_body = postwalk(cs_body) do expression + if @capture(expression, f_(args__) -> meta_) + + if !issymbol(f) || any(a -> !issymbol(a), args) + error("Invalid meta specification $(expression)") + end + + lhs = :($f($(args...))) + lhs_hash = hash(lhs) + lhs_info = if haskey(lhs_dict, lhs_hash) + lhs_dict[ lhs_hash ] + else + lhs_checkname = gensym(f) + lhs_info = MetaSpecificationLHSInfo(lhs_hash, lhs_checkname) + lhs_dict[lhs_hash] = lhs_info + end + + lhs_checkname = lhs_info.checkname + error_msg = "Meta specification $lhs has been redefined" + meta_entry = write_meta_specification_entry(backend, QuoteNode(f), :(($(map(QuoteNode, args)...), )), meta) + + return quote + ($lhs_checkname) && error($error_msg) + $meta_spec_symbol = ($meta_spec_symbol..., $meta_entry) + $lhs_checkname = true + end + end + return expression + end + + lhs_checknames_init = map(collect(pairs(lhs_dict))) do pair + lhs_info = last(pair) + lhs_checkname = lhs_info.checkname + return quote + $lhs_checkname = false + end + end + + ret_meta_specification = write_meta_specification(backend, meta_spec_symbol) + + res = quote + function $cs_name($(cs_args...); $(cs_kwargs...)) + $meta_spec_symbol_init + $(lhs_checknames_init...) + $cs_body + $ret_meta_specification + end + end + + return esc(res) +end \ No newline at end of file diff --git a/src/model.jl b/src/model.jl new file mode 100644 index 00000000..1bf11704 --- /dev/null +++ b/src/model.jl @@ -0,0 +1,333 @@ +export @model + +import MacroTools: @capture, postwalk, prewalk, walk + +function conditioned_walk(f, condition_skip, condition_apply, x) + walk(x, x -> condition_skip(x) ? x : condition_apply(x) ? f(x) : conditioned_walk(f, condition_skip, condition_apply, x), identity) +end + +""" + fquote(expr) + +This function forces `Expr` or `Symbol` to be quoted. +""" +fquote(expr::Symbol) = Expr(:quote, expr) +fquote(expr::Int) = expr +fquote(expr::Expr) = expr + +""" + ensure_type +""" +ensure_type(x::Type) = x +ensure_type(x) = error("Valid type object was expected but '$x' has been found") + +is_kwargs_expression(x) = false +is_kwargs_expression(x::Expr) = x.head === :parameters + +""" + parse_varexpr(varexpr) + +This function parses variable id and returns a tuple of 3 different representations of the same variable +1. Original expression +2. Short variable identificator (used in variables lookup table) +3. Full variable identificator (used in model as a variable id) +""" +function parse_varexpr(varexpr::Symbol) + varexpr = varexpr + short_id = varexpr + full_id = varexpr + return varexpr, short_id, full_id +end + +function parse_varexpr(varexpr::Expr) + + # TODO: It might be handy to have this feature in the future for e.g. interacting with UnPack.jl package + # TODO: For now however we fallback to a more informative error message since it is not obvious how to parse such expressions yet + @capture(varexpr, (tupled_ids__, )) && + error("Multiple variable declarations, definitions and assigments are forbidden within @model macro. Try to split $(varexpr) into several independent statements.") + + @capture(varexpr, id_[idx__]) || + error("Variable identificator can be in form of a single symbol (x ~ ...) or indexing expression (x[i] ~ ...)") + + varexpr = varexpr + short_id = id + full_id = Expr(:call, :Symbol, fquote(id), Expr(:quote, :_), Expr(:quote, Symbol(join(idx, :_)))) + + return varexpr, short_id, full_id +end + +""" + normalize_tilde_arguments(args) + +This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well. +It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general) +if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes +for constant non-linear transformation expressions or not by analyzing input arguments. +""" +function normalize_tilde_arguments(args) + return map(args) do arg + if @capture(arg, id_[idx_]) + return :($(__normalize_arg(id))[$idx]) + else + return __normalize_arg(arg) + end + end +end + +function __normalize_arg(arg) + if @capture(arg, (f_(v__) where { options__ }) | (f_(v__))) + if f === :(|>) + @assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)" + f = v[2] + v = [ v[1] ] + end + nvarexpr = gensym(:nvar) + nnodeexpr = gensym(:nnode) + options = options !== nothing ? options : [] + v = normalize_tilde_arguments(v) + return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr) + else + return arg + end +end + +argument_write_default_value(arg, default::Nothing) = arg +argument_write_default_value(arg, default) = Expr(:kw, arg, default) + + +""" + write_argument_guard(backend, argument) +""" +function write_argument_guard end + +""" + write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) +""" +function write_randomvar_expression end + +""" + write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments) +""" +function write_datavar_expression end + +""" + write_constvar_expression(backend, model, varexpr, arguments, kwarguments) +""" +function write_constvar_expression end + +""" + write_as_variable(backend, model, varexpr) +""" +function write_as_variable end + +""" + write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) +""" +function write_make_node_expression end + +""" + write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid) +""" +function write_autovar_make_node_expression end + +""" + write_node_options(backend, fform, variables, options) +""" +function write_node_options end + +""" + write_randomvar_options(backend, variable, options) +""" +function write_randomvar_options end + +""" + write_constvar_options(backend, variable, options) +""" +function write_constvar_options end + +""" + write_datavar_options(backend, variable, options) +""" +function write_datavar_options end + +macro model(model_specification) + return esc(:(@model [] $model_specification)) +end + +macro model(model_options, model_specification) + return GraphPPL.generate_model_expression(__get_current_backend(), model_options, model_specification) +end + +function generate_model_expression(backend, model_options, model_specification) + @capture(model_options, [ ms_options__ ]) || + error("Model specification options should be in a form of [ option1 = ..., option2 = ... ]") + + ms_options = map(ms_options) do option + (@capture(option, name_ = value_) && name isa Symbol) || error("Invalid option specification: $(option). Expected: 'option_name = option_value'.") + return (name, value) + end + + ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),))) + + @capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) || + error("Model specification language requires full function definition") + + model = gensym(:model) + + ms_args_ids = Vector{Symbol}() + ms_args_guard_ids = Vector{Symbol}() + ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}() + + ms_arg_expression_converter = (ms_arg) -> begin + if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable) + # rc_arg = gensym(:constvar) + push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym + push!(ms_args_guard_ids, arg) + push!(ms_args_ids, arg) + return argument_write_default_value(arg, smth) + elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_) + push!(ms_args_guard_ids, arg) + push!(ms_args_ids, arg) + return argument_write_default_value(:($(arg)::$(T)), smth) + elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol) + push!(ms_args_guard_ids, arg) + push!(ms_args_ids, arg) + return argument_write_default_value(arg, smth) + else + error("Invalid argument specification: $(ms_arg)") + end + end + + ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args) + ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs) + + if length(Set(ms_args_ids)) !== length(ms_args_ids) + error("There are duplicates in argument specification list: $(ms_args_ids)") + end + + ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id + return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], []) + end + + # Step 0: Check that all inputs are not AbstractVariables + # It is highly recommended not to create AbstractVariables outside of the model creation macro + # Doing so can lead to undefined behaviour + ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids) + + # Step 1: Probabilistic arguments normalisation + ms_body = prewalk(ms_body) do expression + if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__))) + options = options === nothing ? [] : options + + # Filter out keywords arguments to options array + arguments = filter(arguments) do arg + ifparameters = arg isa Expr && arg.head === :parameters + if ifparameters + foreach(a -> push!(options, a), arg.args) + end + return !ifparameters + end + + varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr)) + return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...))) + elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ }) + return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...))) + elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ }) + return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...))) + elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ }) + return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...))) + elseif @capture(expression, varexpr_ = randomvar(arguments__)) + return :($varexpr = randomvar($(arguments...); )) + elseif @capture(expression, varexpr_ = datavar(arguments__)) + return :($varexpr = datavar($(arguments...); )) + elseif @capture(expression, varexpr_ = constvar(arguments__)) + return :($varexpr = constvar($(arguments...); )) + else + return expression + end + end + + bannedids = Set{Symbol}() + + ms_body = postwalk(ms_body) do expression + if @capture(expression, lhs_ = rhs_) + if !(@capture(rhs, datavar(args__))) && !(@capture(rhs, randomvar(args__))) && !(@capture(rhs, constvar(args__))) + varexpr, short_id, full_id = parse_varexpr(lhs) + push!(bannedids, short_id) + end + end + return expression + end + + varids = Set{Symbol}(ms_args_ids) + + # Step 2: Main pass + ms_body = postwalk(ms_body) do expression + # Step 2.1 Convert datavar calls + if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__)) + @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" + @assert length(arguments) >= 1 "datavar() call requires type specification as a first argument" + + push!(varids, varexpr) + + type_argument = arguments[1] + tail_arguments = arguments[2:end] + + return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments) + # Step 2.2 Convert randomvar calls + elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__)) + @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" + push!(varids, varexpr) + + return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) + # Step 2.3 Conver constvar calls + elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__)) + @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" + push!(varids, varexpr) + + return write_constvar_expression(backend, model, varexpr, arguments, kwarguments) + # Step 2.2 Convert tilde expressions + elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) + # println(expression) + varexpr, short_id, full_id = parse_varexpr(varexpr) + + if short_id ∈ bannedids + error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.") + end + + variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments) + options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments) + + if short_id ∈ varids + return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) + else + push!(varids, short_id) + return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id) + end + else + return expression + end + end + + # Step 3: Final pass + final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_)) + final_pass_target = (x) -> @capture(x, return ret_) + + ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression + @capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression + end + + res = quote + + function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options)) + $(ms_args_checks...) + options = merge($(ms_options), options) + $model = Model(options) + $(ms_args_const_init_block...) + $ms_body + error("'return' statement is missing") + end + end + + return esc(res) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..58ccc881 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,42 @@ + +issymbol(::Symbol) = true +issymbol(any) = false + +isexpr(expr::Expr) = true +isexpr(expr) = false + +""" + ishead(expr, head) + +Checks if `expr` has head set to `head`. Returns false if expr is not a valid Julia `Expr` object. +""" +ishead(expr, head) = isexpr(expr) && expr.head === head + +""" + isblock(expr) + +Shorthand for `ishead(expr, :block)` + +See also: [`ishead`](@ref) +""" +isblock(expr) = ishead(expr, :block) + +""" + iscall(expr) + iscall(expr, fsym) + +Shorthand for `ishead(expr, :call)` and arguments length check. If an extra `fsym` argument specified function checks if `fsym` function being called. + +See also: [`ishead`](@ref) +""" +iscall(expr) = ishead(expr, :call) && length(expr.args) >= 1 +iscall(expr, fsym) = iscall(expr) && first(expr.args) === fsym + +""" + isref(expr) + +Shorthand for `ishead(expr, :ref)`. + +See also: [`ishead`](@ref) +""" +isref(expr) = ishead(expr, :ref) \ No newline at end of file diff --git a/test/constraints.jl b/test/constraints.jl new file mode 100644 index 00000000..04495fed --- /dev/null +++ b/test/constraints.jl @@ -0,0 +1,6 @@ +module ConstraintsTests + +using Test +using GraphPPL + +end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..c4f97bda --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,57 @@ +module UtilsTests + +using Test +using GraphPPL + +@testset "issymbol tests" begin + import GraphPPL: issymbol + + @test issymbol(:(f(1))) === false + @test issymbol(:(f(1))) === false + @test issymbol(:(if true 1 else 2 end)) === false + @test issymbol(:hello) === true + @test issymbol(:a) === true + @test issymbol(123) === false +end + +@testset "isexpr tests" begin + import GraphPPL: isexpr + + @test isexpr(:(f(1))) === true + @test isexpr(:(f(1))) === true + @test isexpr(:(if true 1 else 2 end)) === true + @test isexpr(:hello) === false + @test isexpr(123) === false +end + +@testset "ishead tests" begin + import GraphPPL: ishead + + @test ishead(:(f(1)), :call) === true + @test ishead(:(f(1)), :if) === false + @test ishead(:(begin end), :if) === false + @test ishead(:(if true 1 else 2 end), :if) === true + @test ishead(:(begin end), :block) === true +end + +@testset "isblock tests" begin + import GraphPPL: isblock + + @test isblock(:(f(1))) === false + @test isblock(:(if true 1 else 2 end)) === false + @test isblock(:(begin end)) === true +end + + +@testset "iscall tests" begin + import GraphPPL: iscall + + @test iscall(:(f(1))) === true + @test iscall(:(f(1)), :f) === true + @test iscall(:(f(1)), :g) === false + @test iscall(:(if true 1 else 2 end)) === false + @test iscall(:(begin end)) === false + +end + +end