In [2]:
# Data processing
using MLDatasets;
using MLUtils: DataLoader;
using MLDataPattern;
using ImageCore;
using Augmentor;
using ImageFiltering;
using MappedArrays;
using Random;
using Flux: DataLoader;
# using OneHotArrays: onehotbatch

# Model creation
using NNlib;
using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;
using Statistics;
using Distributions;
using Functors;

# Training
# using Yota;
using Zygote;
using Optimisers;
using Metrics;
using TimerOutputs;

# Model Primitives

## Conv3x3

In [3]:
"""
    Conv3x3
2D convolution layer with 3x3 kernel for ResNet

# Fields
- `w::AbstractArray`: 4-D weight tensor for convolution kernels, of size (k1, k2, in_channels, out_channels)
- `b::AbstractVector`: Bias for output channels
- `use_bias::Bool`: Whether to apply bias
- `stride::Int`: Stride to apply in convolution
- `padding::Int`: Padding to apply in convolution
"""
mutable struct Conv3x3{T}
    w::AbstractArray{T, 4}
    b::AbstractVector{T}
    use_bias::Bool
    stride::Int
    padding::Int
end

@functor Conv3x3 (w, b)


"""
    Conv3x3(in_channels, out_channels; kwargs...)

Constructor for Conv3x3 layer
Weights are initialized using Kaiming uniform.

# Arguments
- `in_channels::Int`: Number of input channels to the convolution layer
- `out_channels::Int`: Number of output channels from the convolution layer

# Keywords
- `bias::Bool`: Whether to have a bias during convolution, by default false
- `stride::Int`: Stride to apply in convolution, by default 1
- `padding::Int`: Padding to apply in convolution, by default 1 ("same" padding)

# Returns
- `Conv3x3`: Initialized Conv3x3 struct
"""
function Conv3x3(in_channels::Int, out_channels::Int; bias::Bool=false, stride::Int=1, padding::Int=1)
    w_size = (3, 3, in_channels, out_channels)
    w = kaiming_uniform(w_size...)
    (fan_in, fan_out) = nfan(w_size)
    
    if bias
        # Init bias with fan_in from weights. Use gain = √2 for ReLU
        bound = √3 * √2 / √fan_in
        rng = Uniform(-bound, bound)
        b = rand(rng, out_channels, Float32)
    else
        b = zeros(Float32, out_channels)
    end

    return Conv3x3(w, b, bias, stride, padding)
end


"""
    Conv3x3(x; kwargs...)

Perform 2D convolution using initialized layer

# Arguments
- `x::AbstractArray`: 4D input image tensor of shape (width, height, channels, batch size)

# Returns
- `y::AbstractArray`: Output 4D image tensor
"""
function (self::Conv3x3)(x::AbstractArray)
    y = conv(x, self.w; stride=self.stride, pad=self.padding)
    if self.use_bias
        # Bias is applied channel-wise
        (w, h, c, b) = size(y)
        bias = reshape(self.b, (1, 1, c, 1))
        y = y .+ bias
    end
    return y
end

Conv3x3

In [4]:
# Testing Conv3x3
# Expected output: (16, 16, 10, 4)

inputs = randn(Float32, 32, 32, 3, 4)
c = Conv3x3(3, 10; stride=2)

outputs = c(inputs);
size(outputs)

(16, 16, 10, 4)

## ResNet Layer

In [5]:
"""
    ResNetLayer
ResNetV2 Layer
See "Identity Mappings in Deep Residual Networks", https://arxiv.org/pdf/1603.05027.pdf

# Fields
- `conv1::Conv3x3`: First convolution layer
- `conv2::Conv3x3`: Second convolution layer
- `bn1::BatchNorm`: First batch norm layer
- `bn2::BatchNorm`: Second batchnorm layer
- `f::Function`: Activation function
- `in_channels::Int`: Number of input channels to the layer
- `channels::Int`: Number of channels used throughout the layer
- `stride::Int`: Stride to apply to first convolution layer
"""
mutable struct ResNetLayer
    conv1::Conv3x3
    conv2::Conv3x3
    bn1::BatchNorm
    bn2::BatchNorm
    f::Function
    in_channels::Int
    channels::Int
    stride::Int
end

@functor ResNetLayer (conv1, conv2, bn1, bn2)

