# Forward pass and gradient calculation

### Forward functions

In [4]:
# Conv 2d, assumes square input, filter and stride = 1, no bias
function conv2d_layer(input, filters)
    f_dim, in_channels, out_channels = size(filters)[2:4]
    i_dim = size(input)[2]
    out_dim = i_dim - f_dim + 1

    output = zeros(out_dim, out_dim, out_channels)
    
    for n in 1:out_channels
        for c in 1:in_channels
            conv2d!(@view(output[:, :, n]), input[:, :, c], filters[:,:, c, n], f_dim, out_dim)
        end
    end

    return output
end


function conv2d!(output, input, filter, f_dim, out_dim)
    for i in 1:out_dim
        for j in 1:out_dim
            output[i, j] += sum(input[i:i+f_dim-1, j:j+f_dim-1] .* filter)
        end
    end
end

conv2d! (generic function with 1 method)

In [5]:
# relu
function relu(input)
    return max.(0, input)
end

relu (generic function with 1 method)

In [6]:
# maxpool2d, assumes kernel_size == stride

function maxpool2d(input, kernel_size)
    dim_i, n_filters = size(input)[2:3]
    out_dim = floor(Int, dim_i / kernel_size)
    output = Array{Float32, 3}(undef, out_dim, out_dim, n_filters)
    indices = zero(input)

    for n in 1:n_filters
        for i in 1:out_dim
            for j in 1:out_dim
                input_fragment = input[(i-1)*kernel_size+1:i*kernel_size, (j-1)*kernel_size+1:j*kernel_size, n]
                max_x, max_y = Tuple(argmax(input_fragment))
                max_x += (i-1)*kernel_size
                max_y += (j-1)*kernel_size
                indices[max_x, max_y, n] = 1
                output[i, j, n] = input[max_x, max_y, n]
            end
        end
    end
    return output, indices
end

maxpool2d (generic function with 1 method)

In [7]:
# flatten
function flatten(input)
    return reshape(input, prod(size(input)), 1)
end

flatten (generic function with 1 method)

In [8]:
# Fully connected (linear)
function linear(input, weights, bias)
        # output = Array{Float32, 2}(undef, size(weights)[1], size(input)[2])
    return weights * input + bias
end

linear (generic function with 1 method)

In [9]:
# log softmax using LSE trick
function log_softmax(input)
    c = maximum(input)
    return input .- (c + log(sum(exp.(input .- c))))
end

log_softmax (generic function with 1 method)

In [10]:
# nll_loss, format equivalent to torch.nn.NLLLoss
function nll_loss(y, y_true)
    for i in 1:size(y)[1]
        if y_true[i] == 1.0
            return -y[i]
        end
    end
end

nll_loss (generic function with 1 method)

### Gradient calculations

In [11]:
# nll_loss grad
function nll_loss_grad(y_true)
    grad = zero(y_true)
    grad[y_true .== 1.0] .= -1.0
    return grad
end

nll_loss_grad (generic function with 1 method)

In [12]:
# log_softmax grad
function log_softmax_grad(input, grad)
    grad = exp.(input) / sum(exp.(input)) .+ grad
    return grad
end

log_softmax_grad (generic function with 1 method)

In [13]:
# linear grad
function linear_grad(weights, bias, input, grad)
    weights_grad = grad * input'
    bias_grad = grad
    input_grad = weights' * grad
    return input_grad, weights_grad, bias_grad
end

linear_grad (generic function with 1 method)

In [14]:
# flatten grad
function flatten_grad(original_input_shape, grad)
    return reshape(grad, original_input_shape)
end

flatten_grad (generic function with 1 method)

In [15]:
# maxpool2d grad
function maxpool2d_grad(indices, kernel_size, grad)
    output = copy(indices)
    dim_i, n_filters = size(indices)[2:3]
    idx_range = floor(Int, dim_i / kernel_size)

    for n in 1:n_filters
        for i in 1:idx_range
            for j in 1:idx_range
                output[(i-1)*kernel_size+1:i*kernel_size, (j-1)*kernel_size+1:j*kernel_size, n] .*= grad[i, j, n]
            end
        end
    end
    return output
end

maxpool2d_grad (generic function with 1 method)

In [16]:
# relu grad
function relu_grad(input, grad)
    negative_mask = input .<= 0.0
    grad[negative_mask] .= 0.0
    return grad
end

relu_grad (generic function with 1 method)

In [17]:
# conv2d grad
function conv2d_layer_grad(input, filters, grad)
    i_dim = size(input)[2]
    grad_dim = size(grad)[2]
    f_dim, in_channels, out_channels = size(filters)[2:4]
    
    padded_dim = size(input)[2] + f_dim - 1

    weights_grad = zeros(f_dim, f_dim, in_channels, out_channels)
    input_grad = zeros(i_dim, i_dim, in_channels)

    for n in out_channels
        padded_grad = zeros(padded_dim, padded_dim)
        padded_grad[f_dim: end - f_dim + 1, f_dim: end - f_dim + 1] = grad[:, :, n]

        for c in in_channels
            conv2d!(@view(input_grad[:, :, c]), padded_grad, reverse(filters[:, :, c, n]), f_dim, i_dim) # reverse(filters) (conv) padded(grad)
            conv2d!(@view(weights_grad[:, :, c, n]), input[:, :, c], grad[:, :, n], grad_dim, f_dim) # input (conv) grad
        end
    end
    
    return input_grad, weights_grad
end

conv2d_layer_grad (generic function with 1 method)

