Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,12 @@ function generate_model_expression(backend, model_options, model_specification)
end
return expression
end

varids = Set{Symbol}(ms_args_ids)

# Step 2: Main pass
ms_body = postwalk(ms_body) do expression
# Step 2.1 Convert datavar calls
if @capture(expression, varexpr_ = datavar(arguments__; options__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
@assert length(arguments) >= 1 "The expression `$expression` is incorrect. datavar(::Type, [ dims... ]) requires `Type` as a first argument."

push!(varids, varexpr)

type_argument = arguments[1]
tail_arguments = arguments[2:end]
Expand All @@ -298,34 +293,29 @@ function generate_model_expression(backend, model_options, model_specification)
return write_datavar_expression(backend, model, varexpr, dvoptions, type_argument, tail_arguments)
# Step 2.2 Convert randomvar calls
elseif @capture(expression, varexpr_ = randomvar(arguments__; options__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
push!(varids, varexpr)

rvoptions = write_randomvar_options(backend, varexpr, options)

return write_randomvar_expression(backend, model, varexpr, rvoptions, arguments)
# Step 2.3 Conver constvar calls
elseif @capture(expression, varexpr_ = constvar(arguments__))
@assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated"
push!(varids, varexpr)

return write_constvar_expression(backend, model, varexpr, arguments)
# Step 2.2 Convert tilde expressions
elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__))
# println(expression)

varexpr, short_id, full_id = parse_varexpr(varexpr)

if short_id ∈ bannedids
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.")
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' has been already initialized with '=' operator.")
end

variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments)
options = write_node_options(backend, model, fform, [ varexpr, arguments... ], kwarguments)

if short_id ∈ varids

# Indexed variables like `y[1]` cannot be created on the fly and should be pre-initialised with `y = randomvar(n)`
# Single variables like `y` can be created on the fly with the `AutoVar` marker
# In the second case if variable `y` has been initialised before `AutoVar` should simply return it
if isref(varexpr)
return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
else
push!(varids, short_id)
return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id)
end
else
Expand Down
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ See also: [`ishead`](@ref)
"""
isref(expr) = ishead(expr, :ref)

"""
getref(expr)

Returns ref indices from `expr` in a form of a tuple.

See als: [`isref`](@ref)
"""
getref(expr) = isref(expr) ? (view(expr.args, 2:lastindex(expr.args))...,) : ()

"""
ensure_type(x)

Expand Down