Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler 3.0 #965

Merged
merged 67 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
7d5451c
VarName refactor
mohamed82008 Nov 17, 2019
e4365ef
minor cleanup
mohamed82008 Nov 17, 2019
f848eed
new compiler proof of concept
mohamed82008 Nov 18, 2019
f596e29
update model macro docstring
mohamed82008 Nov 18, 2019
e9d2a56
remove compiler2.jl and make it compiler.jl
mohamed82008 Nov 18, 2019
f46d530
replace @vi() with the varinfo input to the model
mohamed82008 Nov 18, 2019
f162905
user-input random variable names
mohamed82008 Nov 19, 2019
0b73b74
minor cleanup
mohamed82008 Nov 19, 2019
a73ed67
introduce the concept of a context
mohamed82008 Nov 19, 2019
c1b6803
minor cleanup
mohamed82008 Nov 19, 2019
59b8e9a
BatchContext to scale the log likelihood
mohamed82008 Nov 19, 2019
db597cd
minor cleanup
mohamed82008 Nov 19, 2019
9e7ad76
skipping logpdf when in LikelihoodContext
mohamed82008 Nov 19, 2019
6594ea4
partial missing data bug fix
mohamed82008 Nov 19, 2019
8a00ed3
fix autodiff for partial missing data
mohamed82008 Nov 19, 2019
84583b3
fix Kai's and Tor's comments
mohamed82008 Nov 20, 2019
48158ab
NoDist of a NamedDist is a NamedDist
mohamed82008 Nov 20, 2019
277e11c
rename assume_or_observe to tilde & add dot_tilde
mohamed82008 Nov 20, 2019
645dce2
inference tilde cleanup
mohamed82008 Nov 20, 2019
d8a1101
remove vector _tilde
mohamed82008 Nov 20, 2019
fd25d2b
some more cleanup
mohamed82008 Nov 20, 2019
2d8b860
remove repeated comment
mohamed82008 Nov 20, 2019
daff9c6
add FillArrays to deps
mohamed82008 Nov 20, 2019
873ecf1
change vec ~ tests to .~
mohamed82008 Nov 20, 2019
6e99ab0
fix compiler tests
mohamed82008 Nov 20, 2019
488b991
add logpdf macro for use in the model macro
mohamed82008 Nov 20, 2019
8cb9240
fix bugs
mohamed82008 Nov 21, 2019
4b8b50c
avoid evaluating the LHS or RHS twice
mohamed82008 Nov 21, 2019
4fb20c5
compiler and varinfo docs
mohamed82008 Nov 21, 2019
f82398c
remove spaces in varname indexing
mohamed82008 Nov 21, 2019
5e21a13
type stability fix
mohamed82008 Nov 22, 2019
6238adf
minor fix and docs update
mohamed82008 Nov 22, 2019
40c1e8d
remove unused arg in dot_tilde observe method
mohamed82008 Nov 22, 2019
6e8958e
shorten the docstring of the `@model` macro
mohamed82008 Nov 22, 2019
9a5572b
fix tests
mohamed82008 Nov 22, 2019
3ce0956
try dropping support for .~ on Julia 1.0 only
mohamed82008 Nov 22, 2019
e19d414
fix #760
mohamed82008 Nov 22, 2019
d57c594
use @. in compiler docs
mohamed82008 Nov 22, 2019
eadcef0
support @. and conditionally support .~ when valid
mohamed82008 Nov 22, 2019
0909b3b
minor cleanup
mohamed82008 Nov 22, 2019
09f6b1f
make ambiguity error say @. or .~
mohamed82008 Nov 22, 2019
ef2a05a
test @. by default and conditionally test .~
mohamed82008 Nov 22, 2019
7135b8f
support broadcasting ~ with mismatched array sizes
mohamed82008 Nov 23, 2019
f1f87b1
add .~ test for mismatched array sizes
mohamed82008 Nov 23, 2019
78338e1
test for throwing when input is missing
mohamed82008 Nov 23, 2019
0ed0d8b
add more tests
mohamed82008 Nov 23, 2019
13a83a9
Merge branch 'master' into mt/compiler3.0
mohamed82008 Nov 27, 2019
1342e04
fix merge
mohamed82008 Nov 27, 2019
40332e7
workaround string(:) == "Colon" in Julia 1.2
mohamed82008 Nov 27, 2019
ea5a5cc
fix the colon thing for real
mohamed82008 Nov 27, 2019
fc5f558
minor test fix
mohamed82008 Nov 27, 2019
6f5206f
add `@sampler()` to access the sampler in model
mohamed82008 Nov 27, 2019
f6721fd
increase sample size and lower atol in mh test
mohamed82008 Nov 28, 2019
22dab76
remove FillArrays dep and reorganize a bit
mohamed82008 Dec 2, 2019
ad72d1e
Merge branch 'master' into mt/compiler3.0
cpfiffer Dec 2, 2019
1f0daca
Remove docstring spacing.
cpfiffer Dec 2, 2019
2212df1
Interface -> AbstractMCMC in dynamichmc
mohamed82008 Dec 2, 2019
b1984cf
BatchContext -> MiniBatchContext & ctx docstrings
mohamed82008 Dec 2, 2019
f7ed50b
Core.tilde vs Inference.tilde in compiler docs
mohamed82008 Dec 2, 2019
61c7bb5
ismissing to === missing
mohamed82008 Dec 7, 2019
8829f52
style issues and cleanup
mohamed82008 Dec 7, 2019
313f5ff
is_number_or_array_type -> isa FloatOrArrayType
mohamed82008 Dec 7, 2019
a9d03b0
get_matching_type changes
mohamed82008 Dec 8, 2019
0086db8
Merge branch 'master' into mt/compiler3.0
mohamed82008 Dec 8, 2019
ce92abb
remove SpecialFunctions compat block
mohamed82008 Dec 8, 2019
490ecb3
update the model internals docs
mohamed82008 Dec 8, 2019
fe46fc3
update the guide.md file
mohamed82008 Dec 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/using-turing/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ mf(vi, sampler, model) = begin
# Assume s has an InverseGamma distribution.
s, lp = Turing.assume(sampler,
InverseGamma(2, 3),
Turing.VarName([:c_s, :s], ""), vi)
Turing.@varname(s), vi)

