Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stackoverflow when running inference with Vector{Any} observations. #10

Open
albertpod opened this issue Jan 15, 2024 · 1 comment
Open

Comments

@albertpod
Copy link
Member

I am getting cryptic stackoverflow when running inference with Vector{Any} observations.

@model function example_bug()
    y = datavar(Vector{Any})
    x ~ MvNormalMeanCovariance(zeros(2), diageye(2))
    y ~ MvNormalMeanCovariance(x, diageye(2))
end

result = infer(model=example_bug(), data=(y = Any[1, 2.0], ),)

Running inference yields:

ERROR: StackOverflowError:
Stacktrace:
     [1] mean(itr::PointMass{Vector{Any}})
       @ Statistics ~/.julia/juliaup/julia-1.10.0+0.x64.apple.darwin14/share/julia/stdlib/v1.10/Statistics/src/Statistics.jl:44
     [2] mean(fn::typeof(identity), distribution::PointMass{Vector{Any}})
       @ BayesBase ~/.julia/packages/BayesBase/ZObVB/src/densities/pointmass.jl:26
--- the last 2 lines are repeated 39990 more times ---
 [79983] mean(itr::PointMass{Vector{Any}})
       @ Statistics ~/.julia/juliaup/julia-1.10.0+0.x64.apple.darwin14/share/julia/stdlib/v1.10/Statistics/src/Statistics.jl:44

To circumvent this error, we need to change the data var and data input as follows, which is fine, but the error should be handled better.

@model function example_bug()
    y = datavar(Vector{Float64})
    x ~ MvNormalMeanCovariance(zeros(2), diageye(2))
    y ~ MvNormalMeanCovariance(x, diageye(2))
end

result = infer(model=example_bug(), data=(y = [1, 2.0], ))
@bvdmitri
Copy link
Member

bvdmitri commented Apr 3, 2024

I just checked out this problem (sorry completely forgot about it), thinking it would be easy to fix, but it's not that simple. Basically, we can sort out the mean, but the cov/std/pdf/logpdf methods are trickier. And fixing only the mean is not really a solution because I suppose the rules will call cov as well. The problem is they need one(T) and zero(T) to be defined, which isn't the case when T = Any. So, in the code, we've specifically said T has to be a type of number (Real), but that's causing a confusing error because the method for T != Real doesn't exist. Maybe we should just prevent making a PointMass if T isn't a real number. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants