Skip to content

Commit

Permalink
fix outputsize on LayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 5, 2022
1 parent ac34df9 commit 604f2b4
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,30 @@ outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chai

## bypass statistics in normalization layers

for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm)
@eval (l::$layer)(x::AbstractArray{Nil}) = x
# for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm)
# @eval (l::$layer)(x::AbstractArray{Nil}) = x
# end
for layer in (:BatchNorm, :InstanceNorm, :GroupNorm)
@eval function (l::$layer)(x::AbstractArray{Nil})
l.chs == size(x, ndims(x)-1) || throw(DimensionMismatch(
string($layer, " expected ", l.chs, " channels, but got ", _channelsize(x))))
x
end
end

_channelsize(x::AbstractArray) = size(x, ndims(x)-1)
_channelsize(x::AbstractVector) = size(x, 1)

function (l::LayerNorm)(x::AbstractArray{Nil,N}) where N
l.affine || return x
n = length(l.size)
l.size[1:min(n,N)] == size(x)[1:min(n,N)] || throw(DimensionMismatch(
string("LayerNorm expected size of input starting with ", l.size, ", but got size(x) == ", size(x))))
if n <= N
return x
else
return similar(x, l.size)
end
end

## fixes for layers that don't work out of the box
Expand Down

0 comments on commit 604f2b4

Please sign in to comment.