# Add the lp to the accumulated logp.
vi.logp += lp

# Assume m has a Normal distribution.
m, lp = Turing.assume(sampler,
Normal(0, sqrt(s)),
Turing.VarName([:c_m, :m], ""), vi)
Turing.@varname(m), vi)

# Add the lp to the accumulated logp.
vi.logp += lp
Expand Down
226 changes: 226 additions & 0 deletions docs/src/using-turing/compiler.md

Large diffs are not rendered by default.

91 changes: 58 additions & 33 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,34 @@ const CACHERANGES = 0b01

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_TURING", "0")))

"""
struct Model{pvars, dvars, F, TData, TDefaults}
f::F
data::TData
defaults::TDefaults
end
# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")

A `Model` struct with parameter variables `pvars`, data variables `dvars`, inner
function `f`, `data::NamedTuple` and `defaults::NamedTuple`.
"""
struct Model{pvars,
dvars,
F,
TData,
TDefaults
} <: AbstractModel
struct Model{F, Targs <: NamedTuple}
f::F
data::TData
defaults::TDefaults
end
function Model{pvars, dvars}(f::F, data::TD, defaults::TDefaults) where {pvars, dvars, F, TD, TDefaults}
return Model{pvars, dvars, F, TD, TDefaults}(f, data, defaults)
args::Targs
end
get_pvars(m::Model{params}) where {params} = Tuple(params.types)
get_dvars(m::Model{params, data}) where {params, data} = Tuple(data.types)
get_defaults(m::Model) = m.defaults
@generated function in_pvars(::Val{sym}, ::Model{params}) where {sym, params}
return sym in params.types ? :(true) : :(false)
end
@generated function in_dvars(::Val{sym}, ::Model{params, data}) where {sym, params, data}
return sym in data.types ? :(true) : :(false)

A `Model` struct with arguments `args` and inner function `f`.
"""
struct Model{F, Targs <: NamedTuple} <: AbstractModel
f::F
args::Targs
end
(model::Model)(vi) = model(vi, SampleFromPrior())
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
getmissing(model::Model) = _getmissing(model.args)
@generated function _getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}
minds = filter(1:length(names)) do i
ttuple.types[i] == Missing
end
mnames = names[minds]
return :(Val{$mnames}())
end

function runmodel! end
function getspace end

Expand Down Expand Up @@ -114,6 +109,36 @@ Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)

abstract type AbstractContext end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

"""
struct DefaultContext <: AbstractContext end

The `DefaultContext` is used by default to compute log the joint probability of the data and parameters when running the model.
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
"""
struct DefaultContext <: AbstractContext end

"""
struct LikelihoodContext <: AbstractContext end

The `LikelihoodContext` enables the computation of the log likelihood of the data when running the model.
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
"""
struct LikelihoodContext <: AbstractContext end

"""
struct MiniBatchContext{Tctx, T} <: AbstractContext
ctx::Tctx
loglike_scalar::T
end

