In [1]:
@inline function im2col(x, k)
    steps = size(x,1) - k + 1
    B = Array{Float32, 3}(undef, steps, size(x,2)*k ,size(x,3))
    for i in 1:steps
        @views B[i, :, :] = reshape(x[i:(i + k - 1), :, :], 1, :, size(x,3)) 
    end
    return B
end

function convolution(x::Array{Float32,3},m::Array{Float32,3})
    x_new = Array{Float32, 3}(undef, size(x,1) - size(m,1) + 1, (size(x,2) - size(m,2) + 1)*size(m,3), size(x,3))
    kernel = reshape(m, :, size(m,3))
    data = im2col(x, size(m,1))
    for i in 1:size(x,3)
        @views x_new[:,:,i] .= data[:,:,i] * kernel
    end
    return x_new
end

convolution (generic function with 1 method)

In [None]:
function conv_backward(x::Array{Float32,3}, m::Array{Float32,3}, g::Array{Float32,3})
    dx = zeros(Float32, size(x))
    dm = zeros(Float32, size(m))
    
    x_cols = im2col(x, size(m,1))

    for i in 1:size(x,3)
        dm .+= reshape(x_cols[:,:,i]' * g[:,:,i], size(m))
    end

    rm = reshape(permutedims(reverse(m, dims=1), (1,3,2)), :, size(m,2))

    pad = size(m,1) - 1
    g_padded = zeros(Float32, size(g,1) + 2 * pad, size(g,2), size(x,3))
    g_padded[pad+1:end-pad, :, :] .= g

    g_cols = im2col(g_padded, size(m,1))

    for i in 1:size(x,3)
        dx[:,:,i] .= g_cols[:,:,i] * rm
    end

    return dx, dm
end


conv_backward (generic function with 1 method)

In [53]:
x = randn(Float32,130,50,64)
m = randn(Float32,3,50,8)
y = convolution(x,m)
g = ones(Float32,size(y))
conv_backward(x,m,g)

(Float32[0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.599069; … ; -2.016179 4.6485705 … 2.452697 -3.1847637; -0.064138174 -1.6898729 … 3.7586615 -4.0606775;;; 0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.599069; … ; -2.016179 4.6485705 … 2.452697 -3.1847637; -0.064138174 -1.6898729 … 3.7586615 -4.0606775;;; 0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.599069; … ; -2.016179 4.6485705 … 2.452697 -3.1847637; -0.064138174 -1.6898729 … 3.7586615 -4.0606775;;; … ;;; 0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.599069; … ; -2.016179 4.6485705 … 2.452697 -3.1847637; -0.064138174 -1.6898729 … 3.7586615 -4.0606775;;; 0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.599069; … ; -2.016179 4.6485705 … 2.452697 -3.1847637; -0.064138174 -1.6898729 … 3.7586615 -4.0606775;;; 0.6828363 -1.8169167 … 7.193335 2.723155; -1.2692046 4.5215263 … 5.887371 3.5990

# Comparison

In [54]:
using Flux
#data
x = randn(Float32,130,50,64)
m = randn(Float32,3,50,8)
rm = reverse(m,dims=1)

3×50×8 Array{Float32, 3}:
[:, :, 1] =
 -0.336207   0.512132  0.509357  -0.314065  …  0.294828  0.177997  -0.934498
  0.780016   1.14444   1.20562    1.1002       0.669616  1.03673    0.358969
  0.956336  -1.40987   0.687899   0.564863     0.650158  1.53122    0.375653

[:, :, 2] =
 -0.85027  0.372803   1.49067     …   1.69816      0.155651  -0.278529
 -1.07316  1.20112   -0.00990165      0.00522546  -0.235013   2.4
  1.02603  0.414474  -0.324279       -0.276929     0.194984  -1.80166

[:, :, 3] =
 -0.129836   1.4631    -0.668341  …   0.814141   1.12531   -1.42563
  0.137544  -0.380926  -0.511547     -0.611895  -0.461121  -0.15402
  1.47481   -1.03576    0.84139      -0.437292   0.660377  -0.0290537

[:, :, 4] =
  0.994073  -0.196833  -0.538353   …  1.0953    -0.631303   -0.654845
 -2.34453   -1.00694   -0.0389345     0.687816  -0.0135215   1.09241
  0.687373  -0.162389   1.3842        0.825871  -0.118808   -0.632476

[:, :, 5] =
  0.626835  0.123943  -2.23434   -0.28159   …  0.152754  

In [55]:
# forward
ym = convolution(x,rm)
yf = Flux.conv(x,m)

128×8×64 Array{Float32, 3}:
[:, :, 1] =
  15.4798     7.60887    -2.61251  …    9.69109    5.4144    -6.1916
  -3.54631    8.59294    13.2808        2.39363    6.57848  -17.3753
   1.7673    -8.39201     2.66384      17.0165     3.26916   -4.92219
   4.73263    7.66879    -2.65865       8.07935   22.0543    -1.88381
  -5.67883   -1.10338    10.1742        8.92518    7.32549   -1.27197
   5.26214   15.9064     -4.46369  …    9.77787   25.1153     7.26483
   2.09132   -6.17384    17.3509      -15.1559     8.25431    1.29603
  -3.17935   16.7812     -6.53453      -8.67801   11.0204   -24.4837
 -15.8797     9.10154    -7.94404       1.06537   -4.70584   -2.82584
  -2.10805  -17.6144     -3.95012     -39.2542     7.1037    19.6598
   ⋮                               ⋱    ⋮                   
  19.0827    -0.223634  -13.9065       15.7633    -6.38631    6.00417
  13.9364     2.88063   -10.0233   …    1.40874   -2.17492    3.02151
  -3.76031   -6.01944     1.31771      -2.50889    1.98685    1

In [56]:
# comparison
isapprox(ym,yf)

true

In [59]:
# backward
g = ones(Float32,size(ym))
dm = conv_backward(x,rm,g)

conv = Conv((3, ), 50 => 8, stride=1, pad=0)
conv.weight .= m

# 2. Gradient względem wejścia (∂L/∂x)
grad_input = gradient(x -> sum(conv(x)), x)[1]

# # 3. Gradient względem wag i biasów (∂L/∂W, ∂L/∂b)
loss_fn() = sum(conv(x))
grads = gradient(loss_fn, Flux.params(conv))
grad_weights = grads[conv.weight]
grad_bias    = grads[conv.bias]

│ Please see the docs for new explicit form.
│   caller = top-level scope at jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X10sZmlsZQ==.jl:13
└ @ Core c:\Users\Szymon\Desktop\AWiD\MyMlp\src\notebooks\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X10sZmlsZQ==.jl:13


8-element Vector{Float32}:
 8192.0
 8192.0
 8192.0
 8192.0
 8192.0
 8192.0
 8192.0
 8192.0

In [58]:
println(isapprox(dm[1],grad_input))
isapprox(reverse(dm[2],dims=1),grad_weights)

true


true

In [64]:
zzz = zeros(4,4,2)
bbb = [1 2 3 4]
yyy = zzz .+ bbb
sum(yyy,dims=2)

4×1×2 Array{Float64, 3}:
[:, :, 1] =
 10.0
 10.0
 10.0
 10.0

[:, :, 2] =
 10.0
 10.0
 10.0
 10.0