Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rational Quadratic Spline #80

Merged
merged 83 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
55e1e2a
added the Stacked bijector
torfjelde Aug 30, 2019
d82c9da
a couple of style fixes
torfjelde Aug 30, 2019
d557b3d
added size assertion to Stacked and better testing
torfjelde Aug 30, 2019
e168ce0
added Stacked test to tests for norm flows
torfjelde Aug 30, 2019
ea5b57c
export TransformedDistribution
torfjelde Sep 2, 2019
fae5dbf
added som useful implementations for TransformedDistribution
torfjelde Sep 2, 2019
fee6aad
fixed composer which had a bug leftover from previous PR
torfjelde Sep 6, 2019
af38dc2
removed vectorization in logabsdetjac to fail loadly rather than sile…
torfjelde Sep 6, 2019
33758bd
added dimension of expected input to Bijector type
torfjelde Sep 12, 2019
ad54c16
composition now fails upon dimension mismatch
torfjelde Sep 13, 2019
0ddf71d
removed dots in logabsdetjac accumulates so as to not fail silently
torfjelde Sep 13, 2019
32ed5c4
updated Identity to work with batches
torfjelde Sep 13, 2019
29562af
batch-specialization for Log and Exp
torfjelde Sep 13, 2019
40ed811
batch-specialization for Scale and Shift
torfjelde Sep 13, 2019
669e72d
added dimension method for bijectors
torfjelde Sep 13, 2019
f08dc51
added tests for batch computation for a bunch of bijectors
torfjelde Sep 13, 2019
e26c893
removed a false comment
torfjelde Sep 13, 2019
589e564
removed some left-over stuff from a failed merge
torfjelde Sep 13, 2019
a16f05a
added some convenient constructors for SimplexBijector
torfjelde Sep 16, 2019
09d1bd6
fixed bug from previous commit and added tests for Stacked
torfjelde Sep 16, 2019
b48c013
Merge branch 'master' into torfjelde/stacked
torfjelde Sep 16, 2019
353e80c
fixed some left-over stuff from a merge
torfjelde Sep 16, 2019
98860a5
added comments convincing myself that entropy is invariant
torfjelde Sep 16, 2019
6e7a174
removed redundant comment
torfjelde Sep 16, 2019
b63fb29
fixed tests which had be messed up in the merge
torfjelde Sep 16, 2019
edaed9a
added message for size-discrepancy in Stacked and tests
torfjelde Sep 16, 2019
976ddf6
remvoed reexport of StatsBase
torfjelde Sep 16, 2019
dd6ec2a
fixed a couple of typos
torfjelde Sep 16, 2019
c38f203
checking something with CI
torfjelde Sep 16, 2019
f5d68eb
more CI test stuff
torfjelde Sep 16, 2019
898ae50
fixed the test which failed on CI
torfjelde Sep 16, 2019
ddc8cd7
okay now I fixed it
torfjelde Sep 16, 2019
8b4eecb
okay NOW i fixed the test
torfjelde Sep 16, 2019
3be1af3
added asserts to evaluation of SimplexBijector to catch length-1
torfjelde Sep 16, 2019
6e4f3c8
made Stacked constructor more strict as it should be
torfjelde Sep 16, 2019
8a925b2
added messages to assertions for SimplexBijector
torfjelde Sep 16, 2019
ea2edde
updated Manifest
torfjelde Sep 16, 2019
a5af0fc
updated README to include info about Stacked
torfjelde Sep 16, 2019
7ff118d
added Stacked to reference-section
torfjelde Sep 16, 2019
2707443
removed useless line
torfjelde Sep 16, 2019
5a7a76a
added forward-specialization for Stacked
torfjelde Sep 16, 2019
64070ea
removed something by accident last commit
torfjelde Sep 16, 2019
9318f95
fixed mixed Stacked and added tests
torfjelde Sep 16, 2019
89c16cd
removed no-longer-needed TODO and added more docstring
torfjelde Sep 16, 2019
fd4eec1
fixed comment
torfjelde Sep 16, 2019
2f8c1a0
removed redundant whitespace
torfjelde Sep 16, 2019
597bb96
added assert-check to array-impl of Stacked
torfjelde Sep 16, 2019
6daf7aa
removed redundant type and fixed typo
torfjelde Sep 17, 2019
fce177b
removed unused consts
torfjelde Sep 17, 2019
170b15e
made the forward(b::Stacked, ...) generated slightly nicer
torfjelde Sep 17, 2019
0f06735
bump to 0.4.1
torfjelde Sep 18, 2019
a26a479
Merge branch 'master' into torfjelde/stacked
torfjelde Sep 18, 2019
a85da36
fixed size-stuff for the flows
torfjelde Sep 19, 2019
7f56b0d
Merge branch 'master' into tor/batch-support
torfjelde Sep 19, 2019
07c0d87
added more tests for batch computation
torfjelde Sep 19, 2019
b458dab
updated SimplexBijector
torfjelde Sep 19, 2019
a47516d
added more and better testing for batch support
torfjelde Sep 19, 2019
4b7421b
Merge branch 'torfjelde/stacked' into tor/batch-support
torfjelde Sep 19, 2019
3182f0d
added dimensionality to Stacked and tests
torfjelde Sep 19, 2019
b89423f
replaced vcat with stack to avoid unnecessary confusion
torfjelde Sep 19, 2019
5442287
Merge branch 'master' into tor/batch-support
torfjelde Sep 19, 2019
7e7c5af
dropped redundant dimensionality for TransformedDistribution
torfjelde Sep 19, 2019
dd5770c
improved support for Tracker.jl
Sep 25, 2019
ff56f75
fixed Shift inverse and added matrix-scaling of vectors
torfjelde Oct 1, 2019
64ff961
initial work on RationalQuadraticSpline
torfjelde Oct 19, 2019
3f42035
added implementations of for 1D case rather than just 0D
torfjelde Oct 22, 2019
e359203
initial RQS
torfjelde Feb 11, 2020
7c45245
Merge branch 'master' into tor/rational-quadratic-spline
torfjelde Sep 11, 2020
b622331
cleaned up RQS significantly
torfjelde Sep 11, 2020
97b0681
now includes RQS
torfjelde Sep 11, 2020
286ee81
removed some commented out code
torfjelde Sep 11, 2020
6d439a9
added a proper docstring for RQS
torfjelde Sep 11, 2020
be7b676
style changes
torfjelde Sep 11, 2020
6b6bafd
added tests for RQS
torfjelde Sep 11, 2020
6cf574c
added RQS tests to runtests
torfjelde Sep 11, 2020
70718df
added proper testing to RQS and utils for Bijectors
torfjelde Sep 11, 2020
5ad76f8
small changes to batch-implementations
torfjelde Sep 11, 2020
2efb15d
uhhmm, I had been working on the wrong branch all along...
torfjelde Sep 11, 2020
771bcd8
added AD-based tests for logabsdetjac in test_bijector
torfjelde Sep 12, 2020
5d0ce8f
fixed type instability in return-type for RQS
torfjelde Sep 12, 2020
160f44b
added a test for RQS to interface.jl but should move away from this
torfjelde Sep 12, 2020
818cf2d
Merge branch 'master' into tor/rqs-master
torfjelde Sep 28, 2020
e301948
commas are difficult okay
torfjelde Sep 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
387 changes: 387 additions & 0 deletions src/bijectors/rational_quadratic_spline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
using NNlib

