# Imports

In [1]:
using MLDatasets: CIFAR10
using MLDataUtils
using Knet, IterTools
using Dictionaries
using TimerOutputs
using JSON
using Printf
using Knet:minibatch
using Knet:minimize
using Knet: Param
using Knet: dir, accuracy, progress, sgd, gc, Data, nll, relu, conv4
using Flatten
using Flux.Data;
using Flux, Statistics
using Statistics: mean, var
using Functors



# Processing Data/Batch Processing

In [2]:
# This loads the CIFAR-10 Dataset for training, validation, and evaluation
xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)
xval,yval =  CIFAR10.traindata(Float32, 45001:50000)
xtst,ytst = CIFAR10.testdata(Float32)
println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));

32×32×3×45000 Array{Float32, 4}
45000-element Vector{Int64}
32×32×3×5000 Array{Float32, 4}
5000-element Vector{Int64}
32×32×3×10000 Array{Float32, 4}
10000-element Vector{Int64}


In [3]:
train_loader = DataLoader((xtrn, ytrn), batchsize=256)
val_loader = DataLoader((xval, yval), batchsize = 256)
test_loader = DataLoader((xtst, ytst), batchsize = 256)

DataLoader{Tuple{Array{Float32, 4}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}(([0.61960787 0.59607846 … 0.23921569 0.21176471; 0.62352943 0.5921569 … 0.19215687 0.21960784; … ; 0.49411765 0.49019608 … 0.11372549 0.13333334; 0.45490196 0.46666667 … 0.078431375 0.08235294;;; 0.4392157 0.4392157 … 0.45490196 0.41960785; 0.43529412 0.43137255 … 0.4 0.4117647; … ; 0.35686275 0.35686275 … 0.32156864 0.32941177; 0.33333334 0.34509805 … 0.2509804 0.2627451;;; 0.19215687 0.2 … 0.65882355 0.627451; 0.18431373 0.15686275 … 0.5803922 0.58431375; … ; 0.14117648 0.1254902 … 0.49411765 0.5058824; 0.12941177 0.13333334 … 0.41960785 0.43137255;;;; 0.92156863 0.93333334 … 0.32156864 0.33333334; 0.90588236 0.92156863 … 0.18039216 0.24313726; … ; 0.9137255 0.9254902 … 0.7254902 0.7058824; 0.9098039 0.92156863 … 0.73333335 0.7294118;;; 0.92156863 0.93333334 … 0.3764706 0.39607844; 0.90588236 0.92156863 … 0.22352941 0.29411766; … ; 0.9137255 0.9254902 … 0.78431374 0.7647059; 0.9098039 0.92156863 … 0

# Define Struct ResNetLayer

In [4]:
mutable struct ResNetLayer
    conv1::Flux.Conv
    conv2::Flux.Conv
    bn1::BatchNorm
    bn2::BatchNorm
    activation_function::Function
    in_channels::Int
    channels::Int
    stride::Int 
end 

@functor ResNetLayer (conv1, conv2, bn1, bn2)

In [5]:
# Constructor
function ResNetLayer(in_channels::Int, channels::Int; activation_function = relu, stride = 1)
    bn1 = BatchNorm(in_channels)
    conv1 = Flux.Conv((3,3), in_channels => channels, activation_function; stride = stride)
    bn2 = BatchNorm(channels)
    conv2 = Flux.Conv((3,3), channels => channels, activation_function; stride = stride)
    return ResNetLayer(conv1, conv2, bn1, bn2, activation_function, in_channels, channels, stride)
end

ResNetLayer

# Define Residual Identity

In [6]:
function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        x_id = copy(x[begin:2:end, begin:2:end, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

residual_identity (generic function with 1 method)

# Forward Function

In [10]:
function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.activation_function(z)
    z = self.conv1(z)
    z = self.bn2(z)
    z = self.activation_function(z)
    z = self.conv2(z)
    y = z + identity 
    return y
end 

In [11]:
# Example
l = ResNetLayer(3, 10; stride = 2, pad = 0)


ResNetLayer(Conv((3, 3), 3 => 10, relu, stride=2), Conv((3, 3), 10 => 10, relu, stride=2), BatchNorm(3), BatchNorm(10), Knet.Ops20.relu, 3, 10, 2, 0)

In [15]:
x = randn(Float32, (64, 64, 3, 2))
x


64×64×3×2 Array{Float32, 4}:
[:, :, 1, 1] =
  1.93723    -0.0038152    1.00776     …  -0.268737    0.772344     0.439853
 -1.11342     0.888429    -0.394805       -0.983046   -1.55638      2.38319
 -1.01276    -0.189184     0.00920959      0.082171   -0.282        0.785619
 -0.392059   -0.184873    -1.81964        -0.0519429   0.106289    -0.736538
  0.19757     0.467828     0.448932       -0.36157    -0.879078    -0.65364
 -0.972052   -0.00179285  -1.44282     …  -0.123626   -0.00435947  -0.531083
 -0.612597   -0.171173     0.14255         0.692491    0.548423     0.683022
  1.25817    -0.170542     0.845562       -0.803024   -0.623968    -2.29548
 -0.266937    1.96865     -1.98532         0.816746    0.717069    -1.49578
 -0.0521157   0.142611     1.59346        -2.30004    -1.42817     -1.13307
  2.12283     1.47045     -1.58077     …  -0.352557   -0.600046    -0.338219
  1.05403    -0.0555598   -1.05525        -0.127607    0.857893    -1.35023
 -0.856175   -0.982227    -0.988386   

In [16]:
y = l(x)

LoadError: MethodError: no method matching Array{Float32, 4}(::Int64)
[0mClosest candidates are:
[0m  Array{T, N}([91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}[39m) where {T, N} at C:\Users\Yash\.julia\packages\NNlibCUDA\kCpTE\src\batchedadjtrans.jl:16
[0m  Array{T, N}([91m::BitArray{N}[39m) where {T, N} at bitarray.jl:495
[0m  Array{T, N}([91m::FillArrays.Zeros{V, N}[39m) where {T, V, N} at C:\Users\Yash\.julia\packages\FillArrays\Slipo\src\FillArrays.jl:441
[0m  ...