In [1]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA
using NNlib

using TensorCast

# Catching slow GPU errors
CUDA.allowscalar(false)

using BenchmarkTools
using ProgressMeter

include("utils.jl")
include("Spherical.jl");

### Defining Data

In [4]:
# Define Tetris Shapes
tetris = [[0 0 0; 0 0 1; 1 0 0; 1 1 0],  # chiral_shape_1
          [0 0 0; 0 0 1; 1 0 0; 1 -1 0], # chiral_shape_2
          [0 0 0; 1 0 0; 0 1 0; 1 1 0],  # square
          [0 0 0; 0 0 1; 0 0 2; 0 0 3],  # line
          [0 0 0; 0 0 1; 0 1 0; 1 0 0],  # corner
          [0 0 0; 0 0 1; 0 0 2; 0 1 0],  # T
          [0 0 0; 0 0 1; 0 0 2; 0 1 1],  # zigzag
          [0 0 0; 1 0 0; 1 1 0; 2 1 0]]  # L

4×3 Matrix{Int64}:
 0  0  0
 0  0  1
 1  0  0
 1  1  0

### Testing Network Definition

In [2]:
include("TFNLayers.jl")

In [9]:
n_samples = 100
n_points = 4
ℓi, ℓf, ℓos = 0, 1, [1]

centers = range(0f0, 3.5f0; length=4) |> collect
c_test = CLayer(ℓi, ℓf, ℓos, centers) |> gpu

xs = rand(Float32, (n_points, 3, n_samples))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss) |> gpu

V = ones(Float32, (1, n_points, n_samples)) |> gpu
#V = ones(Float32, (1, 4))

outs_fake = rand(Float32, (n_points, n_samples)) |> gpu

#C_out = c_test(rss, V)

CLayer(FLayer(Chain(Dense(1 => 5, relu), Dense(5 => 5, relu), Dense(5 => 1, relu)), Function[var"#46#47"(), var"#48#49"(), var"#50#51"()], 1), Dict{Tuple{Int64, Int64}, CuArray{Float32}}((1, -1) => [1.0 0.0 0.0], (1, 1) => [0.0 0.0 1.0], (1, 0) => [0.0 1.0 0.0]), 0, 1, [(1, -1), (1, 0), (1, 1)])

In [14]:
optim = Flux.setup(Flux.Adam(0.01), c_test)

losses = []
@showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(c_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(rss, V)
        Flux.mse(y_hat, outs_fake)
    end
    Flux.update!(optim, c_test, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m


### Testing Increasing Dimensionality

In [3]:
n_samples = 6
n_points = 4

xs = rand(Float32, (n_points, 3, n_samples))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss)

ℓ = 2
Ys = generate_Yℓms(ℓ)
yss = rand(Float32, (n_points, n_points, n_samples, 2ℓ+1))

rss_gpu = rss |> gpu
yss_gpu = yss |> gpu

centers = range(0f0, 3.5f0; length=4) |> collect
f_test = FLayer(Ys, centers) |> gpu

f_test(rss_gpu)

4×4×6×5 CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}:
[:, :, 1, 1] =
  0.0           0.000515819   0.00211304  -0.00108708
  0.000515819   0.0          -0.00021092  -0.00257853
  0.00211304   -0.00021092    0.0          9.30606f-5
 -0.00108708   -0.00257853    9.30605f-5   0.0

[:, :, 2, 1] =
 0.0           0.00104856    0.000574914   0.000276161
 0.00104856    0.0          -0.00492658   -0.000712044
 0.000574913  -0.00492658    0.0          -0.0056656
 0.000276162  -0.000712044  -0.0056656     0.0

[:, :, 3, 1] =
  0.0          -0.000731042  -0.000739708  -0.00184645
 -0.000731042   0.0           0.00113411   -0.0053443
 -0.000739708   0.00113411    0.0          -0.00436006
 -0.00184645   -0.0053443    -0.00436006    0.0

[:, :, 4, 1] =
 0.0         0.00311083  0.00172662  0.00642716
 0.00311083  0.0         0.00120751  0.00209888
 0.00172662  0.00120751  0.0         0.00445498
 0.00642716  0.00209888  0.00445498  0.0

[:, :, 5, 1] =
  0.0          -9.49261f-5   0.000554043   0.00077583

In [5]:
optim = Flux.setup(Flux.Adam(0.01), f_test)

# Testing gradient
losses = []
@showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(f_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(rss_gpu)
        Flux.mse(y_hat, yss_gpu)
    end
    Flux.update!(optim, f_test, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01[39m
