-
Notifications
You must be signed in to change notification settings - Fork 32
/
normalise.jl
97 lines (83 loc) · 2.88 KB
/
normalise.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
using Statistics: mean
# Code adapted from Flux.jl
# Ref: https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L93-L177
# License: https://github.com/FluxML/Flux.jl/blob/master/LICENSE.md
istraining() = false
mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
b :: T1 # bias
logs :: T1 # log-scale
m :: T2 # moving mean
v :: T2 # moving variance
eps :: T3
mtm :: T3 # momentum
end
@functor InvertibleBatchNorm
function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
b1.m == b2.m &&
b1.v == b2.v &&
b1.eps == b2.eps &&
b1.mtm == b2.mtm
end
function InvertibleBatchNorm(
chs::Int;
eps::T=1f-5,
mtm::T=1f-1,
) where {T<:AbstractFloat}
return InvertibleBatchNorm(
zeros(T, chs),
zeros(T, chs), # logs = 0 means s = 1
zeros(T, chs),
ones(T, chs),
eps,
mtm,
)
end
function forward(bn::InvertibleBatchNorm, x)
dims = ndims(x)
size(x, dims - 1) == length(bn.b) ||
error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))")
channels = size(x, dims - 1)
as = ntuple(i -> i == ndims(x) - 1 ? size(x, i) : 1, dims)
logs = reshape(bn.logs, as...)
s = exp.(logs)
b = reshape(bn.b, as...)
if istraining()
n = div(prod(size(x)), channels)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
m = mean(x, dims = axes)
v = sum((x .- m) .^ 2, dims = axes) ./ n
# Update moving mean and variance
mtm = bn.mtm
T = eltype(bn.m)
bn.m = (1 - mtm) .* bn.m .+ mtm .* T.(reshape(m, :))
bn.v = (1 - mtm) .* bn.v .+ (mtm * n / (n - 1)) .* T.(reshape(v, :))
else
m = reshape(bn.m, as...)
v = reshape(bn.v, as...)
end
rv = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b
logabsdetjac = (
fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims))
)
return (rv=rv, logabsdetjac=logabsdetjac)
end
logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac
(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv
function forward(invbn::Inverse{<:InvertibleBatchNorm}, y)
@assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode."
dims = ndims(y)
as = ntuple(i -> i == ndims(y) - 1 ? size(y, i) : 1, dims)
bn = inv(invbn)
s = reshape(exp.(bn.logs), as...)
b = reshape(bn.b, as...)
m = reshape(bn.m, as...)
v = reshape(bn.v, as...)
x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m
return (rv=x, logabsdetjac=-logabsdetjac(bn, x))
end
(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv
function Base.show(io::IO, l::InvertibleBatchNorm)
print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))")
end