"""
    residual_identity(ResNetLayer, x)
Identity function for computing identity connection after a strided convolution
Downsamples `x` by a factor of `stride` and adds extra channels filled with zeros

# Arguments
- `layer::ResNetLayer`: ResNetLayer struct that has information about stride and channels
- `x::AbstractArray`: Input tensor of shape (width, height, channels, batch size)

# Returns
- `x_id::AbstractArray`: Downsampled identity tensor
"""
function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        inds = CartesianIndices((1:stride:w, 1:stride:h))
        x_id = copy(x[inds, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(T, w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

"""
    ResNetLayer(in_channels, channels; kwargs...)
Constructor for ResNetLayer

# Arguments
- `in_channels::Int`: Number of input channels to the layer
- `channels::Int`: Number of channels used throughout the layer

# Keywords
- `stride::Int`: Stride to apply in first convolution layer
- `f::Function`: Activation function to apply, by default relu

# Returns
- `ResNetLayer`: Initialized ResNetLayer
"""
function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)
    bn1 = BatchNorm(in_channels)
    conv1 = Conv3x3(in_channels, channels; bias=false, stride=stride)
    bn2 = BatchNorm(channels)
    conv2 = Conv3x3(channels, channels; bias=false)

    return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)
end


"""
    ResNetLayer(x)

Forward function for ResNetLayer
Applies stride in first convolution for downsampling

# Arguments
- `x::AbstractArray`: 4D input image tensor of shape (width, height, channels, batch size)

# Returns
- `y::AbstractArray`: 4D output image tensor
"""
function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.f(z)
    z = self.conv1(z)
    z = self.bn2(z)
    z = self.f(z)
    z = self.conv2(z)

    y = z + identity
    return y
end

ResNetLayer

In [6]:
# Testing ResNetLayer
# Expected output: (16, 61, 10, 4)

l = ResNetLayer(3, 10; stride=2);
x = randn(Float32, (32, 32, 3, 4));
y = l(x);
size(y)

(16, 16, 10, 4)

## Linear Layer

In [7]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear($i=>$o)")
end

# Forward
(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b


## ResNet20Model

In [8]:
"""
    ResNet20
ResNetV2-20 Model
See section Section 4.2 of "Deep Residual Learning for Image Recognition" (https://arxiv.org/pdf/1512.03385.pdf) 
for architecture details.

# Fields
- `input_conv::Conv3x3`: First layer convolution
- `resnet_blocks::Chain`: Chain of ResNet V2 blocks
- `pool::GlobalMeanPool`: Global average pooling
- `linear::Linear`: Linear classifier layer
"""
mutable struct ResNet20
    input_conv::Conv3x3
    resnet_blocks::Chain
    pool::GlobalMeanPool
    linear::Linear
end

@functor ResNet20

"""
    ResNet20(in_channels, num_classes)
Constructor for ResNet20 model

# Arguments
- `in_channels::Int`: Number of input channels to the model
- `num_classes::Int`: Number of classes to output

# Returns
- `ResNet20`: Initialized model
"""
function ResNet20(in_channels::Int, num_classes::Int)
    resnet_blocks = Chain(
        block_1 = ResNetLayer(16, 16),
        block_2 = ResNetLayer(16, 16),
        block_3 = ResNetLayer(16, 16),
        block_4 = ResNetLayer(16, 32; stride=2),
        block_5 = ResNetLayer(32, 32),
        block_6 = ResNetLayer(32, 32),
        block_7 = ResNetLayer(32, 64; stride=2),
        block_8 = ResNetLayer(64, 64),
        block_9 = ResNetLayer(64, 64)
    )
    return ResNet20(
        Conv3x3(in_channels, 16; bias=false),
        resnet_blocks,
        GlobalMeanPool(),
        Linear(64, num_classes)
    )
end

"""
    ResNet20(x)
Forward function for ResNet20 model

# Arguments
- `x::AbstractArray`: 4D input image tensor of shape (width, height, channels, batch size)

# Returns
- `y::AbstractArray`: 2D output tensor of shape (num classes, batch size)
"""
function (self::ResNet20)(x::AbstractArray)
    z = self.input_conv(x)
    z = self.resnet_blocks(z)
    z = self.pool(z)
    z = dropdims(z, dims=(1, 2))
    y = self.linear(z)
    return y
end

ResNet20

In [9]:
# Testing ResNet20 model
# Expected output: (10, 4)
m = ResNet20(3, 10);
inputs = randn(Float32, (32, 32, 3, 4))
outputs = m(inputs);
size(outputs)

(10, 4)

# Data Pre-processing
- Inputs: batches of (32 x 32) RGB images
    - Tensor size (32, 32, 3, N) in WHCN dimensions
    - Values between [0, 1]
- For all data: ImageNet normalization
    - Subtract means [0.485, 0.456, 0.406]
    - Divide by standard deviations [0.229, 0.224, 0.225]
- Augment training data only:
    - Permute to CWHN (3, 32, 32, N)
    - Convert to RGB image for Augmentor.jl package to process (32, 32, N)
    - 4 pixel padding on each side (40, 40, N)
    - Random horizontal flip
    - (32 x 32) crop from augmented image (32, 32, N)
    - Convert to tensors (3, 32, 32, N)
    - Permute to WHCN (32, 32, 3, N)
- Batch and shuffle data

In [10]:
train_data = MLDatasets.CIFAR10(Tx=Float32, split=:train)
test_data = MLDatasets.CIFAR10(Tx=Float32, split=:test)

dataset CIFAR10:
  metadata  =>    Dict{String, Any} with 2 entries
  split     =>    :test
  features  =>    32×32×3×10000 Array{Float32, 4}
  targets   =>    10000-element Vector{Int64}

In [11]:
train_x = train_data.features;
train_y = train_data.targets;

test_x = test_data.features;
test_y = test_data.targets;
size(train_x), size(test_x)  # Data is in shape WHCB

((32, 32, 3, 50000), (32, 32, 3, 10000))

In [12]:
# Train-test split
# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65
# obsview doesn't work with this data, so use getobs instead

import MLDataPattern.splitobs;

function splitobs(data; at, shuffle::Bool=false)
    if shuffle
        data = shuffleobs(data)
    end
    n = numobs(data)
    return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))