"""
RationalQuadraticSpline{T, 0} <: Bijector{0}
RationalQuadraticSpline{T, 1} <: Bijector{1}

Implementation of the Rational Quadratic Spline flow [1].

- Outside of the interval `[minimum(widths), maximum(widths)]`, this mapping is given
by the identity map.
- Inside the interval it's given by a monotonic spline (i.e. monotonic polynomials
connected at intermediate points) with endpoints fixed so as to continuously transform
into the identity map.

For the sake of efficiency, there are separate implementations for 0-dimensional and
1-dimensional inputs.

# Notes
There are two constructors for `RationalQuadraticSpline`:
- `RationalQuadraticSpline(widths, heights, derivatives)`: it is assumed that `widths`,
`heights`, and `derivatives` satisfy the constraints that makes this a valid bijector, i.e.
- `widths`: monotonically increasing and `length(widths) == K`,
- `heights`: monotonically increasing and `length(heights) == K`,
- `derivatives`: non-negative and `derivatives[1] == derivatives[end] == 1`.
- `RationalQuadraticSpline(widths, heights, derivatives, B)`: other than than the lengths,
no assumptions are made on parameters. Therefore we will transform the parameters s.t.:
- `widths_new` ∈ [-B, B]ᴷ⁺¹, where `K == length(widths)`,
- `heights_new` ∈ [-B, B]ᴷ⁺¹, where `K == length(heights)`,
- `derivatives_new` ∈ (0, ∞)ᴷ⁺¹ with `derivatives_new[1] == derivates_new[end] == 1`,
where `(K - 1) == length(derivatives)`.

# Examples
## Univariate
```julia-repl
julia> using Bijectors: RationalQuadraticSpline

julia> K = 3; B = 2;

julia> # Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points".
b = RationalQuadraticSpline(randn(K), randn(K), randn(K - 1), B);

julia> b(0.5) # inside of `[-B, B]` → transformed
1.412300607463467

julia> b(5.) # outside of `[-B, B]` → not transformed
5.0
```
Or we can use the constructor with the parameters correctly constrained:
```julia-repl
julia> b = RationalQuadraticSpline(b.widths, b.heights, b.derivatives);

julia> b(0.5) # inside of `[-B, B]` → transformed
1.412300607463467
```
## Multivariate
```julia-repl
julia> d = 2; K = 3; B = 2;

julia> b = RationalQuadraticSpline(randn(d, K), randn(d, K), randn(d, K - 1), B);

julia> b([-1., 1.])
2-element Array{Float64,1}:
-1.2568224171342797
0.5537259740554675

julia> b([-5., 5.])
2-element Array{Float64,1}:
-5.0
5.0

julia> b([-1., 5.])
2-element Array{Float64,1}:
-1.2568224171342797
5.0
```

# References
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
"""
struct RationalQuadraticSpline{T, N} <: Bijector{N}
widths::T # K widths
heights::T # K heights
derivatives::T # K derivatives, with endpoints being ones

