In [72]:
using Random
using NNlib: conv

Random.seed!(1234);

function conv_nowe(I, K)
    H, W, C, N = size(I) 
    HH, WW, C, F = size(K)
    H_R = 1 + H - HH
    W_R = 1 + W - WW
    out = zeros(H_R, W_R, F, N)
    for n=1:N
        for depth=1:F
            @views for r=1:H_R
                for c=1:W_R
                    out[r, c, depth, n] = sum(I[r:r+HH-1, c:c+WW-1, :, n] .* K[:, :, :, depth]) 
                end
            end
        end
    end
    return out
end




function conv_noweP(I, K)
    #I = reshape(I, 28 , 28, 1, 1)
    H, W, C, N = size(I) 
    HH, WW, C, F = size(K)
    H_R = 1 + H - HH
    W_R = 1 + W - WW
    out = zeros(H_R, W_R, F, N)
    for r=1:H_R
        @views for c=1:W_R
            r_field = I[r:r+HH-1, c:c+WW-1, :, :]
            r_field_flat = reshape(r_field, HH*WW*C, N)
            K_flat = reshape(K, HH*WW*C, F)
            out[r, c, :, :] = sum(K_flat .* r_field_flat, dims = 1)
        end
    end
    return out
end

function create_kernels_nowe(kernel_height, kernel_width, n_input, n_output)
    # Inicjalizacja Xaviera
    squid = sqrt(6 / (n_input + n_output * (kernel_width * kernel_height)))
    random_vals = randn(kernel_height, kernel_width, n_input, n_output) * squid
    return random_vals
end

function create_kernels(n_input, n_output, kernel_width, kernel_height)
    # Inicjalizacja Xaviera
    squid = sqrt(6 / (n_input + n_output * (kernel_width * kernel_height)))
    random_vals = randn(n_output, n_input, kernel_width, kernel_height) * squid
    return random_vals
end


create_kernels (generic function with 1 method)

In [84]:
function maxPool(x, kernel_size)
    H, W, C, N = size(x)
    K_H = kernel_size[1]
    K_W = kernel_size[2]
    W_2 = fld(W - K_W, K_W) + 1
    H_2 = fld(H - K_H ,K_H) + 1
    out = zeros(H_2, W_2, C, N)
    for n=1:N
        for c=1:C
            for h=1:H_2
                @views for w=1:W_2
                    out[h, w, c, n] = maximum(x[K_H*(h-1)+1:K_H*h, K_W*(w-1)+1:K_W*w, c, n])
                end
            end
        end
    end
    return out
end

function maxPoolP(x)
        h, w, c, n = size(x)
        output = zeros(h ÷ 2, w ÷ 2, c, n)
        indices = CartesianIndex{3}[]
        for i = 1:c
            for j = 1:h÷2
                for k = 1:w÷2
                    val, ids = findmax(@view x[2*j-1:2*j, 2*k-1:2*k, i, :])
                    output[j, k, i, 1] = val

                    idx, idy = ids[1] + 2 * j - 1 - 1, ids[2] + 2 * k - 1 - 1
                    push!(indices, CartesianIndex(idx, idy, i))
                end
            end
        end
        output
    end

maxPoolP (generic function with 2 methods)

