# Imports

In [1]:
using MLDatasets: CIFAR10
using MLDataUtils
using Knet, IterTools
using Dictionaries
using TimerOutputs
using JSON
using Printf
using Knet:minibatch
using Knet:minimize
using Knet: Param
using Knet: dir, accuracy, progress, sgd, gc, Data, nll, relu, conv4
using Flatten
using Flux.Data;
using Flux, Statistics
using Statistics: mean, var
using Functors



# Processing Data/Batch Processing

In [2]:
# This loads the CIFAR-10 Dataset for training, validation, and evaluation
xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)
xval,yval =  CIFAR10.traindata(Float32, 45001:50000)
xtst,ytst = CIFAR10.testdata(Float32)
println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));

32×32×3×45000 Array{Float32, 4}
45000-element Vector{Int64}
32×32×3×5000 Array{Float32, 4}
5000-element Vector{Int64}
32×32×3×10000 Array{Float32, 4}
10000-element Vector{Int64}


In [3]:
train_loader = DataLoader((xtrn, ytrn), batchsize=256)
val_loader = DataLoader((xval, yval), batchsize = 256)
test_loader = DataLoader((xtst, ytst), batchsize = 256)

DataLoader{Tuple{Array{Float32, 4}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}(([0.61960787 0.59607846 … 0.23921569 0.21176471; 0.62352943 0.5921569 … 0.19215687 0.21960784; … ; 0.49411765 0.49019608 … 0.11372549 0.13333334; 0.45490196 0.46666667 … 0.078431375 0.08235294;;; 0.4392157 0.4392157 … 0.45490196 0.41960785; 0.43529412 0.43137255 … 0.4 0.4117647; … ; 0.35686275 0.35686275 … 0.32156864 0.32941177; 0.33333334 0.34509805 … 0.2509804 0.2627451;;; 0.19215687 0.2 … 0.65882355 0.627451; 0.18431373 0.15686275 … 0.5803922 0.58431375; … ; 0.14117648 0.1254902 … 0.49411765 0.5058824; 0.12941177 0.13333334 … 0.41960785 0.43137255;;;; 0.92156863 0.93333334 … 0.32156864 0.33333334; 0.90588236 0.92156863 … 0.18039216 0.24313726; … ; 0.9137255 0.9254902 … 0.7254902 0.7058824; 0.9098039 0.92156863 … 0.73333335 0.7294118;;; 0.92156863 0.93333334 … 0.3764706 0.39607844; 0.90588236 0.92156863 … 0.22352941 0.29411766; … ; 0.9137255 0.9254902 … 0.78431374 0.7647059; 0.9098039 0.92156863 … 0

# Define Struct ResNetLayer

In [4]:
mutable struct ResNetLayer
    conv1::Flux.Conv
    conv2::Flux.Conv
    bn1::BatchNorm
    bn2::BatchNorm
    activation_function::Function
    in_channels::Int
    channels::Int
    stride::Int 
end 

@functor ResNetLayer (conv1, conv2, bn1, bn2)

In [5]:
# Constructor
function ResNetLayer(in_channels::Int, channels::Int, activation_function = relu, stride = 1)
    bn1 = BatchNorm(in_channels)
    conv1 = Flux.Conv((3,3), in_channels => channels, activation_function; stride = stride)
    bn2 = BatchNorm(channels)
    conv2 = Flux.Conv((3,3), channels => channels, activation_function; stride = stride)
    return ResNetLayer(conv1, conv2, bn1, bn2, activation_function, in_channels, channels, stride)
end

ResNetLayer

# Define Residual Identity

In [6]:
function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        x_id = copy(x[begin:2:end, begin:2:end, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

residual_identity (generic function with 1 method)

# Forward Function

In [7]:
function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.activation_function(z)
    z = self.conv1(z)
    z = self.bn2(z)
    z = self.activation_function(z)
    z = self.conv2(z)
    y = z + identity 
    return y
end 

In [8]:
# Example
l = ResNetLayer(3, 10, stride = 2)
x = randn(Float32, (64, 64, 3, 2))
y = l(x)
size(y)

LoadError: MethodError: no method matching ResNetLayer(::Int64, ::Int64; stride=2)
[0mClosest candidates are:
[0m  ResNetLayer(::Int64, ::Int64) at In[5]:2[91m got unsupported keyword argument "stride"[39m
[0m  ResNetLayer(::Int64, ::Int64, [91m::Any[39m) at In[5]:2[91m got unsupported keyword argument "stride"[39m
[0m  ResNetLayer(::Int64, ::Int64, [91m::Any[39m, [91m::Any[39m) at In[5]:2[91m got unsupported keyword argument "stride"[39m
[0m  ...