The `MiniBatchContext` enables the computation of `log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the `loglike_scalar` field, typically equal to `the number of data points / batch size`. This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation.
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
"""
struct MiniBatchContext{Tctx, T} <: AbstractContext
ctx::Tctx
loglike_scalar::T
end
MiniBatchContext(ctx = DefaultContext(); batch_size, npoints) = MiniBatchContext(ctx, npoints/batch_size)
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

include("utilities/Utilities.jl")
using .Utilities
include("core/Core.jl")
Expand All @@ -140,17 +165,16 @@ using .Variational
include("contrib/inference/dynamichmc.jl")
end

# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")

###########
# Exports #
###########

# Turing essentials - modelling macros and inference algorithms
export @model, # modelling
@VarName,
@varname,
@varinfo,
@logpdf,
@sampler,

MH, # classic sampling
Gibbs,
Expand Down Expand Up @@ -190,6 +214,7 @@ export @model, # modelling
BinomialLogit,
VecBinomialLogit,
OrderedLogistic,
LogPoisson
LogPoisson,
NamedDist

end
13 changes: 0 additions & 13 deletions src/contrib/inference/AdvancedSMCExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,6 @@ function Sampler(alg::PMMH, model::Model, s::Selector)
space = union(space, sub_alg.space)
end

# Sanity check for space
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
if !isempty(space)
@assert issubset(Set(get_pvars(model)), space) "[$alg_str] symbols specified to samplers ($space)" * "
doesn't cover the model parameters ($(Set(get_pvars(model))))"

if Set(get_pvars(model)) != space
warn(
"[$alg_str] extra parameters specified by samplers" *
"don't exist in model: $(setdiff(space, Set(get_pvars(model))))"
)
end
end

info[:old_likelihood_estimate] = -Inf # Force to accept first proposal
info[:old_prior_prob] = 0.0
info[:samplers] = samplers
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function Sampler(
end

# Disable the callback for DynamicHMC, since it has it's own progress meter.
function Turing.Interface.init_callback(
function AbstractMCMC.init_callback(
rng::AbstractRNG,
model::Model,
s::Sampler{<:DynamicNUTS},
Expand Down
7 changes: 5 additions & 2 deletions src/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ include("container.jl")
include("ad.jl")

export @model,
@VarName,
@varname,
generate_observe,
translate_tilde!,
get_vars,
Expand Down Expand Up @@ -51,6 +51,9 @@ export @model,
setchunksize,
verifygrad,
gradient_logp_forward,
gradient_logp_reverse
gradient_logp_reverse,
@varinfo,
@logpdf,
@sampler

end # module
77 changes: 25 additions & 52 deletions src/core/RandomVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module RandomVariables

using ...Turing: Turing, CACHERESET, CACHEIDCS, CACHERANGES, Model,
AbstractSampler, Sampler, SampleFromPrior, SampleFromUniform,
Selector, getspace
Selector, getspace, AbstractContext, DefaultContext
using ...Utilities: vectorize, reconstruct, reconstruct!
using Bijectors: SimplexDistribution, link, invlink
using Distributions
Expand Down Expand Up @@ -54,30 +54,25 @@ export VarName,
"""
```
struct VarName{sym}
csym :: Symbol
indexing :: String
counter :: Int
end
```

A variable identifier. Every variable has a symbol `sym`, indices `indexing`, and
internal fields: `csym` and `counter`. The Julia variable in the model corresponding to
`sym` can refer to a single value or to a hierarchical array structure of univariate,
multivariate or matrix variables. `indexing` stores the indices that can access the
random variable from the Julia variable.
A variable identifier. Every variable has a symbol `sym` and `indices `indexing`.
The Julia variable in the model corresponding to `sym` can refer to a single value or
to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia
variable.

Examples:

- `x[1] ~ Normal()` will generate a `VarName` with `sym == :x` and `indexing == "[1]"`.
- `x[:,1] ~ MvNormal(zeros(2))` will generate a `VarName` with `sym == :x` and
`indexing == "[Colon(), 1]"`.
`indexing == "[Colon(),1]"`.
- `x[:,1][2] ~ Normal()` will generate a `VarName` with `sym == :x` and
`indexing == "[Colon(), 1][2]"`.
`indexing == "[Colon(),1][2]"`.
"""
struct VarName{sym}
csym :: Symbol # symbol generated in compilation time
indexing :: String # indexing
counter :: Int # counter of same {csym, uid}
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
indexing::String # indexing
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
end

abstract type AbstractVarInfo end
Expand Down Expand Up @@ -180,9 +175,10 @@ const TypedVarInfo = VarInfo{<:NamedTuple}

function VarInfo(model::Model)
vi = VarInfo()
model(vi, SampleFromUniform())
model(vi)
return TypedVarInfo(vi)
end
(model::Model)() = model(Turing.VarInfo(), SampleFromPrior())

function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
new_vi = deepcopy(old_vi)
Expand Down Expand Up @@ -591,63 +587,39 @@ end
# VarName

"""
`VarName(csym, sym, indexing, counter)`
`VarName{sym}(csym::Symbol, indexing::String)`
`VarName(sym, indexing)`
`VarName{sym}(indexing::String)`

