Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
`loadstate` is exported from DynamicPPL.

### Change of default keytype of `pointwise_logdensities`

The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.

**Other changes**

### `predict(model, chain; include_all)`
Expand Down
30 changes: 15 additions & 15 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ end
::Val{whichlogprob}=Val(:both),
)

Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
with keys corresponding to symbols of the variables, and values being matrices
of shape `(num_chains, num_samples)`.
Runs `model` on each sample in `chain` returning a `OrderedDict{VarName, Matrix{Float64}}`
with keys being model variables and values being matrices of shape
`(num_chains, num_samples)`.

`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
Currently, only `String` and `VarName` are supported. `whichlogprob` specifies
which log-probabilities to compute. It can be `:both`, `:prior`, or
Currently, only `String` and `VarName` are supported, with `VarName` being the default.
`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
`:likelihood`.

See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref).
Expand Down Expand Up @@ -177,13 +177,13 @@ julia> # A chain with 3 iterations.
);

julia> pointwise_logdensities(model, chain)
OrderedDict{String, Matrix{Float64}} with 6 entries:
"s" => [-0.802775; -1.38222; -2.09861;;]
"m" => [-8.91894; -7.51551; -7.46824;;]
"xs[1]" => [-5.41894; -5.26551; -5.63491;;]
"xs[2]" => [-2.91894; -3.51551; -4.13491;;]
"xs[3]" => [-1.41894; -2.26551; -2.96824;;]
"y" => [-0.918939; -1.51551; -2.13491;;]
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
s => [-0.802775; -1.38222; -2.09861;;]
m => [-8.91894; -7.51551; -7.46824;;]
xs[1] => [-5.41894; -5.26551; -5.63491;;]
xs[2] => [-2.91894; -3.51551; -4.13491;;]
xs[3] => [-1.41894; -2.26551; -2.96824;;]
y => [-0.918939; -1.51551; -2.13491;;]

julia> pointwise_logdensities(model, chain, String)
OrderedDict{String, Matrix{Float64}} with 6 entries:
Expand Down Expand Up @@ -225,7 +225,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
```
"""
function pointwise_logdensities(
model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both)
model::Model, chain, ::Type{KeyType}=VarName, ::Val{whichlogprob}=Val(:both)
) where {KeyType,whichlogprob}
# Get the data by executing the model once
vi = VarInfo(model)
Expand Down Expand Up @@ -283,7 +283,7 @@ including the likelihood terms.

See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref).
"""
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=VarName) where {T}
return pointwise_logdensities(model, chain, T, Val(:likelihood))
end

Expand All @@ -301,7 +301,7 @@ including the prior terms.
See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref).
"""
function pointwise_prior_logdensities(
model::Model, chain, keytype::Type{T}=String
model::Model, chain, keytype::Type{T}=VarName
) where {T}
return pointwise_logdensities(model, chain, T, Val(:prior))
end
Expand Down
10 changes: 5 additions & 5 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ end
loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain)

# Check that they contain the correct variables.
@test all(string(vn) in keys(logjoints_pointwise) for vn in vns)
@test all(string(vn) in keys(logpriors_pointwise) for vn in vns)
@test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise))
@test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns)
@test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise))
@test all(vn in keys(logjoints_pointwise) for vn in vns)
@test all(vn in keys(logpriors_pointwise) for vn in vns)
@test !any(Base.Fix1(subsumes, @varname(x)), keys(logpriors_pointwise))
@test !any(vn in keys(loglikelihoods_pointwise) for vn in vns)
@test all(Base.Fix1(subsumes, @varname(x)), keys(loglikelihoods_pointwise))

# Get the sum of the logjoints for each of the iterations.
logjoints = [
Expand Down
Loading