Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difficulty sampling a model with truncated normal Likelihood #1722

Open
dlakelan opened this issue Oct 27, 2021 · 6 comments
Open

Difficulty sampling a model with truncated normal Likelihood #1722

dlakelan opened this issue Oct 27, 2021 · 6 comments

Comments

@dlakelan
Copy link
Contributor

I've been having problems sampling models that use truncated normal distributions.

This minimal worked example from the discourse.julialang.org discussion shows that the initial vector passed in seems to be problematic. The first printout doesn't show the provided initial value, but rather a value with a very small standard deviation of the errors, therefore it immediately has numerical issues, showing zero probability to be at the initial point.

https://discourse.julialang.org/t/making-turing-fast-with-large-numbers-of-parameters/69072/99?u=dlakelan

using Pkg
Pkg.activate(".")
using Turing, DataFrames,DataFramesMeta,LazyArrays,Distributions,DistributionsAD
using LazyArrays, ReverseDiff, Memoization

## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.

nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300

data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))


Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.emptyrdcache()



@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(Normal(150,30),nstaff)
    patientweights ~ filldist(Normal(150,30),npatients)
    
    totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end



@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end


@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end

function truncatenormal(a,b)::UnivariateDistribution
    truncated(Normal(a,b),0.0,Inf)
end


@model function estweights3lazy(nstaff,staffid,npatients,patientid,usewch,totweight)

    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    println("""Evaluating model... 
wcwt: $wcwt
measerr: $measerr
exstaffweights: $(staffweights[1:10])
expatweights: $(patientweights[1:10])
""")
    
    totweight ~ arraydist(LazyArray(@~ truncatenormal.(theta,measerr)))
end



