Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
e8fb37d
init(): constraints spec
bvdmitri Jan 26, 2022
8dfcaaf
feat(): init constraints
bvdmitri Jan 26, 2022
15d1c36
feat(): factorisation constraints WIP
bvdmitri Jan 26, 2022
d4e0171
2prev
bvdmitri Jan 26, 2022
09abc09
feat(): allow single block constraints specification
bvdmitri Jan 27, 2022
b549ff3
feat(): transform x[begin] and x[end]
bvdmitri Jan 27, 2022
f804ca6
feat(): validate factorisation constraints specification
bvdmitri Jan 27, 2022
c9765e5
feat(): misc changes
bvdmitri Jan 27, 2022
a7c94d1
feat(): better indexing with begin/end
bvdmitri Jan 27, 2022
e6c94df
add playground notebook
bvdmitri Jan 28, 2022
7b1e318
update
Jan 29, 2022
7f5d08a
update playground
bvdmitri Feb 9, 2022
21b428b
update
bvdmitri Feb 9, 2022
0d65362
update
bvdmitri Feb 9, 2022
2f984c3
update
bvdmitri Feb 9, 2022
c8bbe0f
2prev
bvdmitri Feb 9, 2022
ac1b93d
2prev
bvdmitri Feb 9, 2022
c4e245d
2prev
bvdmitri Feb 9, 2022
fc7150f
update
bvdmitri Feb 10, 2022
92a7ed8
update()
bvdmitri Feb 10, 2022
7b0a5ee
Support for split operator
bvdmitri Feb 10, 2022
cbfae3b
Better errors
bvdmitri Feb 10, 2022
73ee1d8
some progress
bvdmitri Feb 10, 2022
8afe542
working?
bvdmitri Feb 11, 2022
6cb56ba
update
bvdmitri Feb 11, 2022
0c7ab13
update
bvdmitri Feb 11, 2022
af9d01c
fast and memory efficient implementation
bvdmitri Feb 11, 2022
8a03bab
update
bvdmitri Feb 11, 2022
747275e
update
bvdmitri Feb 12, 2022
2d78e84
code refactoring
bvdmitri Feb 14, 2022
0e4d53e
clusters intersection check
bvdmitri Feb 14, 2022
6a2cac1
more error checks
bvdmitri Feb 14, 2022
ef483c3
better naming
bvdmitri Feb 14, 2022
4d86599
update
bvdmitri Feb 14, 2022
a80f010
remove nested spec feature
Feb 15, 2022
3653907
2prev
Feb 15, 2022
1fe52b1
2prev
Feb 15, 2022
5bd95ce
extra checks
bvdmitri Feb 16, 2022
bd721a7
update
bvdmitri Feb 16, 2022
00de58d
support for proxy variables
bvdmitri Feb 16, 2022
6ba8168
fixes for error printing
bvdmitri Feb 16, 2022
440751c
disallow multiple proxies
bvdmitri Feb 16, 2022
f2e31af
2prev
bvdmitri Feb 16, 2022
1be50da
add support for complex functional expressions within indexing
bvdmitri Feb 18, 2022
18f143e
feat(): some error checks
bvdmitri Feb 18, 2022
55624a4
feat(): better split with indices
bvdmitri Feb 18, 2022
a664ab8
feat(): update
bvdmitri Feb 18, 2022
52b695f
2prev
bvdmitri Feb 20, 2022
75b3ccf
feat(): indices resolution function
bvdmitri Feb 21, 2022
0fc6936
2prev
bvdmitri Feb 21, 2022
750f9b3
feat(): Filter out indexed/ranged external variables
bvdmitri Feb 21, 2022
09281ff
2prev
bvdmitri Feb 21, 2022
d509fb8
feat(): Add internal names check
bvdmitri Feb 21, 2022
6bb9ade
feat(): works, but slow
bvdmitri Feb 21, 2022
e74e693
fix
bvdmitri Feb 21, 2022
000e886
feat(): fix splitted range
bvdmitri Feb 21, 2022
7e2f2a3
bug fixes
bvdmitri Feb 21, 2022
41fba6a
update
bvdmitri Feb 21, 2022
4db8520
fix(): fix performance issues
bvdmitri Feb 21, 2022
eb31e1f
fix(): bug fix with integers and unit range
bvdmitri Feb 21, 2022
ec3d979
feat(): extra checks
bvdmitri Feb 21, 2022
9c8d3c2
fix(): 2prev
bvdmitri Feb 21, 2022
2928a2f
update
bvdmitri Feb 21, 2022
a372957
update(): support for splitted range
bvdmitri Feb 22, 2022
f032e67
fix for factorisation split in constraints language macro
bvdmitri Feb 22, 2022
73691a6
finalisation
bvdmitri Feb 22, 2022
4676b67
copy playground to reactive
bvdmitri Feb 22, 2022
65cd20f
2prev
bvdmitri Feb 22, 2022
56ccc2f
more examples and tests
bvdmitri Feb 22, 2022
3315cf8
more examples and tests
bvdmitri Feb 22, 2022
6cba27a
feat(): factorisation constraints in main code
bvdmitri Feb 23, 2022
9d6cd74
remove playground
bvdmitri Feb 23, 2022
0cc9a9d
feat(): marginals and messages form constraints
bvdmitri Mar 1, 2022
e607544
feat(): meta specification language
bvdmitri Mar 2, 2022
baa954c
wip(): change operator for meta spec
bvdmitri Mar 2, 2022
683a8b5
update(): Bump version to 1.1.0
bvdmitri Mar 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ Coverage.ipynb
**/.DS_Store

