Skip to content

conditioning and IndexLenses #1148

@mhauru

Description

@mhauru

I keep getting confused about how I can and can not condition variables when it comes to IndexLenses, so I wrote a little snippet that makes a table. The idea is to write the model

@model function f_vec()
    x ~ MvNormal(fill(mean, 2), I*std^2)
    return x
end

using different ~ expressions like x[1] ~ Normal(mean, std) and x[1:2] .~ Normal(mean, std), and then condition them with different VarNames like condition(m, (; x=[0.0, 0.0])) and condition(m, Dict(@varname(x[1]) => 0.0, @varname(x[2]) => 0.0)), and see which combinations actually apply the conditioning.

Snippet:

module Conds

using DynamicPPL, Distributions, LinearAlgebra, PrettyTables

mean = 100.0
std = 1e-6
mv_dist = MvNormal(fill(mean, 2), I*std^2)
uv_dist = Normal(mean, std)

@model function f_vec()
    x ~ mv_dist
    return x
end

@model function f_range()
    x = Vector{Float64}(undef, 2)
    x[1:2] ~ mv_dist
    return x
end

@model function f_colon()
    x = Vector{Float64}(undef, 2)
    x[:] ~ mv_dist
    return x
end

@model function f_inds()
    x = Vector{Float64}(undef, 2)
    x[1] ~ uv_dist
    x[2] ~ uv_dist
    return x
end

@model function f_vec_dot()
    x = Vector{Float64}(undef, 2)
    x .~ uv_dist
    return x
end

@model function f_range_dot()
    x = Vector{Float64}(undef, 2)
    x[1:2] .~ uv_dist
    return x
end

@model function f_colon_dot()
    x = Vector{Float64}(undef, 2)
    x[:] .~ uv_dist
    return x
end

models = [
    ("x ~", f_vec()),
    ("x[1:2] ~", f_range()),
    ("x[:] ~", f_colon()),
    ("x[1] ~", f_inds()),
    ("x .~", f_vec_dot()),
    ("x[1:2] .~", f_range_dot()),
    ("x[:] .~", f_colon_dot()),
]
conditionings = [
    ("x = [0.0, 0.0]", m -> condition(m, (; x=[0.0, 0.0]))),
    ("x[1] = 0.0, x[2] = 0.0", m -> condition(m, Dict(@varname(x[1]) => 0.0, @varname(x[2]) => 0.0))),
    ("x[1:2] = [0.0, 0.0]", m -> condition(m, Dict(@varname(x[1:2]) => [0.0, 0.0]))),
    ("x[:] = [0.0, 0.0]", m -> condition(m, Dict(@varname(x[:]) => [0.0, 0.0]))),
]

data = Array{Bool,2}(undef, length(models), length(conditionings))
for (i_model, (model_name, model)) in enumerate(models)
    for (i_cond, (cond_name, cond_func)) in enumerate(conditionings)
        conditioned_model = cond_func(model)
        retval = conditioned_model()
        success = all(abs.(retval) .< 1e-10)
        data[i_model, i_cond] = success
        full_name = "$(model_name) with $(cond_name)"
        println("$full_name => $retval")
        @show retval .≈ 0.0
    end
end

PrettyTables.pretty_table(
    data;
    column_labels = [cond[1] for cond in conditionings],
    row_labels = [model[1] for model in models],
    title = "Which ~ expression (rows) supports which conditioning (columns)"
)

end

Table:

                 Which ~ expression (rows) supports which conditioning (columns)
┌───────────┬────────────────┬────────────────────────┬─────────────────────┬───────────────────┐
│           │ x = [0.0, 0.0] │ x[1] = 0.0, x[2] = 0.0 │ x[1:2] = [0.0, 0.0] │ x[:] = [0.0, 0.0] │
├───────────┼────────────────┼────────────────────────┼─────────────────────┼───────────────────┤
│       x ~ │           true │                  false │               false │             false │
│  x[1:2] ~ │           true │                  false │                true │             false │
│    x[:] ~ │           true │                  false │               false │             false │
│    x[1] ~ │           true │                   true │               false │             false │
│      x .~ │           true │                   true │               false │             false │
│ x[1:2] .~ │           true │                  false │                true │             false │
│   x[:] .~ │           true │                  false │               false │             false │
└───────────┴────────────────┴────────────────────────┴─────────────────────┴───────────────────┘

It's not great. Conditioning with x = [0.0, 0.0] is the only one that works reliably. x[i] and x[1:2] sometimes work, and x[:] = [0.0, 0.0] doesn't work ever.

There are various reasons for why things are the way they are, but I'm wondering what we want to do about this. Some options:

  1. Ban ranges in VarNames, so no x[1:2] ~ allowed.
  2. Ban Colon in VarNames, so no x[:] ~ allowed.
  3. Ban both ranges and Colon.
  4. Write much more complicated logic for checking when we have a conditioned value and when we don't.

Banning ranges feels like a loss of functionality that some users probably enjoy, but banning Colons I'm less sure about. I think it only comes about when you have multidimensional variables and want do things like x[:,i] ~ multivariate_dist, which I presume is quite rare. It's also easy to replace a : with a range.

How hard 4. is to do depends on how we end up implementing VarNamedTuples. I think handling ranges properly shouldn't be too difficult, colons could be a major pain. Colons are also the reason for some code complexity and confusion, like the fact that getvalue and hasvalue take a Distribution as an argument.

I'm currently leaning towards banning Colons and fixing the rest with 4. with the goal of having the whole table show true. Opinions welcome.

I've written this issue as being about conditioning, but I presume fix works exactly the same, and some of these same issues come up internally in places like VarInfo and FastLDF.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions