# Imports

In [28]:
using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using MLUtils: DataLoader;
using OneHotArrays: onehotbatch
using Knet:conv4
using Metrics;
using TimerOutputs;
using Flux: BatchNorm, kaiming_uniform, nfan;
using Functors

# Model creation
using NNlib;
using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;
using Statistics;
using Distributions;
using Functors;


# Conv 2D

In [20]:
mutable struct Conv2D{T}
    w::AbstractArray{T, 4}
    b::AbstractVector{T}
    use_bias::Bool
end

@functor Conv2D (w, b)

In [21]:
function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;
    bias::Bool=false)
    w_size = (kernel_size..., in_channels, out_channels)
    w = kaiming_uniform(w_size...)
    (fan_in, fan_out) = nfan(w_size)
    
    if bias
        # Init bias with fan_in from weights. Use gain = √2 for ReLU
        bound = √3 * √2 / √fan_in
        rng = Uniform(-bound, bound)
        b = rand(rng, out_channels, Float32)
    else
        b = zeros(Float32, out_channels)
    end

    return Conv2D(w, b, bias)
end

function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)
    y = conv4(self.w, x; stride=stride, padding=pad, dilation=dilation)
    if self.use_bias
        # Bias is applied channel-wise
        (w, h, c, b) = size(y)
        bias = reshape(self.b, (1, 1, c, 1))
        y = y .+ bias
    end
    return y
end
     

# ResNetLayer

In [22]:
mutable struct ResNetLayer
    conv1::Conv2D
    conv2::Conv2D
    bn1::BatchNorm
    bn2::BatchNorm
    f::Function
    in_channels::Int
    channels::Int
    stride::Int
end

@functor ResNetLayer (conv1, conv2, bn1, bn2)

function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        x_id = copy(x[begin:2:end, begin:2:end, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)
    bn1 = BatchNorm(in_channels)
    conv1 = Conv2D((3, 3), in_channels, channels, bias=false)
    bn2 = BatchNorm(channels)
    conv2 = Conv2D((3, 3), channels, channels, bias=false)

    return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)
end


function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.f(z)
    z = self.conv1(z; pad=1, stride=self.stride)  # pad=1 will keep same size with (3x3) kernel
    z = self.bn2(z)
    z = self.f(z)
    z = self.conv2(z; pad=1)

    y = z + identity
    return y
end

# Testing ResNetLayer

In [23]:

l = ResNetLayer(3, 10; stride=2);
x = randn(Float32, (32, 32, 3, 4));
y = l(x);
size(y)

(16, 16, 10, 4)

# Linear Layer

In [24]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear(o)")
end

# Forward
(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b




In [38]:
# ResNet Architecture

mutable struct ResNet20Model
    input_conv::Conv2D
    resnet_blocks::Chain
    pool::GlobalMeanPool
    linear::Linear
end

@functor ResNet20Model

function ResNet20Model(in_channels::Int, num_classes::Int)
    resnet_blocks = Chain(
        block_1 = ResNetLayer(16, 16),
        block_2 = ResNetLayer(16, 16),
        block_3 = ResNetLayer(16, 16),
        block_4 = ResNetLayer(16, 32; stride=2),
        block_5 = ResNetLayer(32, 32),
        block_6 = ResNetLayer(32, 32),
        block_7 = ResNetLayer(32, 64; stride=2),
        block_8 = ResNetLayer(64, 64),
        block_9 = ResNetLayer(64, 64)
    )
    return ResNet20Model(
        Conv2D((3, 3), in_channels, 16, bias=false),
        resnet_blocks,
        GlobalMeanPool(),
        Linear(64, num_classes)
    )
end

ResNet20Model

In [39]:
function (self::ResNet20Model)(x::AbstractArray)
    z = self.input_conv(x)
    z = self.resnet_blocks(z)
    z = self.pool(z)
    z = dropdims(z, dims=(1, 2))
    y = self.linear(z)
    return y
end
     

In [40]:

# Testing ResNet20 model
# Expected output: (10, 4)
m = ResNet20Model(3, 10);
inputs = randn(Float32, (32, 32, 3, 4))
outputs = m(inputs);
size(outputs)
     

LoadError: AssertionError: Spatial dimensions are not divisible by `stride`