@model function estweights4(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    means = view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt
    totweight .~ Gamma.(12,means./11)
end




@model function estweightslazygamma(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    totweight ~ arraydist(LazyArray(@~ Gamma.(15, theta ./ 14)))
end




# ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


# ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

# ch3 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

ch3l = sample(estweights3lazy(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75,init_ϵ=.002),1000;
              init_theta = vcat([15.0,20.0],staffweights,patientweights))

# ch4 = sample(estweights4(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


#ch5 = sample(estweightslazygamma(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

When running this version, the initial

@devmotion
Copy link
Member

I think this issue is relevant here: #1588

Most importantly, it seems you use the keyword argument init_theta. However, as also discussed in the linked issue, initial parameter values have to be specified with init_params.

@dlakelan
Copy link
Contributor Author

Aha. Well that's disappointing, I've been running pre-optimizations and following the docs to use those optimizations as the starting points...

Ok, so if I use init_params it does seem to evaluate at the initial vector, but then it complains as follows:, first printing a few reals, then suddenly switching to TrackedReals and then complaining about non-finite values, specifically, apparently it thinks the parameter vector has gone non-finite, and so has the log(p) this is even though the printed values of the parameters that I was checking are reasonable.

Evaluating model... 
wcwt: 0.37882044075836946
measerr: 1.2209697179924754
exstaffweights: [144.59651665320027, 139.26250118653655, 150.90239119517722, 171.0364350885989, 128.94644940913446, 116.41230791013996, 153.99
758061107676, 158.27233225657514, 123.29445199120855, 130.95171040790126]
expatweights: [153.84183383480226, 117.09050452304707, 182.46653358442495, 144.2100653423684, 167.11682736617547, 121.57197068693307, 138.8748
8688049697, 167.95514160373503, 140.8483579031441, 168.57633784427617]

Evaluating model... 
wcwt: 15.0
measerr: 20.0
exstaffweights: [128.5120210513865, 188.86912870512919, 166.7190188565341, 144.91910704233473, 158.5787954759802, 193.11692512992863, 133.3731
1167397962, 111.20409894257753, 122.8398842784177, 137.08070752897862]
expatweights: [121.38295137055916, 118.59580388554596, 212.68126054123962, 157.15016895391787, 133.42273212231717, 189.0564931535782, 132.5339
1016542082, 158.18721536473618, 139.25611445597923, 143.00152296009045]

Evaluating model... 
wcwt: 15.0
measerr: 19.999999999999996
exstaffweights: [128.5120210513865, 188.86912870512919, 166.7190188565341, 144.9191070423347, 158.5787954759802, 193.11692512992857, 133.37311
167397962, 111.20409894257753, 122.83988427841771, 137.08070752897862]
expatweights: [121.38295137055916, 118.59580388554596, 212.68126054123962, 157.15016895391787, 133.42273212231717, 189.05649315357823, 132.533
91016542082, 158.1872153647362, 139.25611445597923, 143.00152296009045]

Evaluating model... 
wcwt: TrackedReal<ETu>(15.0, 0.0, D0C, ---)
measerr: TrackedReal<D5S>(19.999999999999996, 0.0, D0C, ---)
exstaffweights: ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}[Tra
ckedReal<7Gz>(128.5120210513865, 0.0, D0C, 1, Cer), TrackedReal<1bK>(188.86912870512919, 0.0, D0C, 2, Cer), TrackedReal<DWX>(166.7190188565341
, 0.0, D0C, 3, Cer), TrackedReal<Bs3>(144.9191070423347, 0.0, D0C, 4, Cer), TrackedReal<71A>(158.5787954759802, 0.0, D0C, 5, Cer), TrackedReal
<FR7>(193.11692512992857, 0.0, D0C, 6, Cer), TrackedReal<1Af>(133.37311167397962, 0.0, D0C, 7, Cer), TrackedReal<LFq>(111.20409894257753, 0.0,
 D0C, 8, Cer), TrackedReal<3FK>(122.83988427841771, 0.0, D0C, 9, Cer), TrackedReal<FNM>(137.08070752897862, 0.0, D0C, 10, Cer)]
expatweights: ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}[Track
edReal<40n>(121.38295137055916, 0.0, D0C, 1, 7Ev), TrackedReal<66Y>(118.59580388554596, 0.0, D0C, 2, 7Ev), TrackedReal<C3o>(212.68126054123962
, 0.0, D0C, 3, 7Ev), TrackedReal<FFa>(157.15016895391787, 0.0, D0C, 4, 7Ev), TrackedReal<IFo>(133.42273212231717, 0.0, D0C, 5, 7Ev), TrackedRe
al<1Dd>(189.05649315357823, 0.0, D0C, 6, 7Ev), TrackedReal<BJt>(132.53391016542082, 0.0, D0C, 7, 7Ev), TrackedReal<7kv>(158.1872153647362, 0.0
, D0C, 8, 7Ev), TrackedReal<Kbf>(139.25611445597923, 0.0, D0C, 9, 7Ev), TrackedReal<PAj>(143.00152296009045, 0.0, D0C, 10, 7Ev)]

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: Incorrect ϵ = NaN; ϵ_previous = 0.2 is used instead.
└ @ AdvancedHMC.Adaptation ~/.julia/packages/AdvancedHMC/HQHnm/src/adaptation/stepsize.jl:125
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: Incorrect ϵ = NaN; ϵ_previous = 0.2 is used instead.
└ @ AdvancedHMC.Adaptation ~/.julia/packages/AdvancedHMC/HQHnm/src/adaptation/stepsize.jl:125
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47

@dlakelan dlakelan changed the title Initial vector seemingly not used? Difficulty sampling a model with truncated normal Likelihood Oct 28, 2021
@noamsgl
Copy link

noamsgl commented Jun 1, 2023

Getting similar errors, but with the tutorial from https://turing.ml/dev/tutorials/10-bayesian-differential-equations/

🎈 sde_lv.jl — Pluto.jl.pdf

@torfjelde
Copy link
Member

A few points on this issue:

  1. When using samplers such as NUTS, which have this initial "adaptation phase", the initial point is the first point use when starting adaptation, but then once we start sampling, we might no longer be at this initial point.
  2. The reason why you're getting complaints about θ is not necessarily because the parameter itself is infininte, but it could also be that the gradient is infinite (the current message does not specify, which is unfortunate 😕).

@noamsgl
Copy link

noamsgl commented Jun 3, 2023

Thanks for the response.

  1. I copied and pasted the exact tutorial from the Turing.jl website, using all of the latest packages. How come we are getting this difference?

  2. In which way can I get a more informative error message, or go about debugging this in general?

@torfjelde
Copy link
Member

Sorry, this went under my radar!

I copied and pasted the exact tutorial from the Turing.jl website, using all of the latest packages. How come we are getting this difference?

Okay, that's weird. Will have a look.

In which way can I get a more informative error message, or go about debugging this in general?

Pfft.. This is a bit difficult without touching internals. But you're right, we should have a good way of debugging these things. Let me think and I'll get back to you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants