In [19]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA
using NNlib
using OneHotArrays
using MLUtils

using Rotations
using TensorCast

# Disable slow GPU indexing
CUDA.allowscalar(false)

using BenchmarkTools
using ProgressMeter

include("utils.jl")
include("Spherical.jl")
include("alt_parallel.jl")

In [29]:
include("TFNLayers.jl")
include("alt_parallel.jl")

## Testing Basics of Network

In [5]:
# testing R
centers = range(0f0, 3.5f0; length=4) |> collect
r_test = RLayer(centers) |> gpu

RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.48342064, -0.5924541, -0.46601388, -0.4184361], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0)

In [87]:
n_samples = 1000
n_points = 4
ℓi, ℓf, ℓos = 2, 1, [1]
total_outs = [2ℓo + 1 for ℓo in ℓos]

centers = range(0f0, 3.5f0; length=4) |> collect
c_test = CLayer((ℓi, ℓf) => ℓos, centers) |> gpu

xs = rand(Float32, (n_points, 3, n_samples))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss) |> gpu

V = ones(Float32, (n_points, n_samples, 2ℓi+1)) |> gpu
#V = ones(Float32, (1, 4))
outs_fake = rand(Float32, (n_points, n_samples, total_outs[1])) |> gpu

C_out = c_test(rss, V)

(CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[], CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[[-0.36121196 -0.021606624 … 0.20287748 0.41855606; 0.036829777 0.3516178 … -0.5751697 -0.262716; 0.041203 0.24422988 … -0.07556881 -0.031145327; 0.08022238 -0.7771977 … 0.24490415 -0.32765147;;; -0.34861088 0.035238333 … 0.08414009 0.3777876; 0.06861416 0.43490127 … -0.46269113 -0.11485022; -0.07405516 0.116280064 … -0.15479514 -0.1404848; 0.119697616 -0.8207739 … 0.29899177 -0.35680676;;; -0.3369673 0.08685716 … 0.04711111 0.26186392; 0.09838154 0.47218075 … -0.26507103 0.06836632; -0.2230749 -0.029416293 … -0.2969591 -0.24748111; 0.25870395 -0.7325784 … 0.31196216 -0.28570586]])

In [18]:
optim = Flux.setup(Flux.Adam(0.1), c_test)

