In [1]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA
using NNlib
using OneHotArrays

using TensorCast

# Disable slow GPU indexing
CUDA.allowscalar(false)

using BenchmarkTools
using ProgressMeter
using Test

include("utils.jl")
include("Spherical.jl");

In [2]:
include("alt_parallel.jl")
include("TFNLayers.jl")

## Testing Passenger

In [3]:
n_samples = 1000
n_points = 4
ℓi, ℓf, ℓos = 2, 1, [1]
total_outs = [2ℓo + 1 for ℓo in ℓos]

V = ones(Float32, (n_points, n_samples, 2ℓi+1)) |> gpu;

In [34]:
# Tests rely on the previous cell being run

@testset "Testing Passenger" begin
    a = ([V], [V, V], Vector{typeof(V)}(undef, 0))

    p_test = Passenger(triv_connect, SILayer(1 => 2), SILayer(2 => 1), identity) |> gpu

    # Checking that recognises tuple in second argument
    result = p_test(V, a)
    @test result[1] == V
    @test result[2] .|> length == (2, 1, 0)

    # Checking that it unpacks tuple
    result_tuple = p_test((V, a))
    @test result_tuple[1] == V
    @test result_tuple[2] .|> length == (2, 1, 0)

    # Testing that it works with splatted second argument
    result_splat = p_test(V, a...)
    @test result_tuple[1] == V
    @test result_tuple[2] .|> length == (2, 1, 0)
end;

@testset "Testing triv_connect" begin
    a = ([V], [V, V], Vector{typeof(V)}(undef, 0))
    b = (Vector{typeof(V)}(undef, 0), [V], [V])
    c = ([V], [V], [V])

    # Testing a function that does nothing
    p_trivial = Parallel(triv_connect, identity, identity, identity) |> gpu

    result_trivial = p_trivial(a, b, c)
    @test result_trivial == (a, b, c)

    # Testing that it can unpack tuple
    result_tuple = p_trivial((a, b, c))
    @test result_tuple == (a, b, c)
end;