In [81]:
test_n = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.21568628 0.53333336 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.6745098 0.99215686 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.07058824 0.8862745 0.99215686 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.19215687 0.07058824 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.67058825 0.99215686 0.99215686 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.11764706 0.93333334 0.85882354 0.3137255 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.09019608 0.85882354 0.99215686 0.83137256 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.14117648 0.99215686 0.99215686 0.6117647 0.05490196 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.25882354 0.99215686 0.99215686 0.5294118 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.36862746 0.99215686 0.99215686 0.41960785 0.003921569 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.09411765 0.8352941 0.99215686 0.99215686 0.5176471 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.6039216 0.99215686 0.99215686 0.99215686 0.6039216 0.54509807 0.043137256 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.44705883 0.99215686 0.99215686 0.95686275 0.0627451 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.011764706 0.6666667 0.99215686 0.99215686 0.99215686 0.99215686 0.99215686 0.74509805 0.13725491 0.0 0.0 0.0 0.0 0.0 0.15294118 0.8666667 0.99215686 0.99215686 0.52156866 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.07058824 0.99215686 0.99215686 0.99215686 0.8039216 0.3529412 0.74509805 0.99215686 0.94509804 0.31764707 0.0 0.0 0.0 0.0 0.5803922 0.99215686 0.99215686 0.7647059 0.043137256 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.07058824 0.99215686 0.99215686 0.7764706 0.043137256 0.0 0.007843138 0.27450982 0.88235295 0.9411765 0.1764706 0.0 0.0 0.18039216 0.8980392 0.99215686 0.99215686 0.3137255 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.07058824 0.99215686 0.99215686 0.7137255 0.0 0.0 0.0 0.0 0.627451 0.99215686 0.7294118 0.0627451 0.0 0.50980395 0.99215686 0.99215686 0.7764706 0.03529412 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.49411765 0.99215686 0.99215686 0.96862745 0.16862746 0.0 0.0 0.0 0.42352942 0.99215686 0.99215686 0.3647059 0.0 0.7176471 0.99215686 0.99215686 0.31764707 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.53333336 0.99215686 0.9843137 0.94509804 0.6039216 0.0 0.0 0.0 0.003921569 0.46666667 0.99215686 0.9882353 0.9764706 0.99215686 0.99215686 0.7882353 0.007843138 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.6862745 0.88235295 0.3647059 0.0 0.0 0.0 0.0 0.0 0.0 0.09803922 0.5882353 0.99215686 0.99215686 0.99215686 0.98039216 0.30588236 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.101960786 0.6745098 0.32156864 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.105882354 0.73333335 0.9764706 0.8117647 0.7137255 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.6509804 0.99215686 0.32156864 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.2509804 0.007843138 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 1.0 0.9490196 0.21960784 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.96862745 0.7647059 0.15294118 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.49803922 0.2509804 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0;;;;];


In [87]:
maxPool(conv_nowe(test_n, jj), (2,2)) == maxPoolP(conv_noweP(test_n, jj))

true

In [82]:
maxPoolP(conv_noweP(test_n, jj))

13×13×6×1 Array{Float64, 4}:
[:, :, 1, 1] =
 0.0  0.0        0.0        0.0        …  0.0       0.0        0.0       0.0
 0.0  0.0        0.0        0.0           0.0       0.332604   0.678849  0.0
 0.0  0.0        0.0937963  0.0608841     0.0       0.590355   1.18715   0.0
 0.0  0.0        0.560531   1.09198       0.127743  1.20398    1.17727   0.0
 0.0  0.0        0.753945   1.35979       0.587854  1.27881    1.02838   0.0
 0.0  0.0346391  1.11479    1.31711    …  1.0884    1.25224    0.702202  0.0
 0.0  0.0350223  1.40595    1.15527       1.25169   0.925199   0.139422  0.0
 0.0  0.264831   1.37321    1.25344       1.2836    0.60479    0.0       0.0
 0.0  0.320902   0.999215   0.694321      0.920607  0.0415122  0.0       0.0
 0.0  0.493644   0.908478   0.196702      0.619243  0.0        0.0       0.0
 0.0  0.458886   1.16948    0.107824   …  0.0       0.0        0.0       0.0
 0.0  0.0        0.571639   0.0388663     0.0       0.0        0.0       0.0
 0.0  0.0        0.0        0.0 

In [59]:
jj = create_kernels_nowe(3, 3, 1, 6);
println(sum(conv_nowe(test_n, jj)) == sum(conv_noweP(test_n, jj)))

true


In [67]:
@benchmark conv_nowe(test_n, jj)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m270.800 μs[22m[39m … [35m  5.889 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m307.200 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m369.794 μs[22m[39m ± [32m176.409 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m1.73% ± 6.09%

  [39m▆[39m█[39m▇[34m▇[39m[39m▅[39m▅[39m▄[39m▄[39m▃[32m▃[39m[39m▃[39m▃[39m▃[39m▃[39m▂[39m▂[39m▂[39m▂[39m▂[39m▂[39m▃[39m▂[39m▂[39m▂[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m▁[39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[3

In [71]:
@benchmark conv(test_n, jj)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 9.000 μs[22m[39m … [35m  4.929 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 97.53%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m25.100 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m29.700 μs[22m[39m ± [32m142.014 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m18.38% ±  4.11%

  [39m [39m▁[39m▆[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▇[34m█[39m[39m▅[39m▂[39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m█[39m█[39m█

In [70]:
import NNlib: conv

In [30]:
size(test_n)

(28, 28, 1, 1)