In [18]:
input = ones(28,28,1)
input[2,:,1] = input[2,:,1] * 5
filters1 = ones(3,3, 1, 32) .* 2 # filters format: (kernel_dim_1, kernel_dim_2, in_channel, out_channel) 
filters2 = ones(3,3, 32, 64) .* 2; # filters format: (kernel_dim_1, kernel_dim_2, in_channel, out_channel) 
linear_weights = ones(10, 1600) .* 0.5 # weights format: (out_dim, in_dim)
linear_bias = ones(10) .* 2;

label = Array{Float32, 2}(undef, 10, 1)
label *= 0
label[5] = 1

1

In [25]:
# forward pass of network
x1 = conv2d_layer(input, filters1)
x2 = relu(x1)
x3, indices_mp1 = maxpool2d(x1, 2)
x4 = conv2d_layer(x3, filters2)
x5 = relu(x4)
x6, indices_mp2 = maxpool2d(x5, 2)
x7 = flatten(x6)
x8 = linear(x7, linear_weights, linear_bias)
preds = log_softmax(x8)
loss = nll_loss(preds, label)

forward (generic function with 1 method)

In [27]:
using BenchmarkTools
@benchmark forward(input, filters1, filters2, linear_weights, linear_bias)

BenchmarkTools.Trial: 101 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m25.129 ms[22m[39m … [35m83.533 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.88% … 4.47%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m44.275 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m3.72%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m49.795 ms[22m[39m ± [32m19.193 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.42% ± 1.70%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m [39m[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▅[39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[39m▆[39m▃[39m▁[39m▁[39m

In [22]:
size(input), size(x1), size(x2), size(x3), size(x4), size(x5), size(x6), size(x7), size(x8)

((28, 28, 1), (26, 26, 32), (26, 26, 32), (13, 13, 32), (11, 11, 64), (11, 11, 64), (5, 5, 64), (1600, 1), (10, 1))

In [18]:
# backward pass of network
ig, wg, bg = linear_grad(linear_weights, linear_bias, x7, log_softmax_grad(preds, nll_loss_grad(label)))
ig = flatten_grad(size(x6), ig)
ig = maxpool2d_grad(indices_mp2, 2, ig)
ig = relu_grad(x4, ig)
ig, fg = conv2d_layer_grad(x3, filters2, ig)
ig = maxpool2d_grad(indices_mp1, 2, ig)
ig = relu_grad(x1, ig)
ig, fg = conv2d_layer_grad(input, filters1, ig)

([1.1102230246251565e-16 1.1102230246251565e-16 … 0.0 0.0; 1.1102230246251565e-16 1.1102230246251565e-16 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;;; 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;;; 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;;; … ;;;; 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;;; 0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0;;;; 1.2490009027033011e-14 1.2490009027033011e-14 1.2490009027033011e-14; 1.582067810090848e-14 1.582067810090848e-14 1.582067810090848e-14; 1.2490009027033011e-14 1.2490009027033011e-14 1.2490009027033011e-14])

In [6]:
include("nn.jl")

model_dims = [
    (28, 28, 1), 
    (26, 26, 32), 
    (26, 26, 32), 
    (13, 13, 32), 
    (11, 11, 64), 
    (11, 11, 64), 
    (5, 5, 64), 
    (1600, 1), 
    (10, 1)
]

model = NNGraph([
    Convolution2D(28, 3, 1, 32, model_dims[1]),
    RELU(model_dims[2]),
    MaxPool2D(2, model_dims[3]),
    Convolution2D(13, 3, 32, 64, model_dims[4]),
    RELU(model_dims[5]),
    MaxPool2D(2, model_dims[6]),
    Flatten(model_dims[7]),
    Linear(1600, 10, model_dims[8]),
    LogSoftmax(model_dims[9])
])

Computational graph:
(
Convolution2D(:input_size, 28)(:kernel_size, 3)(:in_filters, 1)(:out_filters, 32)(:input_shape, (28, 28, 1))
         ↓         
RELU(:input_shape, (26, 26, 32))
         ↓         
MaxPool2D(:kernel_size, 2)(:input_shape, (26, 26, 32))
         ↓         
Convolution2D(:input_size, 13)(:kernel_size, 3)(:in_filters, 32)(:out_filters, 64)(:input_shape, (13, 13, 32))
         ↓         
RELU(:input_shape, (11, 11, 64))
         ↓         
MaxPool2D(:kernel_size, 2)(:input_shape, (11, 11, 64))
         ↓         
Flatten(:input_shape, (5, 5, 64))
         ↓         
Linear(:input_neurons, 1600)(:output_neurons, 10)(:input_shape, (1600, 1))
         ↓         
LogSoftmax(:input_shape, (10, 1))
)

In [7]:
forward(Float32.(ones(28,28,1)), model)

Forward

Convolution2D

(:input_size, 28)(:kernel_size, 3)(:in_filters, 1)(:out_filters, 32)

(:input_shape, (28, 28, 1))


ForwardRELU(:input_shape, (26, 26, 32))


ForwardMaxPool2D(:kernel_size, 2)(:input_shape, (26, 26, 32))
ForwardConvolution2D(:input_size, 13)(:kernel_size, 3)(:in_filters, 32)(:out_filters, 64)(:input_shape, (13, 13, 32))
ForwardRELU(:input_shape, (11, 11, 64))
ForwardMaxPool2D(:kernel_size, 2)(:input_shape, (11, 11, 64))
Forward

Flatten(:input_shape, (5, 5, 64))
ForwardLinear(:input_neurons, 1600)(:output_neurons, 10

)(:input_shape, (1600, 1))


ForwardLogSoftmax(:input_shape, (10, 1))


28×28×1 Array{Float32, 3}:
[:, :, 1] =
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  …  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  …  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 ⋮                        ⋮              ⋱                      ⋮         
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.