function RationalQuadraticSpline(
widths::T,
heights::T,
derivatives::T
) where {T<:AbstractVector}
# TODO: add a `NoArgCheck` type and argument so we can circumvent if we want
@assert length(widths) == length(heights) == length(derivatives)
@assert all(derivatives .> 0) "derivatives need to be positive"

return new{T, 0}(widths, heights, derivatives)
end

function RationalQuadraticSpline(
widths::T,
heights::T,
derivatives::T
) where {T<:AbstractMatrix}
@assert size(widths, 2) == size(heights, 2) == size(derivatives, 2)
@assert all(derivatives .> 0) "derivatives need to be positive"
return new{T, 1}(widths, heights, derivatives)
end
end

function RationalQuadraticSpline(
widths::A,
heights::A,
derivatives::A,
B::T2
) where {T1, T2, A <: AbstractVector{T1}}
# Using `NNLlinb.softax` instead of `StatsFuns.softmax` (which does inplace operations)
return RationalQuadraticSpline(
(cumsum(vcat([zero(T1)], NNlib.softmax(widths))) .- 0.5) * 2 * B,
(cumsum(vcat([zero(T1)], NNlib.softmax(heights))) .- 0.5) * 2 * B,
vcat([one(T1)], softplus.(derivatives), [one(T1)])
)
end

