Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] committed Dec 30, 2022
1 parent fb4a785 commit 17531c2
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/sample.jl
Expand Up @@ -12,8 +12,12 @@ function setprogress!(progress::Bool)
return progress
end

function StatsBase.sample(model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...)
return StatsBase.sample(Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs...)
function StatsBase.sample(
model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
)
return StatsBase.sample(
Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs...
)
end

"""
Expand Down Expand Up @@ -72,11 +76,7 @@ Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample`
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
logdensity,
sampler::AbstractSampler,
N_or_isdone;
kwargs...,
rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
)
return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...)
end
Expand Down Expand Up @@ -145,10 +145,11 @@ function StatsBase.sample(
nchains::Integer;
kwargs...,
)
return StatsBase.sample(rng, _model(logdensity), sampler, parallel, N, nchains; kwargs...)
return StatsBase.sample(
rng, _model(logdensity), sampler, parallel, N, nchains; kwargs...
)
end


# Default implementations of regular and parallel sampling.

function mcmcsample(
Expand Down Expand Up @@ -593,7 +594,11 @@ tighten_eltype(x::Vector{Any}) = map(identity, x)

function _model(logdensity)
if LogDensityProblems.capabilities(logdensity) === nothing
throw(ArgumentError("the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`"))
throw(
ArgumentError(
"the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`",
),
)
end
return LogDensityModel(logdensity)
end
Expand Down

0 comments on commit 17531c2

Please sign in to comment.