[0m[1mTest Summary:         | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
Testing ParallelInert | [32m   6  [39m[36m    6  [39m[0m0.2s
[0m[1mTest Summary:        | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
Testing triv_connect | [32m   2  [39m[36m    2  [39m[0m0.0s


## Testing ParallelPassenger

In [23]:
n_samples = 1000
n_points = 4
ℓi, ℓf, ℓos = 2, 1, [1]
total_outs = [2ℓo + 1 for ℓo in ℓos]

V = ones(Float32, (n_points, n_samples, 2ℓi+1)) |> gpu;

centers = range(0f0, 3.5f0; length=4) |> collect

xs = rand(Float32, (n_points, 3, n_samples))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss) |> gpu;

In [40]:
n_samples = 1000
n_points = 4

ℓi, ℓf, ℓos = 0, 1, [1]

rss = rand(Float32, (n_points, 3, n_samples)) |> pairwise_rs |> cart_to_sph |> gpu
V = ones(Float32, (n_points, n_samples, 2ℓi+1)) |> gpu

@testset "Testing ParallelPassenger on tuple of feature array channels" begin
    # Have now changed it so that a = (rss, [V, V]) does not have the same effect
    a = (rss, (V, V))

    c1 = CLayer((ℓi, ℓf) => ℓos, centers; ℓ_max=1)
    c2 = CLayer((ℓi, ℓf) => ℓos, centers; ℓ_max=1)

    # This is what is applied within each ℓi
    p = ParallelPassenger(tuple_connect, c1, c2) |> gpu
    out = p(a...)
    @test out[1] == []
    @test out[2] |> length == length(a[2])

    p_alt = ParallelPassenger(tuple_connect, (c1, c2)) |> gpu
    out_alt = p_alt(a...)
    @test out_alt[1] == []
    @test out_alt[2] |> length == length(a[2])

    out_tuple = p_alt(a)
    @test out_tuple[1] == []
    @test out_tuple[2] |> length == length(a[2])
end;

[0m[1mTest Summary:                                                 | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
Testing ParallelPassenger on vector of feature array channels | [32m   6  [39m[36m    6  [39m[0m3.6s


## Testing E3ConvLayer

In [6]:
include("alt_parallel.jl")
include("TFNLayers.jl")

In [7]:
n_samples = 100
n_points = 4
ℓi, ℓf, ℓos = 1, 1, [0, 1]

centers = range(0f0, 3.5f0; length=4) |> collect
rss = rand(Float32, (n_points, 3, n_samples)) |> pairwise_rs |> cart_to_sph |> gpu
V = ones(Float32, (n_points, n_samples, 2ℓi+1)) |> gpu

input = (rss, (Vector{typeof(V)}(undef, 0), [V]))
n_cs = input[2] .|> length |> collect

e3_test = E3ConvLayer(n_cs, [[], [(ℓi, ℓf) => ℓos]], centers) |> gpu

E3ConvLayer{ParallelPassenger{typeof(tuple_connect), Tuple{var"#152#155", ParallelPassenger{typeof(tuple_connect), Tuple{ParallelPassenger{typeof(tuple_connect), Tuple{CLayer}}}}}}, Vector{Vector{Any}}}(ParallelPassenger(tuple_connect, #152, ParallelPassenger(tuple_connect, ParallelPassenger(tuple_connect, CLayer(FLayer{RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}, Vector{Function}}(RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[0.7390269, 0.9731275, -0.36122704, 0.6229573], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), Function[var"#159#160"(), var"#161#162"(), var"#163#164"()], 1), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((0, 0) => [0.0 0.0 0.57735026; 0.0 -0.57735026 0.0; 0.57735026 0.0 0.0], (1, -1) => [0.0 -0.70710677 0.0; 0.70710677 0.0 0.0; 0.0 0.0 0.0], (1, 1) => [0.0 0.0 0.0; 0.0 0.0 -0.70710677; 0.0 0.70710677 0.0], (1, 0) => [0.0 0.0 -0.70710677; 0.0 0.0 0.0; 0.70710677 0.0 0.0]), 1, 1, [0, 1], Dic

In [59]:
e3_test.Cs.layers

(var"#316#319"(Core.Box(1)), ParallelPassenger(tuple_connect, ParallelPassenger(tuple_connect, CLayer(FLayer{RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}, Vector{Function}}(RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.5436605, 1.0303081, 0.5202171, 0.9678404], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), Function[var"#323#324"(), var"#325#326"(), var"#327#328"()], 1), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((0, 0) => [0.0 0.0 0.57735026; 0.0 -0.57735026 0.0; 0.57735026 0.0 0.0], (1, -1) => [0.0 -0.70710677 0.0; 0.70710677 0.0 0.0; 0.0 0.0 0.0], (1, 1) => [0.0 0.0 0.0; 0.0 0.0 -0.70710677; 0.0 0.70710677 0.0], (1, 0) => [0.0 0.0 -0.70710677; 0.0 0.0 0.0; 0.70710677 0.0 0.0]), 1, 1, [0, 1], Dict(0 => 1, 1 => 2), 1))))

In [4]:
@testset "Testing individual layers" begin
    id = e3_test.Cs.layers[1]
    rest = e3_test.Cs.layers[2]

    # If n_c == 0, should do nothing, but remain consistent with other layers
    @test id(input[1], input[2][1]) == ([], [])

    rest_out = rest.layers[1].layers[1](rss, V)

    # Testing there are ℓ = 0, 1 present
    @test rest_out |> length == 2

    # Test below assumes ℓo = 0 and ℓo = 1 output
    # Testing for a tuple ofo vectors
    @test typeof(rest_out) == Tuple{Vector{typeof(V)}, Vector{typeof(V)}}
end

[0m[1mTest Summary:             | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1m Time[22m
Testing individual layers | [32m   3  [39m[36m    3  [39m

Test.DefaultTestSet("Testing individual layers", Any[], 3, false, false, true, 1.673529099025e9, 1.673529116791e9)

[0m17.8s


In [17]:
e3_test(input)[1][1]

4×100×1 CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}:
[:, :, 1] =
  0.743109  -0.347093  -0.540495  …  -0.751455  -0.119234   0.0219167
 -1.04766   -0.537744  -1.11924      -1.24692   -0.963858   0.0136303
 -1.65777    0.390556  -0.591949      1.25063    0.431188  -0.569
  0.79816   -0.669887   1.08752      -0.416425  -0.512264  -0.630715

In [12]:
rest = e3_test.Cs.layers[2]
c_int = rest.layers[1]

ParallelPassenger(tuple_connect, CLayer(FLayer{RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}, Vector{Function}}(RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[0.7390269, 0.9731275, -0.36122704, 0.6229573], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), Function[var"#159#160"(), var"#161#162"(), var"#163#164"()], 1), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((0, 0) => [0.0 0.0 0.57735026; 0.0 -0.57735026 0.0; 0.57735026 0.0 0.0], (1, -1) => [0.0 -0.70710677 0.0; 0.70710677 0.0 0.0; 0.0 0.0 0.0], (1, 1) => [0.0 0.0 0.0; 0.0 0.0 -0.70710677; 0.0 0.70710677 0.0], (1, 0) => [0.0 0.0 -0.70710677; 0.0 0.0 0.0; 0.70710677 0.0 0.0]), 1, 1, [0, 1], Dict(0 => 1, 1 => 2), 1))

In [15]:
single_V_vec = input[2][2]
single_V_vec |> typeof |> println
c_int |> typeof |> println

outthing = map((f, x) -> f(rss, x), c_int.layers, Tuple(single_V_vec))
outthing[1] |> length |> println # Should be a tuple of two singeton lists of Vs

tuple_connect(outthing...)[2] # Should 
#c_int(rss, single_V_vec) # Currently not working properly

Vector{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}
ParallelPassenger{typeof(tuple_connect), Tuple{CLayer}}
2


1-element Vector{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}:
 [0.71489376 -1.0920382 … 0.49052998 0.2931327; -0.85966456 -0.36603582 … -1.4648001 -0.30815986; -1.4722619 -0.12876242 … 0.4759482 -0.31176573; 0.1912238 0.16102761 … -0.92748654 -1.0990156;;; -0.15063535 -1.6710885 … 0.71131104 0.7803197; 0.6748457 0.117825046 … -1.4526098 -1.0612049; -0.2672376 -0.08193672 … 0.6688654 1.2509469; -0.2569728 1.6352 … 0.07243338 -0.9700614;;; -0.8655291 -0.5790503 … 0.22078106 0.48718697; 1.5345103 0.48386088 … 0.012190431 -0.7530451; 1.2050242 0.046825707 … 0.19291717 1.5627127; -0.4481966 1.4741724 … 0.9999199 0.1289542]

In [103]:
outthing[1]

(CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[[1.3108501 0.6637541 … 0.09197149 -0.13893569; -0.15591794 -0.6770607 … 0.75574577 0.06987971; -1.0228273 0.61436653 … 0.088679954 0.39364493; 0.12298012 -0.34597456 … -0.6813121 -0.06950386;;;]], CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[[1.2527649 0.79706514 … 0.1340313 -0.0129568875; -0.47993696 -0.45472288 … 0.30775726 -0.42566007; -0.6563413 0.3466434 … 0.38431752 0.32349795; 0.19592753 -0.3765713 … -0.51369166 0.42753315;;; 0.2003344 0.25427675 … 0.44121498 0.5977051; -0.44201413 0.04989609 … -1.0102649 -0.99718624; 0.22540557 -0.08119163 … 0.60818213 -0.3242905; 0.016274126 -0.22298117 … -0.039132148 0.7237715;;; -1.0524306 -0.5427884 … 0.30718368 0.61066204; 0.03792283 0.504619 … -1.318022 -0.57152617; 0.8817469 -0.42783502 … 0.22386461 -0.6477884; -0.1796534 0.15359013 … 0.47455955 0.29623836]])

## Debugging E3ConvLayer

In [16]:
pairingss = [[0 => [0], 1 => [1]]]

in_channel = []
for (ℓi, n_c, pairings) in zip(ℓis, n_cs, pairingss)
    
    for (ℓf, ℓos) in pairings
        channel_convs = Tuple(CLayer((ℓi, ℓf) => ℓos, centers) for c in 1:n_c)
        #channel_conv = CLayer((ℓi, ℓf) => ℓos, centers)
        p = ParallelPassenger(tuple_connect, channel_convs)
        push!(in_channel, p)
    end
end

c_test = ParallelPassenger(tuple_connect, in_channel) |> gpu

#E3ConvLayer(n_cs, pairingss, centers) |> gpu

ParallelPassenger(tuple_connect, [ParallelPassenger(tuple_connect, CLayer(FLayer{RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}, Vector{var"#169#170"}}(RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.4090891, 0.78547686, -0.21283774, 0.26003954], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), [var"#169#170"()], 0), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((0, 0) => [1.0;;]), 0, 0, [0], Dict(0 => 1), 0)), ParallelPassenger(tuple_connect, CLayer(FLayer{RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}, Vector{Function}}(RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.691431, -0.6132594, 1.021875, -0.51551455], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), Function[var"#171#172"(), var"#173#174"(), var"#175#176"()], 1), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((1, -1) => [1.0 0.0 0.0], (1, 1) => [0.0 0.0 1.0], (1, 0) => [0.0 1.0 0.0]), 0, 

In [13]:
include("alt_parallel.jl")

In [39]:
a = (rss, [V])

p = ParallelPassenger(tuple_connect, CLayer((0, 0) => [0], centers; ℓ_max=1), CLayer((0, 1) => [1], centers; ℓ_max=1)) |> gpu

out = p(a...)
#z1, z2 = out[1], out[2]
out[1]

1-element Vector{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}:
 [-6.1665177 -6.6280146 … -6.544309 -6.3328896; -6.2592435 -6.2963605 … -6.5402913 -6.3172874; -5.966563 -6.3983607 … -6.4238453 -6.616637; -6.477437 -6.4340253 … -6.006059 -6.05404;;;]

In [77]:
#p_classic = Parallel(tuple_connect, CLayer((0, 0) => [0], centers), CLayer((0, 1) => [1], centers)) |> gpu

(CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[[-4.307273 -4.008114 … -4.895726 -4.431291; -4.063853 -4.0651155 … -4.7924385 -4.7163777; -4.0703936 -3.892105 … -4.4321647 -4.6796117; -4.0994596 -4.202361 … -4.723231 -4.412351;;;]],)