function RationalQuadraticSpline(
widths::A,
heights::A,
derivatives::A,
B::T2
) where {T1, T2, A <: AbstractMatrix{T1}}
ws = hcat(zeros(T1, size(widths, 1)), NNlib.softmax(widths; dims = 2))
hs = hcat(zeros(T1, size(widths, 1)), NNlib.softmax(heights; dims = 2))
ds = hcat(ones(T1, size(widths, 1)), softplus.(derivatives), ones(T1, size(widths, 1)))

return RationalQuadraticSpline(
(2 * B) .* (cumsum(ws; dims = 2) .- 0.5),
(2 * B) .* (cumsum(hs; dims = 2) .- 0.5),
ds
)
end

##########################
### Forward evaluation ###
##########################
function rqs_univariate(widths, heights, derivatives, x::Real)
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x))

# We're working on [-B, B] and `widths[end]` is `B`
if (x ≤ -widths[end]) || (x ≥ widths[end])
return one(T) * x
end

K = length(widths)

# Find which bin `x` is in; subtract 1 because `searchsortedfirst` returns idx of ≥ not ≤
k = searchsortedfirst(widths, x) - 1

# Width
# If k == 0 then we should put it in the bin `[-B, widths[1]]`
w_k = (k == 0) ? -widths[end] : widths[k]
w = widths[k + 1] - w_k

# Slope
h_k = (k == 0) ? -heights[end] : heights[k]
Δy = heights[k + 1] - h_k

s = Δy / w
ξ = (x - w_k) / w

# Derivatives at knot-points
# Note that we have (K - 1) knot-points, not K
d_k = (k == 0) ? one(T) : derivatives[k]
d_kplus1 = (k == K - 1) ? one(T) : derivatives[k + 1]

# Eq. (14)
numerator = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ))
denominator = s + (d_kplus1 + d_k - 2s) * ξ * (1 - ξ)
g = h_k + numerator / denominator

return g
end


# univariate
function (b::RationalQuadraticSpline{<:AbstractVector, 0})(x::Real)
return rqs_univariate(b.widths, b.heights, b.derivatives, x)
end
(b::RationalQuadraticSpline{<:AbstractVector, 0})(x::AbstractVector) = b.(x)

# multivariate
function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractVector)
return [rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x)]
end
function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractMatrix)
return eachcolmaphcat(b, x)
end

##########################
### Inverse evaluation ###
##########################
function rqs_univariate_inverse(widths, heights, derivatives, y::Real)
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(y))

if (y ≤ -heights[end]) || (y ≥ heights[end])
return one(T) * y
end

K = length(widths)
k = searchsortedfirst(heights, y) - 1

# Width
w_k = (k == 0) ? -widths[end] : widths[k]
w = widths[k + 1] - w_k

# Slope
h_k = (k == 0) ? -heights[end] : heights[k]
Δy = heights[k + 1] - h_k

# Recurring quantities
s = Δy / w
d_k = (k == 0) ? one(T) : derivatives[k]
d_kplus1 = (k == K - 1) ? one(T) : derivatives[k + 1]
ds = d_kplus1 + d_k - 2 * s

# Eq. (25)
a1 = Δy * (s - d_k) + (y - h_k) * ds
# Eq. (26)
a2 = Δy * d_k - (y - h_k) * ds
# Eq. (27)
a3 = - s * (y - h_k)

# Eq. (24). There's a mistake in the paper; says `x` but should be `ξ`
numerator = - 2 * a3
denominator = (a2 + sqrt(a2^2 - 4 * a1 * a3))
ξ = numerator / denominator

return ξ * w + w_k
end

function (ib::Inverse{<:RationalQuadraticSpline, 0})(y::Real)
return rqs_univariate_inverse(ib.orig.widths, ib.orig.heights, ib.orig.derivatives, y)
end
(ib::Inverse{<:RationalQuadraticSpline, 0})(y::AbstractVector) = ib.(y)

function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractVector)
b = ib.orig
return [rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) for i = 1:length(y)]
end
function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractMatrix)
return eachcolmaphcat(ib, y)
end

######################
### `logabsdetjac` ###
######################
function rqs_logabsdetjac(widths, heights, derivatives, x::Real)
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(y))
K = length(widths) - 1

# Find which bin `x` is in
k = searchsortedfirst(widths, x) - 1

if k > K || k == 0
return zero(T) * x
end

# Width
w = widths[k + 1] - widths[k]

# Slope
Δy = heights[k + 1] - heights[k]

# Recurring quantities
s = Δy / w
ξ = (x - widths[k]) / w

numerator = s^2 * (derivatives[k + 1] * ξ^2
+ 2 * s * ξ * (1 - ξ)
+ derivatives[k] * (1 - ξ)^2)
denominator = s + (derivatives[k + 1] + derivatives[k] - 2 * s) * ξ * (1 - ξ)

return log(numerator) - 2 * log(denominator)
end

function rqs_logabsdetjac(
widths::AbstractVector,
heights::AbstractVector,
derivatives::AbstractVector,
x::Real
)
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x))

if (x ≤ -widths[end]) || (x ≥ widths[end])
return zero(T) * x
end

K = length(widths)
k = searchsortedfirst(widths, x) - 1

# Width
w_k = (k == 0) ? -widths[end] : widths[k]
w = widths[k + 1] - w_k

# Slope
h_k = (k == 0) ? -heights[end] : heights[k]
Δy = heights[k + 1] - h_k

# Recurring quantities
s = Δy / w
ξ = (x - w_k) / w

d_k = (k == 0) ? one(T) : derivatives[k]
d_kplus1 = (k == K - 1) ? one(T) : derivatives[k + 1]

numerator = s^2 * (d_kplus1 * ξ^2 + 2 * s * ξ * (1 - ξ) + d_k * (1 - ξ)^2)
denominator = s + (d_kplus1 + d_k - 2 * s) * ξ * (1 - ξ)

return log(numerator) - 2 * log(denominator)
end

function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real)
return rqs_logabsdetjac(b.widths, b.heights, b.derivatives, x)
end
function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::AbstractVector)
return logabsdetjac.(b, x)
end
function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractVector)
return sum([
rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i])
for i = 1:length(x)
])
end
function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractMatrix)
return mapvcat(x -> logabsdetjac(b, x), eachcol(x))
end

#################
### `forward` ###
#################

# TODO: implement this for `x::AbstractVector` and similarily for 1-dimensional `b`,
# and possibly inverses too?
function rqs_forward(
widths::AbstractVector,
heights::AbstractVector,
derivatives::AbstractVector,
x::Real
)
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x))

if (x ≤ -widths[end]) || (x ≥ widths[end])
return (rv = one(T) * x, logabsdetjac = zero(T) * x)
end

# Find which bin `x` is in
K = length(widths)
k = searchsortedfirst(widths, x) - 1

# Width
w_k = (k == 0) ? -widths[end] : widths[k]
w = widths[k + 1] - w_k

# Slope
h_k = (k == 0) ? -heights[end] : heights[k]
Δy = heights[k + 1] - h_k

# Recurring quantities
s = Δy / w
ξ = (x - w_k) / w

d_k = (k == 0) ? one(T) : derivatives[k]
d_kplus1 = (k == K - 1) ? one(T) : derivatives[k + 1]

# Re-used for both `logjac` and `y`
denominator = s + (d_kplus1 + d_k - 2 * s) * ξ * (1 - ξ)

# logjac
numerator_jl = s^2 * (d_kplus1 * ξ^2 + 2 * s * ξ * (1 - ξ) + d_k * (1 - ξ)^2)
logjac = log(numerator_jl) - 2 * log(denominator)

# y
numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ))
y = h_k + numerator_y / denominator

return (rv = y, logabsdetjac = logjac)
end

function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real)
return rqs_forward(b.widths, b.heights, b.derivatives, x)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/normalise.jl")
include("bijectors/rational_quadratic_spline.jl")

##################
# Other includes #
Expand Down