Skip to content

Commit

Permalink
Document distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoct committed Aug 31, 2018
1 parent 79744ee commit 84428da
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -5,3 +5,5 @@
*.swp
*.aux
*.log
docs/build/
docs/site/
10 changes: 10 additions & 0 deletions docs/make.jl
@@ -0,0 +1,10 @@
using Documenter, Gen

makedocs(
format = :html,
sitename = "Gen",
pages = [
"index.md"
]
)

132 changes: 122 additions & 10 deletions src/distribution.jl
@@ -1,41 +1,60 @@
import Distributions
using SpecialFunctions: lgamma, lbeta


#########################
# abstract distribution #
#########################

abstract type Distribution{T} end

function random end
function logpdf end
get_return_type(::Distribution{T}) where {T} = T

export Distribution
export random
export logpdf


#########
# dirac #
#########

struct Dirac <: Distribution{Float64} end

"""
dirac(val::Real)
Deterministically return a `Float64` that is equal to the given value.
`logpdf(dirac, x, y)` returns `0` if `x == y` and `-Inf` otherwise.
"""
const dirac = Dirac()

function logpdf(::Dirac, x::Real, y::Real)
x == y ? 0. : -Inf
end
logpdf(::Dirac, x::Real, y::Real) = (x == y ? 0. : -Inf)

function random(::Dirac, y::Real)
y
end
random(::Dirac, y::Real) = Float64(y)

(::Dirac)(y) = random(Dirac(), y)

has_output_grad(::Dirac) = false
has_argument_grads(::Dirac) = (false,)
get_static_argument_types(::Dirac) = [Float64]

export dirac


#############
# bernoulli #
#############

struct Bernoulli <: Distribution{Bool} end

"""
bernoulli(prob_true::Real)
Samples a `Bool` value which is true with given probability
"""
const bernoulli = Bernoulli()

function logpdf(::Bernoulli, x::Bool, prob::Real)
Expand Down Expand Up @@ -63,6 +82,11 @@ export bernoulli

struct Normal <: Distribution{Float64} end

"""
normal(mu::Real, std::Real)
Samples a `Float64` value from a normal distribution.
"""
const normal = Normal()

function logpdf(::Normal, x::Real, mu::Real, std::Real)
Expand All @@ -80,9 +104,7 @@ function logpdf_grad(::Normal, x::Real, mu::Real, std::Real)
(deriv_x, deriv_mu, deriv_sigma)
end

function random(::Normal, mu::Real, std::Real)
mu + std * randn()
end
random(::Normal, mu::Real, std::Real) = mu + std * randn()

(::Normal)(mu, std) = random(Normal(), mu, std)

Expand All @@ -99,6 +121,11 @@ export normal

struct Gamma <: Distribution{Float64} end

"""
gamma(shape::Real, scale::Real)
Sample a `Float64` from a gamma distribution.
"""
const gamma = Gamma()

function logpdf(::Gamma, x::Real, shape::Real, scale::Real)
Expand All @@ -109,12 +136,19 @@ function logpdf(::Gamma, x::Real, shape::Real, scale::Real)
end
end

function logpdf_grad(::Gamma, x::Real, shape::Real, scale::Real)
error("Not Implemented")
(nothing, nothing, nothing)
end

function random(::Gamma, shape::Real, scale::Real)
rand(Distributions.Gamma(shape, scale))
end

(::Gamma)(shape, scale) = random(Gamma(), shape, scale)

has_output_grad(::Gamma) = false
has_argument_grads(::Gamma) = (false, false)
get_static_argument_types(::Gamma) = [Float64, Float64]

export gamma
Expand All @@ -126,6 +160,11 @@ export gamma

struct InverseGamma <: Distribution{Float64} end

"""
inv_gamma(shape::Real, scale::Real)
Sample a `Float64` from a inverse gamma distribution.
"""
const inv_gamma = InverseGamma()

function logpdf(::InverseGamma, x::Real, shape::Real, scale::Real)
Expand All @@ -136,56 +175,91 @@ function logpdf(::InverseGamma, x::Real, shape::Real, scale::Real)
end
end

function logpdf_grad(::InverseGamma, x::Real, shape::Real, scale::Real)
error("Not Implemented")
(nothing, nothing, nothing)
end


function random(::InverseGamma, shape::Real, scale::Real)
rand(Distributions.InverseGamma(shape, scale))
end

(::InverseGamma)(shape, scale) = random(InverseGamma(), shape, scale)

has_output_grad(::InverseGamma) = false
has_argument_grads(::InverseGamma) = (false, false)
get_static_argument_types(::InverseGamma) = [Float64, Float64]

export inv_gamma


########
# beta #
########

struct Beta <: Distribution{Float64} end

"""
beta(alpha::Real, beta::Real)
Sample a `Float64` from a beta distribution.
"""
const beta = Beta()

