# Imports

In [1]:

using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using MLUtils: DataLoader;
using OneHotArrays: onehotbatch
using Knet:Knet,conv4, adam
using Knet: dir, accuracy, progress, sgd, gc
using Metrics;
using TimerOutputs;
using Flux: BatchNorm, kaiming_uniform, nfan;
using Functors

# Model creation
using NNlib;
using Flux: BatchNorm, Chain, GlobalMeanPool, kaiming_uniform, nfan;
using Statistics;
using Distributions;
using Functors;

# Data processing
using MLDatasets;
using MLUtils: DataLoader;
using MLDataPattern;
using ImageCore;
using Augmentor;
using ImageFiltering;
using MappedArrays;
using Random;
using Flux: DataLoader;
# using OneHotArrays: onehotbatch


# Training
# using Yota;
using Zygote;
using Optimisers;
using Metrics;
using TimerOutputs;



# Issue when running this
#using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu



# Conv 2D

In [2]:
mutable struct Conv2D{T}
    w::AbstractArray{T, 4}
    b::AbstractVector{T}
    use_bias::Bool
    padding::Int 
end

@functor Conv2D (w, b)

In [3]:
function Conv2D(kernel_size::Tuple{Int, Int}, in_channels::Int, out_channels::Int;
    bias::Bool=false, padding::Int=1)
    w_size = (kernel_size..., in_channels, out_channels)
    w = kaiming_uniform(w_size...)
    (fan_in, fan_out) = nfan(w_size)
    
    if bias
        # Init bias with fan_in from weights. Use gain = √2 for ReLU
        bound = √3 * √2 / √fan_in
        rng = Uniform(-bound, bound)
        b = rand(rng, out_channels, Float32)
    else
        b = zeros(Float32, out_channels)
    end

    return Conv2D(w, b, bias, padding)
end

function (self::Conv2D)(x::AbstractArray; stride::Int=1, pad::Int=0, dilation::Int=1)
    y = conv4(self.w, x; stride=stride, padding=self.padding, dilation=dilation)
    if self.use_bias
        # Bias is applied channel-wise
        (w, h, c, b) = size(y)
        bias = reshape(self.b, (1, 1, c, 1))
        y = y .+ bias
    end
    return y
end
     

# ResNetLayer

In [4]:
mutable struct ResNetLayer
    conv1::Conv2D
    conv2::Conv2D
    bn1::BatchNorm
    bn2::BatchNorm
    f::Function
    in_channels::Int
    channels::Int
    stride::Int
end

@functor ResNetLayer (conv1, conv2, bn1, bn2)

function residual_identity(layer::ResNetLayer, x::AbstractArray{T, 4}) where {T<:Number}
    (w, h, c, b) = size(x)
    stride = layer.stride
    if stride > 1
        @assert ((w % stride == 0) & (h % stride == 0)) "Spatial dimensions are not divisible by `stride`"
    
        # Strided downsample
        x_id = copy(x[begin:2:end, begin:2:end, :, :])
    else
        x_id = x
    end

    channels = layer.channels
    in_channels = layer.in_channels
    if in_channels < channels
        # Zero padding on extra channels
        (w, h, c, b) = size(x_id)
        pad = zeros(w, h, channels - in_channels, b)
        x_id = cat(x_id, pad; dims=3)
    elseif in_channels > channels
        error("in_channels > out_channels not supported")
    end
    return x_id
end

function ResNetLayer(in_channels::Int, channels::Int; stride=1, f=relu)
    bn1 = BatchNorm(in_channels)
    conv1 = Conv2D((3, 3), in_channels, channels, bias=false)
    bn2 = BatchNorm(channels)
    conv2 = Conv2D((3, 3), channels, channels, bias=false)

    return ResNetLayer(conv1, conv2, bn1, bn2, f, in_channels, channels, stride)
end


function (self::ResNetLayer)(x::AbstractArray)
    identity = residual_identity(self, x)
    z = self.bn1(x)
    z = self.f(z)
    z = self.conv1(z; pad=1, stride=self.stride)  # pad=1 will keep same size with (3x3) kernel
    z = self.bn2(z)
    z = self.f(z)
    z = self.conv2(z; pad=1)

    y = z + identity
    return y
end