end

splitobs (generic function with 11 methods)

In [13]:
train, val = splitobs((train_x, train_y), at=0.9, shuffle=true);

train_x, train_y = train;
val_x, val_y = val;

size(train_x), size(val_x)

((32, 32, 3, 45000), (32, 32, 3, 5000))

In [14]:
# Normalize all the data

means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))
stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))
normalize(x) = (x .- means) ./ stdevs

train_x = normalize(train_x);
val_x = normalize(val_x);
test_x = normalize(test_x);

In [15]:
# Notebook testing: Use less data
train_x, train_y = MLDatasets.getobs((train_x, train_y), 1:500);

val_x, val_y = MLDatasets.getobs((val_x, val_y), 1:50);

test_x, test_y = MLDatasets.getobs((test_x, test_y), 1:50);

## Data augmentation pipeline with Augmentor.jl
By default, batch is the last dimension.

In [16]:
# Pad the training data for further augmentation
train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0)));  
size(train_x_padded)  # Should be (40, 40, 3, 50000)

(40, 40, 3, 500)

In [17]:
pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))

6-step Augmentor.ImmutablePipeline:
 1.) Permute dimension order to (3, 1, 2)
 2.) Combine color channels into colorant RGB
 3.) Either: (50%) Flip the X axis. (50%) No operation.
 4.) Crop random window with size (32, 32)
 5.) Split colorant into its color channels
 6.) Permute dimension order to (2, 3, 1)

In [18]:
# Create an output array for augmented images
outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))

outbatch (generic function with 1 method)

In [19]:
# Function that takes a batch (images and targets) and augments the images
augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)

augmentbatch (generic function with 1 method)

In [20]:
# Shuffled and batched dataset of augmented images
train_batch_size = 16

train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));

└ @ MLDataPattern /Users/trevoryu/.julia/packages/MLDataPattern/2yPuO/src/dataview.jl:205


In [21]:
# Test and Validation data
test_batch_size = 32

val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);
test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);

# Training setup

## Sparse Cross Entropy function
Extend logit cross entropy to work efficiently with sparse labels (i.e. integer indices and not one-hot labels)

In [22]:
"""
    sparse_logit_cross_entropy(logits, labels)

Efficient computation of cross entropy loss with model logits and integer indices as labels.
Integer indices are from [0,  N-1], where N is the number of classes
Similar to TensorFlow SparseCategoricalCrossEntropy

# Arguments
- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)
- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)

# Returns
- `loss::Float32`: Cross entropy loss
"""
# function sparse_logit_cross_entropy(logits, labels)
#     log_probs = logsoftmax(logits);
#     # Select indices of labels for loss
#     log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);
#     loss = -mean(log_probs);
#     return loss
# end

function sparse_logit_cross_entropy(logits, labels)
    log_probs = logsoftmax(logits);
    inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));
    # Select indices of labels for loss
    log_probs = log_probs[inds];
    loss = -mean(log_probs);
    return loss
end

sparse_logit_cross_entropy (generic function with 1 method)

In [23]:
# Create model with 3 input channels and 10 classes
model = ResNet20(3, 10);

In [24]:
# Setup AdamW optimizer
β = (0.9, 0.999);
decay = 1e-4;
state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);

In [25]:
# Create objective function to optimize
function loss_function(model::ResNet20, x::AbstractArray, y::AbstractArray)
    ŷ = model(x)
    loss = sparse_logit_cross_entropy(ŷ, y)
    return loss
end

