Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New interface #27

Merged
merged 86 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
839c882
initial implementation of new interface
torfjelde Jul 8, 2019
65fce15
fixed the testing
torfjelde Jul 8, 2019
0be22a6
now adheres to the style guide
torfjelde Jul 8, 2019
16d83aa
cant have type templates on next line...
torfjelde Jul 8, 2019
4c90e1e
made more similar to the interface suggested by @xukai92
torfjelde Jul 10, 2019
c182b1f
fixed typo in error message
torfjelde Jul 10, 2019
52b5deb
added composition of Bijectors
torfjelde Jul 10, 2019
88c681d
improved composing compositions and added overloading of circ
torfjelde Jul 11, 2019
b413972
composition with its own inverse now results in Identity
torfjelde Jul 11, 2019
aa65723
made a constant IdentityBijector
torfjelde Jul 11, 2019
08f5e1a
added Logit bijector and specialization for Beta distribution
torfjelde Jul 11, 2019
2e71597
fixed a typo in transform of inv(Logit)
torfjelde Jul 11, 2019
625917e
fixed typo
torfjelde Jul 12, 2019
d043ef4
added jacobian function for bijectors, with AD implementations
torfjelde Jul 13, 2019
1940748
added jacobian using AD for inverses of ADBijector
torfjelde Jul 13, 2019
15f1af2
added simple AD-types and removed dep on Turing
torfjelde Jul 17, 2019
9758bcd
Add norm flows
sharanry Jul 23, 2019
b8c22db
minor changes
sharanry Jul 23, 2019
22b1936
fix bugs
sharanry Jul 23, 2019
5719e5d
fix more bugs
sharanry Jul 23, 2019
8104a28
fix bugs and follow style guide
sharanry Jul 25, 2019
494075d
add tests on logabsdetjacob
sharanry Jul 25, 2019
59aeceb
fix spaces
sharanry Jul 25, 2019
0d7a70e
defaults to using Tuple but allow any containers
torfjelde Jul 26, 2019
a3bf614
fixing style-issues
torfjelde Jul 27, 2019
b0a01e9
add iterative norm for planar flows
sharanry Jul 31, 2019
96c2c7f
fix minor bug
sharanry Jul 31, 2019
2109d6d
adhere to stylecode, add radial inverse, remove tracker dependency, r…
sharanry Aug 1, 2019
b727f5e
fix radius bug
sharanry Aug 7, 2019
6c1a5a0
fix param dependency
sharanry Aug 7, 2019
0c026d4
fix inv() for radial flow; follow style guidelines
sharanry Aug 13, 2019
d3ff9a6
minor change to test
sharanry Aug 13, 2019
8ca489a
minor fix
sharanry Aug 14, 2019
fb2399f
add ref to paper for each equation
sharanry Aug 14, 2019
45a6713
fix forward and remove redundant
sharanry Aug 14, 2019
3f86c99
remove update_u_hat!() requirement
sharanry Aug 15, 2019
d882406
update tests and ifx bug
sharanry Aug 17, 2019
bac200f
update tests
sharanry Aug 17, 2019
f276e57
implement Bijector call. We can now transform using BijectorName(x)
sharanry Aug 17, 2019
23a0501
Add inv and rand functions for composed bijectors
sharanry Aug 17, 2019
dc73717
Add direct calls for norm flows
sharanry Aug 18, 2019
d754b74
added docstrings and moved away from transform to callables
torfjelde Aug 19, 2019
4a7b260
Fix the two remaining issues in https://github.com/torfjelde/Bijector…
xukai92 Aug 19, 2019
b22ef2a
Merge branch 'tor/interface' into norm_flow
torfjelde Aug 22, 2019
d903d18
Merge pull request #1 from sharanry/norm_flow
torfjelde Aug 22, 2019
1a712b2
adapted flows to new interface
torfjelde Aug 22, 2019
60e783a
Merge branch 'master' into tor/interface
torfjelde Aug 22, 2019
1c4b77d
removed random Revise import
torfjelde Aug 22, 2019
f321c82
introduced TransformedDistribution and MatrixTransformed as subtype
torfjelde Aug 22, 2019
71fcca5
fixed typo
torfjelde Aug 22, 2019
a4ffc54
added proper implementation of bijector(d) for UnitDistribution
torfjelde Aug 22, 2019
a8c2d8f
added Scale and Shift together with bijector for Truncated
torfjelde Aug 22, 2019
5f8b4a0
removed redundant line in logabsdetjac
torfjelde Aug 22, 2019
3054ad2
fixed a typo and added some more tests
torfjelde Aug 22, 2019
883a8ec
removed some unecessary commented code
torfjelde Aug 22, 2019
5484f6d
addded some comments
torfjelde Aug 22, 2019
3565a18
added SimplexBijector and forward(flow, ::Vector) now returns vector
torfjelde Aug 22, 2019
7cf3f6c
removed left-over commented out code
torfjelde Aug 24, 2019
6cb2f61
updated Manifest
torfjelde Aug 24, 2019
3ee2584
fixed issue with _transform for RadialLayer when using Tracked
torfjelde Aug 24, 2019
9397098
forgot a sqrt in previous commit
torfjelde Aug 24, 2019
2bf33d5
removed now redundant hack for Dirichlet
torfjelde Aug 24, 2019
0db9988
fixed stackoverflow on forward(::RadialLayer, x::AbstractArray) due t…
torfjelde Aug 24, 2019
0d5b78c
added recursive implementation for forward(cb::Composed, x)
torfjelde Aug 24, 2019
a6d4c36
support for batch computation using forward(b::Bijector, x, logjac)
torfjelde Aug 24, 2019
e2f26df
edited some comments
torfjelde Aug 24, 2019
4638dd2
fixed a typo in RadialLayer forward
torfjelde Aug 24, 2019
24ce4cc
replaced forward(b, x, logjac) with fused logjac instead
torfjelde Aug 24, 2019
7c157f9
initializing recursive forward call using first result
torfjelde Aug 24, 2019
aa0f8b1
fixed logabsdetjac of Shift for batch, though it is ambiguous imo
torfjelde Aug 25, 2019
28bd710
added test for transform and inverse for univariate transformed
torfjelde Aug 26, 2019
afce05c
MvLogMvNormal no longer uses DistributionBijector
torfjelde Aug 26, 2019
4044155
captialized comments
torfjelde Aug 26, 2019
eb4bbc8
added my name to a comment
torfjelde Aug 26, 2019
a568f86
fixed a typo leading to Kolmogorov being treated as unit-contrained
torfjelde Aug 26, 2019
0d5a3be
added SingularJacobianException to logabsdetjac of ADBijector
torfjelde Aug 27, 2019
323e61a
made changes to adhere to style-guide
torfjelde Aug 27, 2019
04f5afa
mode style-changes
torfjelde Aug 27, 2019
68355c7
added a more useful forward(td::TransformedDistribution)
torfjelde Aug 27, 2019
a2336d5
replaced constant variables log_b and exp_b with Log and Exp
torfjelde Aug 27, 2019
d71b87e
compose replaced by composel and composer
torfjelde Aug 28, 2019
d6856d1
fixed sign typo and added AD verification
torfjelde Aug 28, 2019
a851abe
added special logpdf_forward for dirichlet in forward
torfjelde Aug 29, 2019
ea6f1ee
added docstring to forward(d::Distribution)
torfjelde Aug 29, 2019
fc2351a
change variable name in forward(d::Distribution)
torfjelde Aug 29, 2019
b49aac5
increment version number to 0.4.0
torfjelde Aug 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 112 additions & 18 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.0.0"

[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"]
git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f"
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.0"
version = "0.3.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand All @@ -14,22 +20,45 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
version = "0.5.6"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataAPI]]
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.0.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
Expand All @@ -39,6 +68,18 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -49,6 +90,12 @@ git-tree-sha1 = "022e6610c320b6e19b454502d759c672580abe00"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.18.0"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -66,6 +113,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.1"

[[MappedArrays]]
deps = ["Test"]
git-tree-sha1 = "923441c5ac942b60bd3a842d5377d96646bcbf46"
Expand All @@ -77,14 +130,26 @@ deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"
version = "0.4.1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
Expand All @@ -93,9 +158,9 @@ version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551"
git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.6"
version = "0.9.9"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -137,6 +202,12 @@ git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.5.0"

[[Roots]]
deps = ["Printf"]
git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "0.8.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

Expand Down Expand Up @@ -166,15 +237,21 @@ git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.11.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.29.0"
version = "0.32.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions", "Test"]
Expand All @@ -183,13 +260,30 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.8.0"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.6"

[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.2"

[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.3.2"
version = "0.4.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
46 changes: 35 additions & 11 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Reexport, Requires
using StatsFuns
using LinearAlgebra
using MappedArrays
using Roots

export TransformDistribution,
PositiveDistribution,
Expand All @@ -13,7 +14,26 @@ export TransformDistribution,
PDMatDistribution,
link,
invlink,
logpdf_with_trans
logpdf_with_trans,
transform,
forward,
logabsdetjac,
logabsdetjacinv,
Bijector,
ADBijector,
Inversed,
Composed,
compose,
Identity,
DistributionBijector,
bijector,
transformed,
UnivariateTransformed,
MultivariateTransformed,
logpdf_with_jac,
logpdf_forward,
PlanarLayer,
RadialLayer

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))

Expand Down Expand Up @@ -162,8 +182,8 @@ function _clamp(x::T, dist::SimplexDistribution) where T
end

function link(
d::SimplexDistribution,
x::AbstractVector{T},
d::SimplexDistribution,
x::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
y, K = similar(x), length(x)
Expand Down Expand Up @@ -191,8 +211,8 @@ end

# Vectorised implementation of the above.
function link(
d::SimplexDistribution,
X::AbstractMatrix{T},
d::SimplexDistribution,
X::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
Y, K, N = similar(X), size(X, 1), size(X, 2)
Expand All @@ -219,8 +239,8 @@ function link(
end

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
d::SimplexDistribution,
y::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
x, K = similar(y), length(y)
Expand All @@ -245,8 +265,8 @@ end

# Vectorised implementation of the above.
function invlink(
d::SimplexDistribution,
Y::AbstractMatrix{T},
d::SimplexDistribution,
Y::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
X, K, N = similar(Y), size(Y, 1), size(Y, 2)
Expand Down Expand Up @@ -335,8 +355,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{<:Real})
end

function logpdf_with_trans(
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
transform::Bool
)
T = eltype(X)
Expand Down Expand Up @@ -424,4 +444,8 @@ function logpdf_with_trans(
return map(x -> logpdf_with_trans(d, x, transform), X)
end

include("interface.jl")

include("norm_flows.jl")

end # module