# Testing ResNetLayer

In [5]:

l = ResNetLayer(3, 10; stride=2);
x = randn(Float32, (32, 32, 3, 4));
y = l(x);
size(y)

(16, 16, 10, 4)

# Linear Layer

In [6]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear(o)")
end

# Forward
(l::Linear)(x::AbstractArray) where T = l.W * x .+ l.b




# Defining a Chain Layer

In [7]:
# Define a chain of layers and a loss function:
struct Chain1; layers; end
(c::Chain1)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain1)(x,y) = nll(c(x),y)

In [8]:
# ResNet Architecture

mutable struct ResNet20Model
    input_conv::Conv2D
    resnet_blocks::Chain1
    pool::GlobalMeanPool
    linear::Linear
end

@functor ResNet20Model

function ResNet20Model(in_channels::Int, num_classes::Int)
    resnet_blocks = Chain1((
        block_1 = ResNetLayer(16, 16),
        block_2 = ResNetLayer(16, 16),
        block_3 = ResNetLayer(16, 16),
        block_4 = ResNetLayer(16, 32; stride=2),
        block_5 = ResNetLayer(32, 32),
        block_6 = ResNetLayer(32, 32),
        block_7 = ResNetLayer(32, 64; stride=2),
        block_8 = ResNetLayer(64, 64),
        block_9 = ResNetLayer(64, 64)
    ))
    return ResNet20Model(
        Conv2D((3, 3), in_channels, 16, bias=false),
        resnet_blocks,
        GlobalMeanPool(),
        Linear(64, num_classes)
    )
end

ResNet20Model

In [9]:
function (self::ResNet20Model)(x::AbstractArray)
    z = self.input_conv(x)
    z = self.resnet_blocks(z)
    z = self.pool(z)
    z = dropdims(z, dims=(1, 2))
    y = self.linear(z)
    return y
end
     

In [10]:

# Testing ResNet20 model
# Expected output: (10, 4)
m = ResNet20Model(3, 10);
inputs = randn(Float32, (32, 32, 3, 4))
outputs = m(inputs);
size(outputs)
     

│   yT = Float64
│   T1 = Float64
│   T2 = Float32
└ @ NNlib C:\Users\Yash\.julia\packages\NNlib\0QnJJ\src\conv.jl:285


(10, 4)

# Data Preprocessing 

In [11]:
# This loads the CIFAR-10 Dataset for training, validation, and evaluation
xtrn,ytrn = CIFAR10.traindata(Float32, 1:45000)
xval,yval =  CIFAR10.traindata(Float32, 45001:50000)
xtst,ytst = CIFAR10.testdata(Float32)
println.(summary.((xtrn,ytrn,xval, yval, xtst,ytst)));

32×32×3×45000 Array{Float32, 4}
45000-element Vector{Int64}
32×32×3×5000 Array{Float32, 4}
5000-element Vector{Int64}
32×32×3×10000 Array{Float32, 4}
10000-element Vector{Int64}


In [12]:
# Normalize all the data

means = reshape([0.485, 0.465, 0.406], (1, 1, 3, 1))
stdevs = reshape([0.229, 0.224, 0.225], (1, 1, 3, 1))
normalize(x) = (x .- means) ./ stdevs

train_x = normalize(xtrn);
val_x = normalize(xval);
test_x = normalize(xtst);

In [13]:

# Train-test split
# Copied from https://github.com/JuliaML/MLUtils.jl/blob/v0.2.11/src/splitobs.jl#L65
# obsview doesn't work with this data, so use getobs instead

import MLDataPattern.splitobs;

function splitobs(data; at, shuffle::Bool=false)
    if shuffle
        data = shuffleobs(data)
    end
    n = numobs(data)
    return map(idx -> MLDataPattern.getobs(data, idx), splitobs(n, at))
end

splitobs (generic function with 11 methods)

In [14]:

# Notebook testing: Use less data
train_x, train_y = MLDatasets.getobs((train_x, ytrn), 1:500);

val_x, val_y = MLDatasets.getobs((val_x, yval), 1:50);

test_x, test_y = MLDatasets.getobs((test_x, ytst), 1:50);

In [15]:

# Pad the training data for further augmentation
train_x_padded = padarray(train_x, Fill(0, (4, 4, 0, 0)));  
size(train_x_padded)  # Should be (40, 40, 3, 50000)

(40, 40, 3, 500)

In [16]:
pl = PermuteDims((3, 1, 2)) |> CombineChannels(RGB) |> Either(FlipX(), NoOp()) |> RCropSize(32, 32) |> SplitChannels() |> PermuteDims((2, 3, 1))

6-step Augmentor.ImmutablePipeline:
 1.) Permute dimension order to (3, 1, 2)
 2.) Combine color channels into colorant RGB
 3.) Either: (50%) Flip the X axis. (50%) No operation.
 4.) Crop random window with size (32, 32)
 5.) Split colorant into its color channels
 6.) Permute dimension order to (2, 3, 1)

In [17]:
# Create an output array for augmented images
outbatch(X) = Array{Float32}(undef, (32, 32, 3, nobs(X)))

outbatch (generic function with 1 method)

In [18]:
# Function that takes a batch (images and targets) and augments the images
augmentbatch((X, y)) = (augmentbatch!(outbatch(X), X, pl), y)

augmentbatch (generic function with 1 method)

In [19]:

# Shuffled and batched dataset of augmented images
train_batch_size = 16

train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));
     

└ @ MLDataPattern C:\Users\Yash\.julia\packages\MLDataPattern\KlSmO\src\dataview.jl:205


In [20]:
# Test and Validation data
test_batch_size = 32

val_loader = DataLoader((val_x, val_y), shuffle=true, batchsize=test_batch_size);
test_loader = DataLoader((test_x, test_y), shuffle=true, batchsize=test_batch_size);

# Training setup

In [21]:
#Sparse Cross Entropy function

In [22]:

"""
    sparse_logit_cross_entropy(logits, labels)

Efficient computation of cross entropy loss with model logits and integer indices as labels.
Integer indices are from [0,  N-1], where N is the number of classes
Similar to TensorFlow SparseCategoricalCrossEntropy

# Arguments
- `logits::AbstractArray`: 2D model logits tensor of shape (classes, batch size)
- `labels::AbstractArray`: 1D integer label indices of shape (batch size,)

# Returns
- `loss::Float32`: Cross entropy loss
"""
# function sparse_logit_cross_entropy(logits, labels)
#     log_probs = logsoftmax(logits);
#     # Select indices of labels for loss
#     log_probs = map((x, i) -> x[i + 1], eachslice(log_probs; dims=2), labels);
#     loss = -mean(log_probs);
#     return loss
# end

function sparse_logit_cross_entropy(logits, labels)
    log_probs = logsoftmax(logits);
    inds = CartesianIndex.(labels .+ 1, axes(log_probs, 2));
    # Select indices of labels for loss
    log_probs = log_probs[inds];
    loss = -mean(log_probs);
    return loss
end


sparse_logit_cross_entropy (generic function with 1 method)

In [23]:

# Create model with 3 input channels and 10 classes
model = ResNet20Model(3, 10);

In [24]:
# Setup AdamW optimizer
β = (0.9, 0.999);
decay = 1e-4;
state = Optimisers.setup(Optimisers.Adam(1e-3, β, decay), model);

In [25]:

(x, y) = first(train_batches);

In [26]:
# loss, g = grad(loss_function, model, x, y);

In [27]:
mutable struct ResNet5
    input_conv::Conv2D
    resnet_block::ResNetLayer
    pool::GlobalMeanPool
    linear::Linear
end

@functor ResNet5

function ResNet5(in_channels::Int, num_classes::Int)
    return ResNet5(
        Conv2D((3, 3), in_channels, 16, bias=false),
        ResNetLayer(16, 16),
        GlobalMeanPool(),
        Linear(16, num_classes)
    )
end

function (self::ResNet5)(x::AbstractArray)
    z = self.input_conv(x)
    z = self.resnet_block(z)
    z = self.pool(z)
    z = dropdims(z, dims=(1, 2))
    y = self.linear(z)
    return y
end


# function loss_function(model::ResNet5, x::AbstractArray, y::AbstractArray)
#     ŷ = model(x)
#     loss = sparse_logit_cross_entropy(ŷ, y)
#     return loss
# end

In [28]:

# Yota is unable to compute gradients through the ResNet for some reason, maybe due to residual connections?
# loss, g = grad(loss_function, model, x, y)
model = ResNet5(3, 10);

# loss, g = Zygote.gradient(loss_function, model, x, y);

In [29]:
# g

In [30]:
model

ResNet5(Conv2D{Float32}([-0.3269322 -0.09079589 -0.30220258; 0.29980195 -0.35697645 -0.43193826; 0.41894063 -0.27608046 -0.35023037;;; -0.11048572 -0.18733561 -0.048941202; -0.25535512 0.41386655 0.23646444; 0.1226552 0.19139434 -0.3201441;;; -0.2881651 0.4041223 -0.11729951; 0.28896266 0.124178275 0.10890088; -0.07136462 0.37597623 0.2907424;;;; 0.0116342725 -0.06491064 0.03947901; 0.36589766 -0.31363672 0.32354057; -0.101177834 0.22076249 0.26570976;;; -0.22781743 -0.16796216 0.079579934; 0.43243396 -0.18935399 0.3949348; -0.3725451 -0.06775151 0.21907443;;; -0.05270441 -0.43405735 -0.44125763; -0.47045088 -0.30292767 0.014733751; -0.04850591 -0.2133474 0.2412362;;;; 0.2572607 0.18735757 -0.33566207; -0.03157889 -0.04323261 0.1315869; -0.16356815 -0.23604983 0.051579874;;; -0.3262316 0.40397793 0.07843399; 0.17368728 0.31032175 -0.2273731; -0.20191403 0.11151084 0.33216488;;; -0.34083363 -0.46381113 0.055753145; -0.44104743 0.31393462 0.2622986; 0.13619547 -0.12979876 -0.043511562;;;

# Training Knet

In [31]:


loss(test_x, test_y) = nll(model(test_x), test_y)
evalcb = () -> (loss(test_x, test_y)) #function that will be called to get the loss 
const to = TimerOutput() # creating a TimerOutput, keeps track of everything


@timeit to "Train Total" begin
    for epoch in 1:10
        train_epoch = epoch > 1 ? "train_epoch" : "train_ji"
        @timeit to train_epoch begin
            progress!(adam(model, train_batches; lr = 1e-3))
        end
        
        evaluation = epoch > 1 ? "evaluation" : "eval_jit"
        @timeit to evaluation begin
            accuracy(model, train_batches)
        end 
        
    end 
end 

LoadError: UndefVarError: progress! not defined

# Evaluation Function

In [32]:

# function evaluate(model, test_loader)
#     preds = []
#     targets = []
#     for (x, y) in test_loader
#         # Get model predictions
#         # Note argmax of nd-array gives CartesianIndex
#         # Need to grab the first element of each CartesianIndex to get the true index
#         logits = model(x)
#         ŷ = map(i -> i[1], argmax(logits, dims=1))
#         append!(preds, ŷ)

#         # Get true labels
#         append!(targets, y)
#     end
#     accuracy = sum(preds .== targets) / length(targets)
#     return accuracy
# end

# Training Loop

In [33]:

# # Setup timing output
# const to = TimerOutput()

In [34]:
# # last_loss = 0;
# # @timeit to "total_training_time" begin
#     for epoch in 1:10
#         timing_name = epoch > 1 ? "average_epoch_training_time" : "train_jit"

#         # Create lazily evaluated augmented training data
#         train_batches = mappedarray(augmentbatch, batchview(shuffleobs((train_x_padded, train_y)), size=train_batch_size));

#         @timeit to timing_name begin
#             losses = []
#             for (x, y) in train_batches
#                 # loss_function does forward pass
#                 # Yota.jl grad function computes model parameter gradients in g[2]
#                 loss, g = grad(loss_function, model, x, y)
                
#                 # Optimiser updates parameters
#                 Optimisers.update!(state, model, g[2])
#                 push!(losses, loss)
#             end
#             last_loss = mean(losses)
#             @info("epoch (mean(losses))")
#         end
#         # timing_name = epoch > 1 ? "average_inference_time" : "eval_jit"
#         # @timeit to timing_name begin
#         #     acc = evaluate(model, test_loader)
#         #     @info("epoch (acc)")
#         # end
#     end
# end