loss_function (generic function with 1 method)

In [26]:
(x, y) = first(train_batches);

In [27]:
# loss, g = grad(loss_function, model, x, y);

In [28]:
mutable struct ResNet5
    input_conv::Conv3x3
    resnet_block::ResNetLayer
    pool::GlobalMeanPool
    linear::Linear
end

@functor ResNet5

function ResNet5(in_channels::Int, num_classes::Int)
    return ResNet5(
        Conv3x3(in_channels, 16; bias=false),
        ResNetLayer(16, 16),
        GlobalMeanPool(),
        Linear(16, num_classes)
    )
end

function (self::ResNet5)(x::AbstractArray)
    z = self.input_conv(x)
    z = self.resnet_block(z)
    z = self.pool(z)
    z = dropdims(z, dims=(1, 2))
    y = self.linear(z)
    return y
end


function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)
    ŷ = model(x)
    loss = sparse_logit_cross_entropy(ŷ, y)
    return loss
end

loss_function (generic function with 2 methods)

In [31]:
# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?
# loss, g = grad(loss_function, model, x, y)
model = ResNet5(3, 10);

loss, g = Zygote.gradient(loss_function, model, x, y);

In [38]:
g

32×32×3×16 Array{Float32, 4}:
[:, :, 1, 1] =
  3.02599f-5   7.07849f-5   7.94292f-5   …  6.81889f-5   4.79578f-5
  1.98297f-5   1.765f-5     3.06098f-5      5.32691f-5   1.61127f-5
  2.41994f-5  -6.55991f-6   2.34102f-5      5.13781f-5   3.09528f-5
 -1.93668f-5   1.05649f-5   3.49399f-5      6.30373f-5   4.09379f-5
  1.10981f-5  -9.92099f-6  -1.78859f-5      4.04262f-5   2.8198f-5
  2.52191f-5  -7.15537f-6  -1.00885f-5   …  4.56546f-5   2.18946f-5
 -2.06225f-6   6.24625f-6   4.01926f-7      5.5654f-5    1.91564f-5
  6.59202f-5   5.37844f-5   4.03756f-5      2.97683f-5   2.18327f-5
  5.54791f-5  -1.35512f-5  -5.49614f-6      6.7599f-5    2.30759f-5
  1.27065f-5   4.77054f-5   3.88314f-5      6.97576f-5   1.83329f-5
 -1.4598f-5    6.17554f-5   6.25489f-5   …  3.87517f-5  -1.14439f-5
  7.20194f-6   4.24674f-5   6.32658f-5      4.78324f-5  -1.47931f-5
  3.15358f-6  -2.03658f-6   7.06585f-5      4.49833f-5  -2.59674f-5
  ⋮                                      ⋱  ⋮           
  5.2263f-5   -

# Evaluation Function

In [61]:
function evaluate(model, test_loader)
    preds = []
    targets = []
    for (x, y) in test_loader
        # Get model predictions
        # Note argmax of nd-array gives CartesianIndex
        # Need to grab the first element of each CartesianIndex to get the true index
        logits = model(x)
        ŷ = map(i -> i[1], argmax(logits, dims=1))
        append!(preds, ŷ)

        # Get true labels
        append!(targets, y)
    end
    accuracy = sum(preds .== targets) / length(targets)
    return accuracy
end

evaluate (generic function with 1 method)

# Training Loop

In [62]:
# Setup timing output
const to = TimerOutput()

[0m[1m ────────────────────────────────────────────────────────────────────[22m
[0m[1m                   [22m         Time                    Allocations      
                   ───────────────────────   ────────────────────────
 Tot / % measured:      358ms /   0.0%           44.0MiB /   0.0%    

 Section   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────
[0m[1m ────────────────────────────────────────────────────────────────────[22m

In [63]:
last_loss = 0;
@timeit to "total_training_time" begin
    for epoch in 1:10
        timing_name = epoch > 1 ? "average_epoch_training_time" : "train_jit"

        # Create lazily evaluated augmented training data
        train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));

        @timeit to timing_name begin
            losses = []
            for (x, y) in train_batches
                # loss_function does forward pass
                # Yota.jl grad function computes model parameter gradients in g[2]
                loss, g = grad(loss_function, model, x, y)
                
                # Optimiser updates parameters
                Optimisers.update!(state, model, g[2])
                push!(losses, loss)
            end
            last_loss = mean(losses)
            @info("epoch (mean(losses))")
        end
        # timing_name = epoch > 1 ? "average_inference_time" : "eval_jit"
        # @timeit to timing_name begin
        #     acc = evaluate(model, test_loader)
        #     @info("epoch (acc)")
        # end
    end
end

LoadError: InterruptException: