-
Notifications
You must be signed in to change notification settings - Fork 32
/
scale.jl
84 lines (73 loc) · 3.81 KB
/
scale.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
struct Scale{T, N} <: Bijector{N}
a::T
end
Scale(a::T; dim::Val{D} = Val(0)) where {T<:Real, D} = Scale{T, D}(a)
Scale(a::A; dim::Val{D} = Val(N)) where {T, D, N, A<:AbstractArray{T, N}} = Scale{A, D}(a)
(b::Scale)(x) = b.a .* x
(b::Scale{<:Real})(x::AbstractArray) = b.a .* x
(b::Scale{<:AbstractMatrix})(x::AbstractArray) = b.a * x
(b::Scale{<:AbstractVector{<:Real}, 2})(x::AbstractMatrix{<:Real}) = b.a .* x
inv(b::Scale{T, D}) where {T, D} = Scale(inv(b.a); dim = Val(D))
inv(b::Scale{<:AbstractVector, D}) where {D} = Scale(inv.(b.a); dim = Val(D))
# We're going to implement custom adjoint for this
logabsdetjac(b::Scale{T, N}, x) where {T, N} = _logabsdetjac_scale(b.a, x, Val(N))
_logabsdetjac_scale(a::Real, x::Real, ::Val{0}) = log(abs(a))
_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x))
_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{1}) = log(abs(a)) * length(x)
_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2))
_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(log.(abs.(a)))
_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(log.(abs.(a))), size(x, 2))
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = log(abs(det(a)))
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = log(abs(det(a))) * ones(T, size(x, 2))
# Adjoints for 0-dim and 1-dim `Scale` using `Real`
function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0})
return track(_logabsdetjac_scale, a, data(x), Val(0))
end
@grad function _logabsdetjac_scale(a::Real, x::Real, ::Val{0})
return _logabsdetjac_scale(data(a), data(x), Val(0)), Δ -> (inv(data(a)) .* Δ, nothing, nothing)
end
# Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors
function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0})
return track(_logabsdetjac_scale, a, data(x), Val(0))
end
@grad function _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0})
da = data(a)
J = fill(inv.(da), length(x))
return _logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing)
end
function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0})
return track(_logabsdetjac_scale, a, data(x), Val(0))
end
@grad function _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{0})
da = data(a)
J = fill(size(x, 1) / da, size(x, 2))
return _logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing)
end
# adjoints for 1-dim and 2-dim `Scale` using `AbstractVector`
function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1})
return track(_logabsdetjac_scale, a, data(x), Val(1))
end
@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1})
# ∂ᵢ (∑ⱼ log|aⱼ|) = ∑ⱼ δᵢⱼ ∂ᵢ log|aⱼ|
# = ∂ᵢ log |aᵢ|
# = (1 / aᵢ) ∂ᵢ aᵢ
# = (1 / aᵢ)
da = data(a)
J = inv.(da)
return _logabsdetjac_scale(da, data(x), Val(1)), Δ -> (J .* Δ, nothing, nothing)
end
function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1})
return track(_logabsdetjac_scale, a, data(x), Val(1))
end
@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1})
da = data(a)
Jᵀ = repeat(inv.(da), 1, size(x, 2))
return _logabsdetjac_scale(da, data(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing)
end
# TODO: implement analytical gradient for scaling a vector using a matrix
# function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1})
# track(_logabsdetjac_scale, a, data(x), Val{1})
# end
# @grad function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1})
# throw
# end