losses = []
@showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(c_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(rss, V)
        Flux.mse(y_hat[2][1], outs_fake)
    end
    Flux.update!(optim, c_test, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:43[39m


In [119]:
a = (V, ([V], [V, V], []))

p_test = Chain(
            # Self interaction
            Parallel(triv_connect, identity, Parallel(triv_connect, SILayer(1 => 2), SILayer(2 => 3), identity)),
            # Nonlinearity
            Parallel(triv_connect, NLLayer(2), NLLayer(3), identity)
 ) |> gpu

#ou = p_test(a)[1][1]
optim = Flux.setup(Flux.Adam(0.1), p_test)

losses = []
    @showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(p_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(a)[1][1]
        Flux.mse(y_hat, V)
    end;
    Flux.update!(optim, p_test, grads[1])
    push!(losses, loss)
end


#nl_test = NLLayer(1) |> gpu
#nl_test([V]) |> typeof

#combined = Tuple(vcat(x, y) for (x, y) in zip($a, $b))
#reduce((x, y) -> (x..., y...), (a, b, c)) |> length

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m


In [84]:
a = (V, ([V], [V, V], []))

p_passenger = Chain(
            # Self interaction
            Parallel(triv_connect, identity, Parallel(triv_connect, SILayer(1 => 2), SILayer(2 => 3), identity))
) |> gpu


p_passenger(a)

([1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0], (CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[[0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787;;; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787;;; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.58756787; 0.58756787 0.58756787 … 0.58756787 0.587567

In [29]:
Vector{eltype([V])}(undef, 0)

CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}[]

In [41]:
a = ([V], [V, V], Vector{typeof(V)}(undef, 0))
b = (Vector{typeof(V)}(undef, 0), [V], [V])
c = ([V], [V], [V])

#@btime Tuple(vcat(x, y) for (x, y) in zip($a, $b)) # Pretty fast

function tuple_connect(xa, ya)
    Tuple(vcat(x, y) for (x, y) in zip(xa, ya))
end

answer = reduce(tuple_connect, (a, b, c))

answer[2]

4-element Vector{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}:
 [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]
 [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]
 [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0;;; 

In [29]:
p = Parallel(hcat, Dense(1 => 4), Dense(1 => 4)) |> gpu
d_test = Dense(1 => 4) |> gpu

p_compat = (x -> p(x...))

test_mat = ones(Float32, (1, 3)) |> gpu

c_test = Chain(p_compat, Dense(4 => 1)) |> gpu

p([test_mat, test_mat]...)

4×6 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.298705   0.298705   0.298705  -0.331017   -0.331017   -0.331017
  0.192269   0.192269   0.192269  -0.0233604  -0.0233604  -0.0233604
 -0.479108  -0.479108  -0.479108  -0.0682962  -0.0682962  -0.0682962
  1.02654    1.02654    1.02654    0.583121    0.583121    0.583121

In [70]:
ℓi = 0
ℓ_max = 1
in_channel = []
n_c = 2
for (ℓf, ℓos) in [1 => [1]]
    channel_convs = Tuple(CLayer((ℓi, ℓf) => ℓos, centers; ℓ_max = ℓ_max) for c in 1:n_c)

    p = Parallel(tuple_connect, channel_convs)
    push!(in_channel, Chain(x -> Tuple(x), p))
    #push!(in_channel, p)
end

#in_channels[1]
p_out = Parallel(tuple_connect, in_channel...) |> gpu

Parallel(
  tuple_connect,
  Chain(
    var"#534#536"(),
    Parallel(
      tuple_connect,
      CLayer(
        FLayer(
          RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.6948973, 0.8630991, -0.07906914, -0.94141114], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0),  [90m# 4 parameters[39m
        ),
      ),
      CLayer(
        FLayer(
          RLayer{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vector{Float32}, Float32}(Float32[-0.38486096, 0.39335984, 0.77282935, -0.76927626], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0),  [90m# 4 parameters[39m
        ),
      ),
    ),
  ),
) 

In [20]:
rss_vec = [rss]
V_vec = [V]

rss_tup = (rss,)
V_tup = (V,)

@btime vcat($rss_vec, $V_vec)
@btime (rss_tup..., V_tup)
@btime (rss_vec..., V_vec...)

  18.337 ns (1 allocation: 64 bytes)
  79.959 ns (3 allocations: 64 bytes)
  48.077 ns (1 allocation: 32 bytes)


([0.0 0.43549958 1.0521978 0.56605774; 0.43549958 0.0 1.3980144 0.9712511; 1.0521978 1.3980144 0.0 0.8892459; 0.56605774 0.9712511 0.8892459 0.0;;; 0.0 2.3328195 1.0250881 0.37029645; 0.8087733 0.0 0.92014986 0.5493338; 2.1165047 2.2214427 0.0 1.5915178; 2.7712963 2.592259 1.5500748 0.0;;; 0.0 1.4255625 -0.7652147 -2.1712863; -1.7160301 0.0 -0.99782175 -1.8945957; 2.376378 2.143771 0.0 2.605678; 0.97030634 1.2469969 -0.53591454 0.0;;;; 0.0 0.7547935 0.6804147 0.20409416; 0.7547935 0.0 0.28264362 0.59508324; 0.6804147 0.28264362 0.0 0.5821993; 0.20409416 0.59508324 0.5821993 0.0;;; 0.0 1.5146486 1.5841296 1.1584036; 1.6269442 0.0 1.753775 1.504465; 1.557463 1.3878177 0.0 1.4140692; 1.9831891 1.6371276 1.7275236 0.0;;; 0.0 0.7300529 0.35341373 1.2074509; -2.4115398 0.0 -1.2909464 -2.5567398; -2.788179 1.8506463 0.0 -3.0358489; -1.9341418 0.584853 0.10574378 0.0;;;; 0.0 0.23809826 0.75272274 0.4063107; 0.23809826 0.0 0.8707703 0.52639395; 0.75272274 0.8707703 0.0 0.3507919; 0.4063107 0.52

### Testing Increasing Dimensionality

In [3]:
n_samples = 6
n_points = 4

xs = rand(Float32, (n_points, 3, n_samples))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss)

ℓ = 2
Ys = generate_Yℓms(ℓ)
yss = rand(Float32, (n_points, n_points, n_samples, 2ℓ+1))

rss_gpu = rss |> gpu
yss_gpu = yss |> gpu

centers = range(0f0, 3.5f0; length=4) |> collect
f_test = FLayer(Ys, centers) |> gpu

f_test(rss_gpu)

4×4×6×5 CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}:
[:, :, 1, 1] =
 -0.0          0.00754981  -0.0499345  -0.116065
  0.00754981  -0.0         -0.213974   -0.329704
 -0.0499345   -0.213974    -0.0         0.545569
 -0.116065    -0.329704     0.545569   -0.0

[:, :, 2, 1] =
 -0.0        0.232802   0.160161    0.262934
  0.232802  -0.0       -0.229806    0.456423
  0.160161  -0.229806  -0.0         0.00174865
  0.262934   0.456423   0.0017487  -0.0

[:, :, 3, 1] =
 -0.0        -0.370857   0.226216    -0.0809055
 -0.370857   -0.0       -0.503037    -0.22966
  0.226216   -0.503037  -0.0          0.00735487
 -0.0809054  -0.22966    0.00735485  -0.0

[:, :, 4, 1] =
 -0.0       -0.11373     0.145963   0.258541
 -0.11373   -0.0         0.113866  -0.0867647
  0.145963   0.113866   -0.0        0.141403
  0.258541  -0.0867647   0.141403  -0.0

[:, :, 5, 1] =
 -0.0        0.166011  -0.113239   0.15408
  0.166011  -0.0        0.425438  -0.265191
 -0.113239   0.425438  -0.0        0.450372
  0.15408

In [4]:
optim = Flux.setup(Flux.Adam(0.01), f_test)

# Testing gradient
losses = []
@showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(f_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(rss_gpu)
        Flux.mse(y_hat, yss_gpu)
    end
    Flux.update!(optim, f_test, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:36[39m


## Testing Composition of Blocks

In [56]:
centers = range(0f0, 3.5f0; length=4) |> collect

n_samples = 100
n_points = 4

rss = rand(Float32, (n_points, 3, n_samples)) |> pairwise_rs |> cart_to_sph |> gpu
V0 = ones(Float32, (n_points, n_samples, 1)) |> gpu
Vrand = rand(Float32, (n_points, n_samples, 1)) |> gpu;

In [57]:
# Try implementing first parts of TFN Process
p_chain = Chain(
            Passenger(triv_connect, SILayer(1 => 4)),
            E3ConvLayer([4], [[(0, 0) => [0], (0, 1) => [1]]], centers),
            Passenger(triv_connect, SILayer(4 => 4), SILayer(4 => 4)),
            Passenger(triv_connect, NLLayer(4), NLLayer(4)),
            E3ConvLayer([4, 4], [[(0, 0) => [0], (0, 1) => [1]], [(1, 0) => [1], (1, 1) => [0, 1]]], centers),
            Passenger(triv_connect, SILayer(8 => 4), SILayer(12 => 4)),
            Passenger(triv_connect, NLLayer(4), NLLayer(4)),
            E3ConvLayer([4, 4], [[(0, 0) => [0]], [(1, 1) => [0]]], centers),
            Passenger(triv_connect, SILayer(8 => 4)),
            Passenger(triv_connect, NLLayer(4)),
) |> gpu;

In [58]:
input = (rss, ([V0],))
input2 = (rss, ([Vrand],))

p_chain(input)

([0.0 1.1618803 0.09173866 0.6294662; 1.1618803 0.0 1.1901708 0.7764151; 0.09173866 1.1901708 0.0 0.6634445; 0.6294662 0.7764151 0.6634445 0.0;;; 0.0 0.8811243 1.8064804 0.21447161; 2.2604682 0.0 2.2642 1.7315067; 1.3351122 0.87739277 0.0 0.28615603; 2.9271212 1.4100859 2.8554366 0.0;;; 0.0 -0.14347813 -1.8804499 0.08385577; 2.9981146 0.0 3.0943747 2.9587073; 1.2611427 -0.04721811 0.0 0.53937507; -3.0577369 -0.1828853 -2.6022177 0.0;;;; 0.0 0.8503873 0.5912107 1.0756193; 0.8503873 0.0 0.5026764 0.808194; 0.5912107 0.5026764 0.0 0.5323414; 1.0756193 0.808194 0.5323414 0.0;;; 0.0 2.745924 2.3553429 2.0825462; 0.39566872 0.0 0.75236124 1.2459434; 0.7862499 2.3892314 0.0 1.7770909; 1.0590465 1.8956492 1.3645017 0.0;;; 0.0 2.8689046 1.9412713 2.0071938; -0.272688 0.0 1.0725368 1.6764148; -1.2003213 -2.0690558 0.0 2.0601158; -1.1343988 -1.4651778 -1.0814768 0.0;;;; 0.0 0.52027565 0.52100843 0.45821744; 0.52027565 0.0 0.67533183 0.6109644; 0.52100843 0.67533183 0.0 0.12161232; 0.45821744 0.61

## Trying Shape Classification

In [30]:
# Define Tetris Shapes
tetris = [[0 0 0; 0 0 1; 1 0 0; 1 1 0],  # chiral_shape_1
          [0 0 0; 0 0 1; 1 0 0; 1 -1 0], # chiral_shape_2
          [0 0 0; 1 0 0; 0 1 0; 1 1 0],  # square
          [0 0 0; 0 0 1; 0 0 2; 0 0 3],  # line
          [0 0 0; 0 0 1; 0 1 0; 1 0 0],  # corner
          [0 0 0; 0 0 1; 0 0 2; 0 1 0],  # T
          [0 0 0; 0 0 1; 0 0 2; 0 1 1],  # zigzag
          [0 0 0; 1 0 0; 1 1 0; 2 1 0]]  # L

tetris = convert.(Array{Float32, 2}, tetris)
tetris_batched = batch(tetris)
onehot_tetris = onehotbatch(1:length(tetris) |> collect, 1:length(tetris)) |> gpu
Vones = ones(Float32, (size(tetris_batched, 1), size(tetris_batched, 3), 1)) |> gpu;

In [31]:
sph_tetris = tetris_batched |> pairwise_rs |> cart_to_sph |> gpu;
centers = range(0f0, 3.5f0; length=4) |> collect

classifier = Chain(
            Passenger(triv_connect, SILayer(1 => 4)),
            E3ConvLayer([4], [[(0, 0) => [0], (0, 1) => [1]]], centers),
            Passenger(triv_connect, SILayer(4 => 4), SILayer(4 => 4)),
            Passenger(triv_connect, NLLayer(4), NLLayer(4)),
            E3ConvLayer([4, 4], [[(0, 0) => [0], (0, 1) => [1]], [(1, 0) => [1], (1, 1) => [0, 1]]], centers),
            Passenger(triv_connect, SILayer(8 => 4), SILayer(12 => 4)),
            Passenger(triv_connect, NLLayer(4), NLLayer(4)),
            E3ConvLayer([4, 4], [[(0, 0) => [0]], [(1, 1) => [0]]], centers),
            Passenger(triv_connect, SILayer(8 => 4)),
            Passenger(triv_connect, NLLayer(4)),
            PLayer(),
            Dense(4 => 8)
) |> gpu;

In [32]:
classifier((sph_tetris, ([Vones],)))

8×8 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.0752553    0.0751255    0.0116361   …   0.0880927   0.0915387   0.0746019
  0.108796     0.108508     0.0222652       0.139159    0.132958    0.111135
  0.131753     0.131424     0.0253054       0.13741     0.151866    0.121334
  0.00618636   0.00623168  -0.00100601      0.0234849   0.0149541   0.0144971
  0.0307177    0.0306595    0.00672953      0.0646946   0.0473721   0.0433133
 -0.0859591   -0.0856564   -0.0253435   …  -0.0618561  -0.0921236  -0.0667383
 -0.0809994   -0.08076     -0.0193537      -0.0519535  -0.0842926  -0.0603588
  0.146391     0.146048     0.0281436       0.128569    0.163379    0.12515

In [33]:
optim = Flux.setup(Flux.Adam(0.01), classifier)

# Testing gradient
losses = []
@showprogress for epoch in 1:600
    loss, grads = Flux.withgradient(classifier) do c
        # Evaluate model and loss inside gradient context:
        y_hat = c((sph_tetris, ([Vones],)))
        Flux.logitcrossentropy(y_hat, onehot_tetris)
    end
    Flux.update!(optim, classifier, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:58[39m


In [46]:
NNlib.softmax(classifier((sph_tetris, ([Vones],))))

8×8 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.999727     0.000243681  4.11271f-5   …  6.32147f-30  1.08455f-36
 0.000273181  0.999726     1.21361f-13     2.64894f-10  1.13664f-17
 4.73231f-11  1.83895f-17  0.999942        0.0          0.0
 0.0          0.0          0.0             0.0          1.22686f-23
 4.57479f-11  1.34753f-12  1.73029f-5      0.0          0.0
 0.0          0.0          0.0          …  2.6604f-22   6.72408f-10
 1.46379f-18  3.07302f-5   8.95438f-26     0.999981     7.45317f-7
 1.57f-43     1.78257f-39  0.0             1.91691f-5   0.999999

In [37]:
rotated_tetris = [(rand(RotMatrix{3, Float32}) * tet')' for tet in tetris]
rotated_tetris_batched = batch(rotated_tetris)
sph_rotated = rotated_tetris_batched |> pairwise_rs |> cart_to_sph |> gpu;

In [47]:
final = NNlib.softmax(classifier((sph_rotated, ([Vones],))))

8×8 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.999727     0.000243681  4.1128f-5    …  6.3208f-30   1.08438f-36
 0.000273179  0.999726     1.21365f-13     2.64878f-10  1.13657f-17
 4.73249f-11  1.83901f-17  0.999942        0.0          0.0
 0.0          0.0          0.0             0.0          1.22683f-23
 4.57514f-11  1.3476f-12   1.73037f-5      0.0          0.0
 0.0          0.0          0.0          …  2.66045f-22  6.72405f-10
 1.46374f-18  3.07297f-5   8.95442f-26     0.999981     7.45306f-7
 1.57f-43     1.78243f-39  0.0             1.91688f-5   0.999999

In [44]:
OneHotArrays.onecold(final |> cpu)

8-element Vector{Int64}:
 1
 2
 3
 4
 5
 6
 7
 8