function logpdf(::Beta, x::Real, alpha::Real, beta::Real)
(alpha - 1) * log(x) + (beta - 1) * log1p(-x) - lbeta(alpha, beta)
end

function logpdf_grad(::Beta, x::Real, alpha::Real, beta::Real)
error("Not Implemented")
(nothing, nothing, nothing)
end

function random(::Beta, alpha::Real, beta::Real)
rand(Distributions.Beta(alpha, beta))
end

(::Beta)(alpha, beta) = random(Beta(), alpha, beta)

has_output_grad(::Beta) = false
has_argument_grads(::Beta) = (false, false)
get_static_argument_types(::Beta) = [Float64, Float64]

export beta


###############
# categorical #
###############

struct Categorical <: Distribution{Int} end

"""
categorical(probs::AbstractArray{U, 1}) where {U <: Real}
Given a vector of probabilities `probs` where `sum(probs) = 1`, sample an `Int` `i` from the set {1, 2, .., `length(probs)`} with probability `probs[i]`.
"""
const categorical = Categorical()

function logpdf(::Categorical, x::Int, probs::AbstractArray{U,1}) where {U <: Real}
log(probs[x])
end

function logpdf_grad(::Beta, x::Int, probs::AbstractArray{U,1}) where {U <: Real}
grad = zeros(length(probs))
grad[x] = 1.0
(nothing, grad)
end

function random(::Categorical, probs::AbstractArray{U,1}) where {U <: Real}
rand(Distributions.Categorical(probs))
end

(::Categorical)(probs) = random(Categorical(), probs)

has_output_grad(::Categorical) = false
has_argument_grads(::Categorical) = (true,)
get_static_argument_types(::Categorical) = [Vector{Float64}]

export categorical
Expand All @@ -197,20 +271,31 @@ export categorical

struct UniformDiscrete <: Distribution{Int} end

"""
uniform_discrete(low::Integer, high::Integer)
Sample an `Int` from the uniform distribution on the set {low, low + 1, ..., high-1, high}.
"""
const uniform_discrete = UniformDiscrete()

function logpdf(::UniformDiscrete, x::Int, low::Integer, high::Integer)
d = Distributions.DiscreteUniform(low, high)
Distributions.logpdf(d, x)
end

function logpdf_grad(::UniformDiscrete, x::Int, lower::Integer, high::Integer)
(nothing, nothing, nothing)
end

function random(::UniformDiscrete, low::Integer, high::Integer)
rand(Distributions.DiscreteUniform(low, high))
end

(::UniformDiscrete)(low, high) = random(UniformDiscrete(), low, high)

get_static_argument_types(::UniformDiscrete) = [Float64, Float64]
has_output_grad(::UniformDiscrete) = false
has_argument_grads(::UniformDiscrete) = (false, false)
get_static_argument_types(::UniformDiscrete) = [Int, Int]

export uniform_discrete

Expand All @@ -222,40 +307,67 @@ export uniform_discrete
struct UniformContinuous <: Distribution{Float64} end

const uniform_continuous = UniformContinuous()

"""
uniform(low::Real, high::Real)
Sample a `Float64` from the uniform distribution on the interval [low, high].
"""
const uniform = uniform_continuous

function logpdf(::UniformContinuous, x::Real, low::Real, high::Real)
(x >= low && x <= high) ? -log(high-low) : -Inf
end

function logpdf_grad(::UniformContinuous, x::Real, low::Real, high::Real)
inv_diff = 1. / (high-low)
(0., inv_diff, -inv_diff)
end

function random(::UniformContinuous, low::Real, high::Real)
rand() * (high - low) + low
end

(::UniformContinuous)(low, high) = random(UniformContinuous(), low, high)

has_output_grad(::UniformContinuous) = true
has_argument_grads(::UniformContinuous) = (true, true)
get_static_argument_types(::UniformContinuous) = [Float64, Float64]

export uniform_continuous, uniform


###########
# poisson #
###########

struct Poisson <: Distribution{Int} end

"""
poisson(lambda::Real)
Sample an `Int` from the Poisson distribution with rate `lambda`.
"""
const poisson = Poisson()

function logpdf(::Poisson, x::Integer, lambda::Real)
x * log(lambda) - lambda - lgamma(x+1)
end

function logpdf_grad(::Poisson, x::Integer, lambda::Real)
error("Not implemented")
(nothing, nothing)
end


function random(::Poisson, lambda::Real)
rand(Distributions.Poisson(lambda))
end

(::Poisson)(lambda) = random(Poisson(), lambda)

has_output_grad(::Poisson) = false
has_argument_grads(::Poisson) = (false,)
get_static_argument_types(::Poisson) = [Float64]

export poisson

0 comments on commit 84428da

Please sign in to comment.