examples/*Compiled
statprof
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GraphPPL"
uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
authors = ["Dmitry Bagaev <bvdmitri@gmail.com>"]
version = "1.0.5"
version = "1.1.0"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ makedocs(
sitename = "GraphPPL.jl",
pages = [
"Home" => "index.md",
"User guide" => "user-guide.md"
"User guide" => "user-guide.md",
"Utils" => "utils.md"
],
format = Documenter.HTML(
prettyurls = get(ENV, "CI", nothing) == "true"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Utils

```@docs
GraphPPL.ishead
GraphPPL.isblock
GraphPPL.iscall
```
338 changes: 5 additions & 333 deletions src/GraphPPL.jl
Original file line number Diff line number Diff line change
@@ -1,342 +1,14 @@
module GraphPPL

export @model

import MacroTools
import MacroTools: @capture, postwalk, prewalk, walk

function conditioned_walk(f, condition_skip, condition_apply, x)
walk(x, x -> condition_skip(x) ? x : condition_apply(x) ? f(x) : conditioned_walk(f, condition_skip, condition_apply, x), identity)
end

"""
fquote(expr)

This function forces `Expr` or `Symbol` to be quoted.
"""
fquote(expr::Symbol) = Expr(:quote, expr)
fquote(expr::Int) = expr
fquote(expr::Expr) = expr

"""
ensure_type
"""
ensure_type(x::Type) = x
ensure_type(x) = error("Valid type object was expected but '$x' has been found")

is_kwargs_expression(x) = false
is_kwargs_expression(x::Expr) = x.head === :parameters

"""
parse_varexpr(varexpr)

This function parses variable id and returns a tuple of 3 different representations of the same variable
1. Original expression
2. Short variable identificator (used in variables lookup table)
3. Full variable identificator (used in model as a variable id)
"""
function parse_varexpr(varexpr::Symbol)
varexpr = varexpr
short_id = varexpr
full_id = varexpr
return varexpr, short_id, full_id
end

function parse_varexpr(varexpr::Expr)

# TODO: It might be handy to have this feature in the future for e.g. interacting with UnPack.jl package
# TODO: For now however we fallback to a more informative error message since it is not obvious how to parse such expressions yet
@capture(varexpr, (tupled_ids__, )) &&
error("Multiple variable declarations, definitions and assigments are forbidden within @model macro. Try to split $(varexpr) into several independent statements.")

@capture(varexpr, id_[idx__]) ||
error("Variable identificator can be in form of a single symbol (x ~ ...) or indexing expression (x[i] ~ ...)")

varexpr = varexpr
short_id = id
full_id = Expr(:call, :Symbol, fquote(id), Expr(:quote, :_), Expr(:quote, Symbol(join(idx, :_))))

return varexpr, short_id, full_id
end

"""
normalize_tilde_arguments(args)

This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well.
It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general)
if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes
for constant non-linear transformation expressions or not by analyzing input arguments.
"""
function normalize_tilde_arguments(args)
return map(args) do arg
if @capture(arg, id_[idx_])
return :($(__normalize_arg(id))[$idx])
else
return __normalize_arg(arg)
end
end
end

function __normalize_arg(arg)
if @capture(arg, (f_(v__) where { options__ }) | (f_(v__)))
if f === :(|>)
@assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)"
f = v[2]
v = [ v[1] ]
end
nvarexpr = gensym(:nvar)
nnodeexpr = gensym(:nnode)
options = options !== nothing ? options : []
v = normalize_tilde_arguments(v)
return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr)
else
return arg
end
end

argument_write_default_value(arg, default::Nothing) = arg
argument_write_default_value(arg, default) = Expr(:kw, arg, default)


"""
write_argument_guard(backend, argument)
"""
function write_argument_guard end

"""
write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
"""
function write_randomvar_expression end

"""
write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments)
"""
function write_datavar_expression end

"""
write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
"""
function write_constvar_expression end

"""
write_as_variable(backend, model, varexpr)
"""
function write_as_variable end

"""
write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
"""
function write_make_node_expression end

"""
write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid)
"""
function write_autovar_make_node_expression end

"""
write_node_options(backend, fform, variables, options)
"""
function write_node_options end

"""
write_randomvar_options(backend, variable, options)
"""
function write_randomvar_options end

"""
write_constvar_options(backend, variable, options)
"""
function write_constvar_options end

"""
write_datavar_options(backend, variable, options)
"""
function write_datavar_options end
using MacroTools

include("backends/reactivemp.jl")

__get_current_backend() = ReactiveMPBackend()

macro model(model_specification)
return esc(:(@model [] $model_specification))
end

macro model(model_options, model_specification)
return GraphPPL.generate_model_expression(__get_current_backend(), model_options, model_specification)
end

function generate_model_expression(backend, model_options, model_specification)
@capture(model_options, [ ms_options__ ]) ||
error("Model specification options should be in a form of [ option1 = ..., option2 = ... ]")

ms_options = map(ms_options) do option
(@capture(option, name_ = value_) && name isa Symbol) || error("Invalid option specification: $(option). Expected: 'option_name = option_value'.")
return (name, value)
end

ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),)))

@capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) ||
error("Model specification language requires full function definition")

model = gensym(:model)

ms_args_ids = Vector{Symbol}()
ms_args_guard_ids = Vector{Symbol}()
ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}()

ms_arg_expression_converter = (ms_arg) -> begin
if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable)
# rc_arg = gensym(:constvar)
push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym
push!(ms_args_guard_ids, arg)
push!(ms_args_ids, arg)
return argument_write_default_value(arg, smth)
elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_)
push!(ms_args_guard_ids, arg)
push!(ms_args_ids, arg)
return argument_write_default_value(:($(arg)::$(T)), smth)
elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol)
push!(ms_args_guard_ids, arg)
push!(ms_args_ids, arg)
return argument_write_default_value(arg, smth)
else
error("Invalid argument specification: $(ms_arg)")
end
end

ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args)
ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs)

if length(Set(ms_args_ids)) !== length(ms_args_ids)
error("There are duplicates in argument specification list: $(ms_args_ids)")
end

ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id
return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], [])
end

# Step 0: Check that all inputs are not AbstractVariables
# It is highly recommended not to create AbstractVariables outside of the model creation macro
# Doing so can lead to undefined behaviour
ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids)

# Step 1: Probabilistic arguments normalisation
ms_body = prewalk(ms_body) do expression
if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__)))
options = options === nothing ? [] : options

# Filter out keywords arguments to options array
arguments = filter(arguments) do arg
ifparameters = arg isa Expr && arg.head === :parameters
if ifparameters
foreach(a -> push!(options, a), arg.args)
end
return !ifparameters
end

varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...)))
elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ })
return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...)))
elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ })
return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...)))
elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ })
return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...)))
elseif @capture(expression, varexpr_ = randomvar(arguments__))
return :($varexpr = randomvar($(arguments...); ))
elseif @capture(expression, varexpr_ = datavar(arguments__))
return :($varexpr = datavar($(arguments...); ))
elseif @capture(expression, varexpr_ = constvar(arguments__))
return :($varexpr = constvar($(arguments...); ))
else
return expression
end
end

bannedids = Set{Symbol}()

ms_body = postwalk(ms_body) do expression
if @capture(expression, lhs_ = rhs_)
if !(@capture(rhs, datavar(args__))) && !(@capture(rhs, randomvar(args__))) && !(@capture(rhs, constvar(args__)))
varexpr, short_id, full_id = parse_varexpr(lhs)
push!(bannedids, short_id)
end
end
return expression
end

varids = Set{Symbol}(ms_args_ids)

# Step 2: Main pass
ms_body = postwalk(ms_body) do expression
# Step 2.1 Convert datavar calls
if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
@assert length(arguments) >= 1 "datavar() call requires type specification as a first argument"

push!(varids, varexpr)

type_argument = arguments[1]
tail_arguments = arguments[2:end]

return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments)
# Step 2.2 Convert randomvar calls
elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
push!(varids, varexpr)

return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
# Step 2.3 Conver constvar calls
elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
push!(varids, varexpr)

return write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
# Step 2.2 Convert tilde expressions
elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__))
# println(expression)
varexpr, short_id, full_id = parse_varexpr(varexpr)

if short_id ∈ bannedids
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.")
end

variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments)
options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments)

if short_id ∈ varids
return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
else
push!(varids, short_id)
return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id)
end
else
return expression
end
end

# Step 3: Final pass
final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_))
final_pass_target = (x) -> @capture(x, return ret_)

ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression
@capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression
end

res = quote

function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options))
$(ms_args_checks...)
options = merge($(ms_options), options)
$model = Model(options)
$(ms_args_const_init_block...)
$ms_body
error("'return' statement is missing")
end
end

return esc(res)
end
include("utils.jl")
include("model.jl")
include("constraints.jl")
include("meta.jl")

end # module
Loading