Skip to content

Commit

Permalink
Taking bijectors autodiff seriously (#77)
Browse files Browse the repository at this point in the history
* squash commit

* rebase fixes

* respond to Tor's comments and some simplifications

* simplifications and test fixes on Julia 1.0

* use isapprox in tests

* typo

* enable coveralls

* AD tests and some test fixes

* many test fixes

* fix type stability issue

* fix Dirichlet logabsdetjac gradient

* remove failed tests

* remove failed TuringMvNormal tests

* bump DAD compat version

* fix PD distributions
  • Loading branch information
mohamed82008 committed Mar 9, 2020
1 parent 612200e commit b0aaa98
Show file tree
Hide file tree
Showing 32 changed files with 2,490 additions and 388 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ForwardDiff_Tracker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: ForwardDiff and Tracker tests

on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.0, 1.3]
julia-arch: [x64, x86]
os: [ubuntu-latest, macOS-latest]
exclude:
- os: macOS-latest
julia-arch: x86

steps:
- uses: actions/checkout@v1.0.0
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
env:
STAGE: ForwardDiff_Tracker
29 changes: 29 additions & 0 deletions .github/workflows/Zygote.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Zygote tests

on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.0, 1.3]
julia-arch: [x64, x86]
os: [ubuntu-latest, macOS-latest]
exclude:
- os: macOS-latest
julia-arch: x86

steps:
- uses: actions/checkout@v1.0.0
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
env:
STAGE: Zygote
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ matrix:
notifications:
email: false

script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia --check-bounds=yes -e 'using Pkg;
Pkg.test("Bijectors"; coverage=true)'

after_success:
- if [[ $TRAVIS_JULIA_VERSION = 1.3 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
julia -e 'using Pkg; cd(Pkg.dir("Bijectors")); Pkg.add("Coverage"); using Coverage; Coveralls.submit(process_folder())'
Expand Down
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.5.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Expand All @@ -18,21 +19,28 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ArgCheck = "1, 2.0"
Combinatorics = "0.7"
Compat = "3.0"
Distributions = "0.21.11, 0.22"
DistributionsAD = "0.4.2"
ForwardDiff = "0.10.3"
MappedArrays = "0.2.2"
NNlib = "0.6"
Reexport = "0.2"
Requires = "1"
Requires = "0.5, 1"
Roots = "0.8.4"
StatsFuns = "0.8, 0.9.3"
Tracker = "0.2.3"
Zygote = "0.4.7"
julia = "1"

[extras]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ForwardDiff", "Test", "Tracker"]
test = ["ForwardDiff", "Test", "Tracker", "DistributionsAD", "Zygote", "Combinatorics"]
43 changes: 24 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

[![Build Status](https://travis-ci.org/TuringLang/Bijectors.jl.svg?branch=master)](https://travis-ci.org/TuringLang/Bijectors.jl)
[![Build status](https://ci.appveyor.com/api/projects/status/mvfs8eio2cscwk1m?svg=true)](https://ci.appveyor.com/project/TuringLang/bijectors-jl)
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/Bijectors.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/Bijectors.jl?branch=master)


This package implements a set of functions for transforming constrained random variables (e.g. simplexes, intervals) to Euclidean space. The 3 main functions implemented in this package are the `link`, `invlink` and `logpdf_with_trans` for a number of distributions. The distributions supported are:
Expand Down Expand Up @@ -126,7 +127,7 @@ What about `invlink`?

```julia
julia> b⁻¹ = inv(b)
Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> b⁻¹(y)
0.3688868996596376
Expand All @@ -135,10 +136,10 @@ julia> b⁻¹(y) == invlink(dist, y)
true
```

Pretty neat, huh? `Inversed{Logit}` is also a `Bijector` where we've defined `(ib::Inversed{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inversed`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.

#### Dimensionality
One more thing. See the `0` in `Inversed{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:

```julia
julia> Bijectors.dimension(b)
Expand All @@ -155,7 +156,7 @@ Also, we can _compose_ bijectors:

```julia
julia> id_y = (b b⁻¹)
Composed{Tuple{Inversed{Logit{Float64},0},Logit{Float64}},0}((Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_y(y) y
true
Expand All @@ -165,7 +166,7 @@ And since `Composed isa Bijector`:

```julia
julia> id_x = inv(id_y)
Composed{Tuple{Inversed{Logit{Float64},0},Logit{Float64}},0}((Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_x(x) x
true
Expand Down Expand Up @@ -264,12 +265,12 @@ julia> b = bijector(dist) # (0, 1) → ℝ
Logit{Float64}(0.0, 1.0)

julia> b⁻¹ = inv(b) # ℝ → (0, 1)
Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
TransformedDistribution{Normal{Float64},Inversed{Logit{Float64},0},Univariate}(
TransformedDistribution{Normal{Float64},Inverse{Logit{Float64},0},Univariate}(
dist: Normal{Float64}=0.0, σ=1.0)
transform: Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
transform: Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
)


Expand Down Expand Up @@ -338,10 +339,10 @@ julia> # Construct the transform
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())

julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
(Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inversed{SimplexBijector{true},1}(SimplexBijector{true}()))
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))

julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
Stacked{Tuple{Inversed{Logit{Float64},0},Exp{0},Inversed{SimplexBijector{true},1}},3}((Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inversed{SimplexBijector{true},1}(SimplexBijector{true}())), (1:1, 2:2, 3:4))
Stacked{Tuple{Inverse{Logit{Float64},0},Exp{0},Inverse{SimplexBijector{true},1}},3}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}())), (1:1, 2:2, 3:4))

