In [1]:
using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using MLUtils: DataLoader;
using OneHotArrays: onehotbatch
using Metrics;
using TimerOutputs;
using Flux: BatchNorm, kaiming_uniform, nfan;

In [2]:
"""
    Conv2D
2D convolution layer

# Fields
- `w::AbstractArray`: 4-D weight tensor for convolution kernels, of size (k1, k2, in_channels, out_channels)
- `b::AbstractVector`: Bias for output channels
- `use_bias::Bool`: Whether to apply bias
"""
mutable struct Conv2D{T}
    w::AbstractArray{T, 4}
    b::AbstractVector{T}
    use_bias::Bool
end

@functor Conv2D (w, b)


"""
    Conv2D(kernel_size, in_channels, out_channels; kwargs...)

Constructor for Conv2D layer
Weights are initialized using Kaiming uniform.

# Arguments
- `kernel_size::Tuple{Int, Int}`: 2-tuple of int for convolution kernel size
- `in_channels::Int`: Number of input channels to the convolution layer
- `out_channels::Int`: Number of output channels from the convolution layer

# Keywords
- `bias::Bool`: Whether to have a bias during convolution, by default false

# Returns
- `Conv2D`: Initialized Conv2D struct
"""
function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;
    bias::Bool=false)
    w_size = (kernel_size..., in_channels, out_channels)
    w = kaiming_uniform(w_size...)
    (fan_in, fan_out) = nfan(w_size)
    
    if bias
        # Init bias with fan_in from weights. Use gain = √2 for ReLU
        bound = √3 * √2 / √fan_in
        rng = Uniform(-bound, bound)
        b = rand(rng, out_channels, Float32)
    else
        b = zeros(Float32, out_channels)
    end

    return Conv2D(w, b, bias)
end


"""
    Conv2D(x; kwargs...)

Perform 2D convolution using initialized layer

# Arguments
- `x::AbstractArray`: 4D input image tensor of shape (width, height, channels, batch size)

# Keywords
- `stride::Int`: convolution stride, by default 1
- `pad::Int`: amount of padding on each side, by default 0
- `dilation::Int`: convolution dialation, by default 1

# Returns
- `y::AbstractArray`: Output 4D image tensor
"""
function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)
    y = conv(x, self.w; stride=stride, pad=pad, dilation=dilation)
    if self.use_bias
        # Bias is applied channel-wise
        (w, h, c, b) = size(y)
        bias = reshape(self.b, (1, 1, c, 1))
        y = y .+ bias
    end
    return y
end

Conv2D

In [3]:
"""
    ResNetLayer
ResNetV2 Layer
See "Identity Mappings in Deep Residual Networks", https://arxiv.org/pdf/1603.05027.pdf

# Fields
- `conv1::Conv2D`: First convolution layer
- `conv2::Conv2D`: Second convolution layer
- `bn1::BatchNorm`: First batch norm layer
- `bn2::BatchNorm`: Second batchnorm layer
- `f::Function`: Activation function
- `in_channels::Int`: Number of input channels to the layer
- `channels::Int`: Number of channels used throughout the layer
- `stride::Int`: Stride to apply to first convolution layer
"""
mutable struct ResNetLayer
    conv1::Conv2D
    conv2::Conv2D
    bn1::BatchNorm
    bn2::BatchNorm
    f::Function
    in_channels::Int
    channels::Int
    stride::Int
end

@functor ResNetLayer (conv1, conv2, bn1, bn2)

"""
    residual_identity(ResNetLayer, x)
Identity function for computing identity connection after a strided convolution
Downsamples `x` by a factor of `stride` and adds extra channels filled with zeros

# Arguments
- `layer::ResNetLayer`: ResNetLayer struct that has information about stride and channels
- `x::AbstractArray`: Input tensor of shape (width, height, channels, batch size)

# Returns
- `x_id::AbstractArray`: Downsampled identity tensor
"""
function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        x_id = copy(x[begin:2:end, begin:2:end, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

"""
    ResNetLayer(in_channels, channels; kwargs...)
Constructor for ResNetLayer

# Arguments
- `in_channels::Int`: Number of input channels to the layer
- `channels::Int`: Number of channels used throughout the layer

# Keywords
- `stride::Int`: Stride to apply in first convolution layer
- `f::Function`: Activation function to apply, by default relu

# Returns
- `ResNetLayer`: Initialized ResNetLayer
"""
function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)
    bn1 = BatchNorm(in_channels)
    conv1 = Conv2D((3, 3), in_channels, channels, bias=false)
    bn2 = BatchNorm(channels)
    conv2 = Conv2D((3, 3), channels, channels, bias=false)

    return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)
end


"""
    ResNetLayer(x)

Forward function for ResNetLayer
Applies stride in first convolution for downsampling

# Arguments
- `x::AbstractArray`: 4D input image tensor of shape (width, height, channels, batch size)

# Returns
- `y::AbstractArray`: 4D output image tensor
"""
function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.f(z)
    z = self.conv1(z; pad=1, stride=self.stride)  # pad=1 will keep same size with (3x3) kernel
    z = self.bn2(z)
    z = self.f(z)
    z = self.conv2(z; pad=1)

    y = z + identity
    return y
end

ResNetLayer

In [4]:
l = ResNetLayer(3, 10; stride=2);

In [5]:
x = randn(Float32, (64, 64, 3, 2));
y = l(x);
size(y)

(32, 32, 10, 2)