Constructs a new instance of `VarName{sym}`
"""
VarName(csym, sym, indexing, counter) = VarName{sym}(csym, indexing, counter)
function VarName(csym::Symbol, sym::Symbol, indexing::String)
# TODO: update this method when implementing the sanity check
return VarName{sym}(csym, indexing, 1)
end
function VarName{sym}(csym::Symbol, indexing::String) where {sym}
# TODO: update this method when implementing the sanity check
return VarName{sym}(csym, indexing, 1)
end

"""
`VarName(syms::Vector{Symbol}, indexing::String)`

Constructs a new instance of `VarName{syms[2]}`
"""
function VarName(syms::Vector{Symbol}, indexing::String) where {sym}
# TODO: update this method when implementing the sanity check
return VarName{syms[2]}(syms[1], indexing, 1)
end
VarName(sym, indexing) = VarName{sym}(indexing)

"""
`VarName(vn::VarName, indexing::String)`

Returns a copy of `vn` with a new index `indexing`.
"""
function VarName(vn::VarName, indexing::String)
return VarName(vn.csym, vn.sym, indexing, vn.counter)
function VarName(vn::VarName{sym}, indexing::String) where {sym}
return VarName{sym}(indexing)
end

function getproperty(vn::VarName{sym}, f::Symbol) where {sym}
return f === :sym ? sym : getfield(vn, f)
end

# NOTE: VarName should only be constructed by VarInfo internally due to the nature of the counter field.

"""
`uid(vn::VarName)`

Returns a unique tuple identifier for `vn`.
"""
uid(vn::VarName) = (vn.csym, vn.sym, vn.indexing, vn.counter)
uid(vn::VarName) = (vn.sym, vn.indexing)

hash(vn::VarName) = hash(uid(vn))

==(x::VarName, y::VarName) = hash(uid(x)) == hash(uid(y))

function string(vn::VarName; all = true)
if all
return "{$(vn.csym),$(vn.sym)$(vn.indexing)}:$(vn.counter)"
else
return "$(vn.sym)$(vn.indexing)"
end
function string(vn::VarName)
return "$(vn.sym)$(vn.indexing)"
end
function string(vns::Vector{<:VarName})
return replace(string(map(vn -> string(vn), vns)), "String" => "")
Expand All @@ -658,7 +630,7 @@ end

Returns a `Symbol` represenation of the variable identifier `VarName`.
"""
Symbol(vn::VarName) = Symbol(string(vn, all=false)) # simplified symbol
Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol

"""
`in(vn::VarName, space::Set)`
Expand All @@ -670,7 +642,7 @@ function in(vn::VarName, space::Tuple)::Bool
return true
else
# String representation of `vn`
vn_str = string(vn, all=false)
vn_str = string(vn)
return _in(vn_str, space)
end
end
Expand Down Expand Up @@ -699,21 +671,22 @@ function has_eval_num(spl::T) where T<:AbstractSampler
end

"""
`runmodel!(model::Model, vi::AbstractVarInfo, spl::AbstractSampler)`
`runmodel!(model::Model, vi::AbstractVarInfo, spl::AbstractSampler, ctx::AbstractContext)`

Samples from `model` using the sampler `spl` storing the sample and log joint
probability in `vi`.
"""
function runmodel!(
model::Model,
vi::AbstractVarInfo,
spl::T = SampleFromPrior()
spl::T = SampleFromPrior(),
ctx::AbstractContext = DefaultContext()
) where T<:AbstractSampler
setlogp!(vi, 0)
if has_eval_num(spl)
spl.state.eval_num += 1
end
model(vi, spl)
model(vi, spl, ctx)
return vi
end

Expand Down Expand Up @@ -1106,7 +1079,7 @@ end
length(names) === 0 && return :(NamedTuple())
expr = Expr(:tuple)
map(names) do f
push!(expr.args, Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns, all=false))))
push!(expr.args, Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns))))
end
return expr
end
Expand Down
Loading