julia> # Mean-field normal with unconstrained-to-constrained stacked bijector
td = transformed(d, sb);
Expand Down Expand Up @@ -416,10 +417,10 @@ julia> d = MvNormal(zeros(2), ones(2));
julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked{Tuple{Exp{0},Inversed{Logit{Float64},0}},2}((Exp{0}(), Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))

julia> b = sb PlanarLayer(2)
Composed{Tuple{PlanarLayer{Array{Float64,2},Array{Float64,1}},Stacked{Tuple{Exp{0},Inversed{Logit{Float64},0}},2}},1}((PlanarLayer{Array{Float64,2},Array{Float64,1}}([1.49138; 0.367563], [-0.886205; 0.684565], [-1.59058]), Stacked{Tuple{Exp{0},Inversed{Logit{Float64},0}},2}((Exp{0}(), Inversed{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))))
Composed{Tuple{PlanarLayer{Array{Float64,2},Array{Float64,1}},Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}},1}((PlanarLayer{Array{Float64,2},Array{Float64,1}}([1.49138; 0.367563], [-0.886205; 0.684565], [-1.59058]), Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))))

julia> td = transformed(d, b);

Expand Down Expand Up @@ -514,7 +515,7 @@ import Bijectors: logabsdetjac

struct Identity{N} <: Bijector{N} end
(::Identity)(x) = x # transform itself, "forward"
(::Inversed{<: Identity})(y) = y # inverse tramsform, "backward"
(::Inverse{<: Identity})(y) = y # inverse tramsform, "backward"

# see the proper implementation for `logabsdetjac` in general
logabsdetjac(::Identity{0}, y::Real) = zero(eltype(y)) # ∂ₓid(x) = ∂ₓ x = 1 → log(abs(1)) = log(1) = 0
Expand All @@ -530,10 +531,14 @@ struct Logit{T<:Real} <: Bijector{0}
b::T
end

(b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a))
(ib::Inversed{<:Logit})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a # `orig` contains the `Bijector` which was inverted
(b::Logit)(x::Real) = logit((x - b.a) / (b.b - b.a))
(b::Logit)(x) = mapvcat(b, x)
# `orig` contains the `Bijector` which was inverted
(ib::Inverse{<:Logit})(y::Real) = (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a
(ib::Inverse{<:Logit})(y) = mapvcat(ib, y)

logabsdetjac(b::Logit, x) = @. - log((x - b.a) * (b.b - x) / (b.b - b.a))
logabsdetjac(b::Logit, x::Real) = - log((x - b.a) * (b.b - x) / (b.b - b.a))
logabsdetjac(b::Logit, x) = mapvcat(logabsdetjac, x)
```

(Batch computation is not fully supported by all bijectors yet (see issue #35), but is actively worked on. In the particular case of `Logit` there's only one thing that makes sense, which is elementwise application. Therefore we've added `@.` to the implementation above, thus this works for any `AbstractArray{<:Real}`.)
Expand Down Expand Up @@ -603,7 +608,7 @@ end
ADLogit(a::T, b::T) where {T<:Real} = ADLogit{T, ADBackend()}(a, b)

(b::ADLogit)(x) = @. logit((x - b.a) / (b.b - b.a))
(ib::Inversed{<:ADLogit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a
(ib::Inverse{<:ADLogit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a
```
No implementation of `logabsdetjac`, but:
Expand Down Expand Up @@ -697,7 +702,7 @@ If anything is lacking or not clear in docstrings, feel free to open an issue or
The following are the bijectors available:
- Abstract:
- `Bijector`: super-type of all bijectors.
- `ADBijector{AD} <: Bijector`: subtypes of this only require the user to implement `(b::UserBijector)(x)` and `(ib::Inversed{<:UserBijector})(y)`. Automatic differentation will be used to compute the `jacobian(b, x)` and thus `logabsdetjac(b, x).
- `ADBijector{AD} <: Bijector`: subtypes of this only require the user to implement `(b::UserBijector)(x)` and `(ib::Inverse{<:UserBijector})(y)`. Automatic differentation will be used to compute the `jacobian(b, x)` and thus `logabsdetjac(b, x).
- Concrete:
- `Composed`: represents a composition of bijectors.
- `Stacked`: stacks univariate and multivariate bijectors
Expand All @@ -718,7 +723,7 @@ The distribution interface consists of:
#### Methods
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inversed{<:Bijector}`.
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
Expand Down
3 changes: 3 additions & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
environment:
matrix:
- julia_version: 1
- julia_version: 1.1
- julia_version: 1.2
- julia_version: 1.3
- julia_version: nightly

platform:
Expand Down

0 comments on commit b0aaa98

Please sign in to comment.