Skip to content

Commit

Permalink
Use OrderedDict for pointwise_loglikelihoods instead of Dict (#475)
Browse files Browse the repository at this point in the history
What the title says!
  • Loading branch information
torfjelde committed May 15, 2023
1 parent 0ffa0e5 commit 3e1204d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.22.3"
version = "0.22.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
44 changes: 25 additions & 19 deletions src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext
end

function PointwiseLikelihoodContext(
likelihoods=Dict{VarName,Vector{Float64}}(),
likelihoods=OrderedDict{VarName,Vector{Float64}}(),
context::AbstractContext=LikelihoodContext(),
)
return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}(
Expand All @@ -20,7 +20,7 @@ function setchildcontext(context::PointwiseLikelihoodContext, child)
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}},
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}},
vn::VarName,
logp::Real,
)
Expand All @@ -30,13 +30,15 @@ function Base.push!(
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}},
vn::VarName,
logp::Real,
)
return context.loglikelihoods[vn] = logp
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}},
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
vn::VarName,
logp::Real,
)
Expand All @@ -46,13 +48,15 @@ function Base.push!(
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
vn::VarName,
logp::Real,
)
return context.loglikelihoods[string(vn)] = logp
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}},
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
vn::String,
logp::Real,
)
Expand All @@ -62,7 +66,9 @@ function Base.push!(
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
vn::String,
logp::Real,
)
return context.loglikelihoods[vn] = logp
end
Expand Down Expand Up @@ -126,11 +132,11 @@ end
"""
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
with keys corresponding to symbols of the observations, and values being matrices
of shape `(num_chains, num_samples)`.
`keytype` specifies what the type of the keys used in the returned `Dict` are.
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
Currently, only `String` and `VarName` are supported.
# Notes
Expand Down Expand Up @@ -179,25 +185,25 @@ julia> model = demo(randn(3), randn());
julia> chain = sample(model, MH(), 10);
julia> pointwise_loglikelihoods(model, chain)
Dict{String,Array{Float64,2}} with 4 entries:
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
OrderedDict{String,Array{Float64,2}} with 4 entries:
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
julia> pointwise_loglikelihoods(model, chain, String)
Dict{String,Array{Float64,2}} with 4 entries:
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
OrderedDict{String,Array{Float64,2}} with 4 entries:
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
julia> pointwise_loglikelihoods(model, chain, VarName)
Dict{VarName,Array{Float64,2}} with 4 entries:
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
OrderedDict{VarName,Array{Float64,2}} with 4 entries:
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
```
## Broadcasting
Expand All @@ -224,7 +230,7 @@ julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
# Get the data by executing the model once
vi = VarInfo(model)
context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}())
context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}())

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
Expand All @@ -237,15 +243,15 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String)

niters = size(chain, 1)
nchains = size(chain, 3)
loglikelihoods = Dict(
loglikelihoods = OrderedDict(
varname => reshape(logliks, niters, nchains) for
(varname, logliks) in context.loglikelihoods
)
return loglikelihoods
end

function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}())
context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}())
model(varinfo, context)
return context.loglikelihoods
end

0 comments on commit 3e1204d

Please sign in to comment.