Skip to content

Commit

Permalink
Merge 8024f24 into 8f90f83
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jul 30, 2023
2 parents 8f90f83 + 8024f24 commit 49186e7
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InferenceObjects = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down Expand Up @@ -53,6 +54,7 @@ DynamicHMC = "3.4"
DynamicPPL = "0.23"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
InferenceObjects = "0.2"
Libtask = "0.7, 0.8"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
Expand Down
85 changes: 85 additions & 0 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import LogDensityProblems
import LogDensityProblemsAD
import Random
import MCMCChains
using InferenceObjects: InferenceObjects
import StatsBase: predict

export InferenceAlgorithm,
Expand Down Expand Up @@ -70,6 +71,15 @@ export InferenceAlgorithm,
isgibbscomponent,
externalsampler

const turing_inferencedata_key_map = (
hamiltonian_energy = :energy,
hamiltonian_energy_error = :energy_error,
is_adapt = :tune,
max_hamiltonian_energy_error = :max_energy_error,
nom_step_size = :step_size_nom,
numerical_error = :diverging,
)

#######################
# Sampler abstraction #
#######################
Expand Down Expand Up @@ -450,6 +460,81 @@ end

DynamicPPL.loadstate(chain::MCMCChains.Chains) = chain.info[:samplerstate]

# Default InferenceObjects constructor
# This is type piracy!
function AbstractMCMC.bundle_samples(
ts::Vector,
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
chain_type::Type{InferenceObjects.InferenceData};
group = spl isa SampleFromPrior ? :prior : :posterior,
save_state = false,
stats = missing,
dims=(;),
coords=(;),
kwargs...
)
sample = map(t -> map(v -> length(v[1]) == 1 ? v[1][1] : v[1], getparams(t)), ts)
sample_stats = map(_rename_sample_stats metadata, ts)

# Set up the info tuple.
attrs = OrderedDict{String,Any}()
if save_state
attrs["model"] = model
attrs["sampler"] = spl
attrs["samplerstate"] = state
end

# Merge in the timing info, if available
if !ismissing(stats)
attrs["start_time"] = stats.start
attrs["stop_time"] = stats.stop
end

# Get the average or final log evidence, if it exists.
le = getlogevidence(ts, spl, state)
if !ismissing(le)
attrs["log_evidence"] = le
end

# identify if this is posterior or prior
sample_stats_group = group === :prior ? :sample_stats_prior : :sample_stats

# InferenceData construction.
idata = InferenceObjects.convert_to_inference_data(
[sample];
group=group,
sample_stats_group => [sample_stats],
attrs=attrs,
dims=dims,
coords=coords,
)
return idata
end

function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.InferenceData})
nchains = length(c)
nchains == 1 && return c[1]
groups = map(keys(first(c))) do k
k => AbstractMCMC.chainsstack(map(idata -> idata[k], c))
end
return InferenceObjects.InferenceData(; groups...)
end
function AbstractMCMC.chainsstack(c::AbstractVector{<:InferenceObjects.Dataset})
nchains = length(c)
nchains == 1 && return c[1]
# TODO: gather our metadata into vectors instead of replacing
group = cat(c...; dims=:chain)
# give each chain a different index
return InferenceObjects.DimensionalData.set(group, :chain => Base.OneTo(nchains))
end

function _rename_sample_stats(stats::NamedTuple)
new_keys = map(k -> get(turing_inferencedata_key_map, k, k), keys(stats))
return NamedTuple{new_keys}(values(stats))
end

#######################################
# Concrete algorithm implementations. #
#######################################
Expand Down

0 comments on commit 49186e7

Please sign in to comment.