Skip to content

Commit

Permalink
New interface (#27)
Browse files Browse the repository at this point in the history
* initial implementation of new interface

* fixed the testing

* now adheres to the style guide

* cant have type templates on next line...

* made more similar to the interface suggested by @xukai92

* fixed typo in error message

* added composition of Bijectors

* improved composing compositions and added overloading of circ

* composition with its own inverse now results in Identity

* made a constant IdentityBijector

* added Logit bijector and specialization for Beta distribution

* fixed a typo in transform of inv(Logit)

* fixed typo

* added jacobian function for bijectors, with AD implementations

* added jacobian using AD for inverses of ADBijector

* added simple AD-types and removed dep on Turing

* Add norm flows

* minor changes

* fix bugs

* fix more bugs

* fix bugs and follow style guide

* add tests on logabsdetjacob

* fix spaces

* defaults to using Tuple but allow any containers

* fixing style-issues

* add iterative norm for planar flows

* fix minor bug

* adhere to stylecode, add radial inverse, remove tracker dependency, restructure code

* fix radius bug

* fix param dependency

* fix inv() for radial flow; follow style guidelines

* minor change to test

* minor fix

* add ref to paper for each equation

* fix forward and remove redundant

* remove update_u_hat!() requirement

* update tests and ifx bug

* update tests

* implement Bijector call. We can now transform using BijectorName(x)

* Add inv and rand functions for composed bijectors

* Add direct calls for norm flows

* added docstrings and moved away from transform to callables

* Fix the two remaining issues in torfjelde#1 (#1)

* add dep and update deps

* update test and fix the randomness

* replace transpose(x) with x'

* use tab to align comments

* replace transpose(x) with x'

* unify transform code for RadialLayer

* unify transform code for PlanarLayer

* improve code style and add comments

* adapted flows to new interface

* removed random Revise import

* introduced TransformedDistribution and MatrixTransformed as subtype

* fixed typo

* added proper implementation of bijector(d) for UnitDistribution

* added Scale and Shift together with bijector for Truncated

* removed redundant line in logabsdetjac

* fixed a typo and added some more tests

* removed some unecessary commented code

* addded some comments

* added SimplexBijector and forward(flow, ::Vector) now returns vector

* removed left-over commented out code

* updated Manifest

* fixed issue with _transform for RadialLayer when using Tracked

* forgot a sqrt in previous commit

* removed now redundant hack for Dirichlet

* fixed stackoverflow on forward(::RadialLayer, x::AbstractArray) due to typo

* added recursive implementation for forward(cb::Composed, x)

* support for batch computation using forward(b::Bijector, x, logjac)

* edited some comments

* fixed a typo in RadialLayer forward

* replaced forward(b, x, logjac) with fused logjac instead

* initializing recursive forward call using first result

* fixed logabsdetjac of Shift for batch, though it is ambiguous imo

* added test for transform and inverse for univariate transformed

* MvLogMvNormal no longer uses DistributionBijector

* captialized comments

* added my name to a comment

* fixed a typo leading to Kolmogorov being treated as unit-contrained

* added SingularJacobianException to logabsdetjac of ADBijector

* made changes to adhere to style-guide

* mode style-changes

* added a more useful forward(td::TransformedDistribution)

* replaced constant variables log_b and exp_b with Log and Exp

* compose replaced by composel and composer

* fixed sign typo and added AD verification

* added special logpdf_forward for dirichlet in forward

* added docstring to forward(d::Distribution)

* change variable name in forward(d::Distribution)

* increment version number to 0.4.0
  • Loading branch information
torfjelde authored and yebai committed Aug 29, 2019
1 parent 2a7e423 commit 6ea2c37
Show file tree
Hide file tree
Showing 8 changed files with 1,394 additions and 30 deletions.
130 changes: 112 additions & 18 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.0.0"

[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"]
git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f"
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.0"
version = "0.3.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand All @@ -14,22 +20,45 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
version = "0.5.6"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataAPI]]
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.0.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
Expand All @@ -39,6 +68,18 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -49,6 +90,12 @@ git-tree-sha1 = "022e6610c320b6e19b454502d759c672580abe00"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.18.0"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -66,6 +113,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.1"

[[MappedArrays]]
deps = ["Test"]
git-tree-sha1 = "923441c5ac942b60bd3a842d5377d96646bcbf46"
Expand All @@ -77,14 +130,26 @@ deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"
version = "0.4.1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
Expand All @@ -93,9 +158,9 @@ version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551"
git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.6"
version = "0.9.9"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -137,6 +202,12 @@ git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.5.0"

[[Roots]]
deps = ["Printf"]
git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "0.8.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

Expand Down Expand Up @@ -166,15 +237,21 @@ git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.11.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.29.0"
version = "0.32.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions", "Test"]
Expand All @@ -183,13 +260,30 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.8.0"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.6"

[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.2"

[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.3.2"
version = "0.4.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
46 changes: 35 additions & 11 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Reexport, Requires
using StatsFuns
using LinearAlgebra
using MappedArrays
using Roots

export TransformDistribution,
PositiveDistribution,
Expand All @@ -13,7 +14,26 @@ export TransformDistribution,
PDMatDistribution,
link,
invlink,
logpdf_with_trans
logpdf_with_trans,
transform,
forward,
logabsdetjac,
logabsdetjacinv,
Bijector,
ADBijector,
Inversed,
Composed,
compose,
Identity,
DistributionBijector,
bijector,
transformed,
UnivariateTransformed,
MultivariateTransformed,
logpdf_with_jac,
logpdf_forward,
PlanarLayer,
RadialLayer

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))

Expand Down Expand Up @@ -162,8 +182,8 @@ function _clamp(x::T, dist::SimplexDistribution) where T
end

function link(
d::SimplexDistribution,
x::AbstractVector{T},
d::SimplexDistribution,
x::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
y, K = similar(x), length(x)
Expand Down Expand Up @@ -191,8 +211,8 @@ end

# Vectorised implementation of the above.
function link(
d::SimplexDistribution,
X::AbstractMatrix{T},
d::SimplexDistribution,
X::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
Y, K, N = similar(X), size(X, 1), size(X, 2)
Expand All @@ -219,8 +239,8 @@ function link(
end

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
d::SimplexDistribution,
y::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
x, K = similar(y), length(y)
Expand All @@ -245,8 +265,8 @@ end

# Vectorised implementation of the above.
function invlink(
d::SimplexDistribution,
Y::AbstractMatrix{T},
d::SimplexDistribution,
Y::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
X, K, N = similar(Y), size(Y, 1), size(Y, 2)
Expand Down Expand Up @@ -335,8 +355,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{<:Real})
end

function logpdf_with_trans(
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
transform::Bool
)
T = eltype(X)
Expand Down Expand Up @@ -424,4 +444,8 @@ function logpdf_with_trans(
return map(x -> logpdf_with_trans(d, x, transform), X)
end

include("interface.jl")

include("norm_flows.jl")

end # module

0 comments on commit 6ea2c37

Please sign in to comment.