-
Notifications
You must be signed in to change notification settings - Fork 37
Description
I keep getting confused about how I can and can not condition variables when it comes to IndexLenses, so I wrote a little snippet that makes a table. The idea is to write the model
@model function f_vec()
x ~ MvNormal(fill(mean, 2), I*std^2)
return x
endusing different ~ expressions like x[1] ~ Normal(mean, std) and x[1:2] .~ Normal(mean, std), and then condition them with different VarNames like condition(m, (; x=[0.0, 0.0])) and condition(m, Dict(@varname(x[1]) => 0.0, @varname(x[2]) => 0.0)), and see which combinations actually apply the conditioning.
Snippet:
module Conds
using DynamicPPL, Distributions, LinearAlgebra, PrettyTables
mean = 100.0
std = 1e-6
mv_dist = MvNormal(fill(mean, 2), I*std^2)
uv_dist = Normal(mean, std)
@model function f_vec()
x ~ mv_dist
return x
end
@model function f_range()
x = Vector{Float64}(undef, 2)
x[1:2] ~ mv_dist
return x
end
@model function f_colon()
x = Vector{Float64}(undef, 2)
x[:] ~ mv_dist
return x
end
@model function f_inds()
x = Vector{Float64}(undef, 2)
x[1] ~ uv_dist
x[2] ~ uv_dist
return x
end
@model function f_vec_dot()
x = Vector{Float64}(undef, 2)
x .~ uv_dist
return x
end
@model function f_range_dot()
x = Vector{Float64}(undef, 2)
x[1:2] .~ uv_dist
return x
end
@model function f_colon_dot()
x = Vector{Float64}(undef, 2)
x[:] .~ uv_dist
return x
end
models = [
("x ~", f_vec()),
("x[1:2] ~", f_range()),
("x[:] ~", f_colon()),
("x[1] ~", f_inds()),
("x .~", f_vec_dot()),
("x[1:2] .~", f_range_dot()),
("x[:] .~", f_colon_dot()),
]
conditionings = [
("x = [0.0, 0.0]", m -> condition(m, (; x=[0.0, 0.0]))),
("x[1] = 0.0, x[2] = 0.0", m -> condition(m, Dict(@varname(x[1]) => 0.0, @varname(x[2]) => 0.0))),
("x[1:2] = [0.0, 0.0]", m -> condition(m, Dict(@varname(x[1:2]) => [0.0, 0.0]))),
("x[:] = [0.0, 0.0]", m -> condition(m, Dict(@varname(x[:]) => [0.0, 0.0]))),
]
data = Array{Bool,2}(undef, length(models), length(conditionings))
for (i_model, (model_name, model)) in enumerate(models)
for (i_cond, (cond_name, cond_func)) in enumerate(conditionings)
conditioned_model = cond_func(model)
retval = conditioned_model()
success = all(abs.(retval) .< 1e-10)
data[i_model, i_cond] = success
full_name = "$(model_name) with $(cond_name)"
println("$full_name => $retval")
@show retval .≈ 0.0
end
end
PrettyTables.pretty_table(
data;
column_labels = [cond[1] for cond in conditionings],
row_labels = [model[1] for model in models],
title = "Which ~ expression (rows) supports which conditioning (columns)"
)
endTable:
Which ~ expression (rows) supports which conditioning (columns)
┌───────────┬────────────────┬────────────────────────┬─────────────────────┬───────────────────┐
│ │ x = [0.0, 0.0] │ x[1] = 0.0, x[2] = 0.0 │ x[1:2] = [0.0, 0.0] │ x[:] = [0.0, 0.0] │
├───────────┼────────────────┼────────────────────────┼─────────────────────┼───────────────────┤
│ x ~ │ true │ false │ false │ false │
│ x[1:2] ~ │ true │ false │ true │ false │
│ x[:] ~ │ true │ false │ false │ false │
│ x[1] ~ │ true │ true │ false │ false │
│ x .~ │ true │ true │ false │ false │
│ x[1:2] .~ │ true │ false │ true │ false │
│ x[:] .~ │ true │ false │ false │ false │
└───────────┴────────────────┴────────────────────────┴─────────────────────┴───────────────────┘
It's not great. Conditioning with x = [0.0, 0.0] is the only one that works reliably. x[i] and x[1:2] sometimes work, and x[:] = [0.0, 0.0] doesn't work ever.
There are various reasons for why things are the way they are, but I'm wondering what we want to do about this. Some options:
- Ban ranges in VarNames, so no
x[1:2] ~allowed. - Ban
Colonin VarNames, so nox[:] ~allowed. - Ban both ranges and
Colon. - Write much more complicated logic for checking when we have a conditioned value and when we don't.
Banning ranges feels like a loss of functionality that some users probably enjoy, but banning Colons I'm less sure about. I think it only comes about when you have multidimensional variables and want do things like x[:,i] ~ multivariate_dist, which I presume is quite rare. It's also easy to replace a : with a range.
How hard 4. is to do depends on how we end up implementing VarNamedTuples. I think handling ranges properly shouldn't be too difficult, colons could be a major pain. Colons are also the reason for some code complexity and confusion, like the fact that getvalue and hasvalue take a Distribution as an argument.
I'm currently leaning towards banning Colons and fixing the rest with 4. with the goal of having the whole table show true. Opinions welcome.
I've written this issue as being about conditioning, but I presume fix works exactly the same, and some of these same issues come up internally in places like VarInfo and FastLDF.