-
Notifications
You must be signed in to change notification settings - Fork 32
/
leaky_relu.jl
95 lines (77 loc) · 2.78 KB
/
leaky_relu.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
LeakyReLU{T, N}(α::T) <: Bijector{N}
Defines the invertible mapping
x ↦ x if x ≥ 0 else αx
where α > 0.
"""
struct LeakyReLU{T, N} <: Bijector{N}
α::T
end
@functor LeakyReLU
LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)
up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)
# (N=0) Univariate case
function (b::LeakyReLU{<:Any, 0})(x::Real)
mask = x < zero(x)
return mask * b.α * x + !mask * x
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x)
function Base.inv(b::LeakyReLU{<:Any,N}) where N
invα = inv.(b.α)
return LeakyReLU{typeof(invα),N}(invα)
end
function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
return log(abs(J))
end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x)
# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + !mask * one(x)
return (rv=J * x, logabsdetjac=log(abs(J)))
end
# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end
# (N=1) Multivariate case
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
return let z = zero(eltype(x))
@. (x < z) * b.α * x + (x > z) * x
end
end
function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
if x isa AbstractVector
return sum(log.(abs.(J)))
elseif x isa AbstractMatrix
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
end
# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
if x isa AbstractVector
logjac = sum(log.(abs.(J)))
elseif x isa AbstractMatrix
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
y = J .* x
return (rv=y, logabsdetjac=logjac)
end