-
Notifications
You must be signed in to change notification settings - Fork 32
/
shift.jl
30 lines (23 loc) · 1.33 KB
/
shift.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#################
# Shift & Scale #
#################
struct Shift{T, N} <: Bijector{N}
a::T
end
@functor Shift
Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a
function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Shift{typeof(a), D}(a)
end
up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a)
(b::Shift)(x) = b.a .+ x
(b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x)
inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a)
# FIXME: implement custom adjoint to ensure we don't get tracking
logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N))
_logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x))
_logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x))
_logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2)
_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x, 2))
_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2)
_logabsdetjac_shift(a::T1, x::AbstractArray{<:AbstractMatrix{T2}}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x))