In [33]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA

using BenchmarkTools

### Defining Data

In [4]:
# Define Tetris Shapes
tetris = [[0 0 0; 0 0 1; 1 0 0; 1 1 0],  # chiral_shape_1
          [0 0 0; 0 0 1; 1 0 0; 1 -1 0], # chiral_shape_2
          [0 0 0; 1 0 0; 0 1 0; 1 1 0],  # square
          [0 0 0; 0 0 1; 0 0 2; 0 0 3],  # line
          [0 0 0; 0 0 1; 0 1 0; 1 0 0],  # corner
          [0 0 0; 0 0 1; 0 0 2; 0 1 0],  # T
          [0 0 0; 0 0 1; 0 0 2; 0 1 1],  # zigzag
          [0 0 0; 1 0 0; 1 1 0; 2 1 0]]  # L

4×3 Matrix{Int64}:
 0  0  0
 0  0  1
 1  0  0
 1  1  0

### Testing Network

In [21]:
"""
Convert cartesian vectors into spherical coordinates.
Columns are [r, θ, ϕ]
"""
function cart_to_sph(xs)
    rs = @. √(xs[:,1]^2 + xs[:,2]^2 + xs[:,3]^2) 
    θs = @. acos(xs[:,3] / rs)
    ϕs = @. atan(xs[:,2], xs[:,1])

    [rs θs ϕs]
end

4×5 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.0         0.0         -0.026137   3.95768f-9    0.0452706
 -3.95768f-9  3.95768f-9  -0.026137  -1.72996f-16  -0.0452706
 -3.95768f-9  3.95768f-9  -0.026137  -1.72996f-16  -0.0452706
 -1.1873f-8   1.1873f-8   -0.078411  -5.18987f-16  -0.135812

In [104]:
struct FLayer
    R::Chain # Radial NN
    Ys::Vector{Function} # SH functions for this ℓf
    ℓf::Int # Filter angular momentum # TODO Consider removing
end

function FLayer(Ys::Vector{Function})
    # Will later allow for custom spec
    R = Chain(
        Dense(1 => 5, relu),
        Dense(5 => 5, relu),
        Dense(5 => 1, relu)
    )

    ℓf = (length(Ys) - 1) ÷ 2
    FLayer(R, Ys, ℓf)
end

# Dimension needs to be made one bigger
function (F::FLayer)(rr)
    # Apply R to the input radii
    rr_radial = rr[:,1]'
    R_out = F.R(rr_radial)

    # Multiply by SH components
    Y_out = reduce(hcat, [Y.(rr[:,2], rr[:,3]) for Y in F.Ys])
    R_out' .* Y_out
end

Flux.@functor FLayer (R,)

In [None]:
@variables θ, ϕ
Ys_sym = computeYlm(θ, ϕ; lmax=2, SHType=SphericalHarmonics.RealHarmonics())

# Create dictionary of functions
Ys = Dict(key => (eval ∘ build_function)(Ys_sym[key] |> simplify, θ, ϕ)
            for key in Ys_sym.modes[1] |> collect)

ℓ = 2
Ys_ℓ = [Ys[(ℓ, m)] for m in -ℓ:ℓ]

In [106]:
xs1::Matrix{Float32} = [1 0 0; 0 1 0; 0 1 0; 0 3 0]
xs_sph = cart_to_sph(xs1)
ys_rand = rand(Float32, (4, 5))

xs_gpu = xs_sph |> gpu
ys_gpu = ys_rand |> gpu

f_test = FLayer(Ys_ℓ) |> gpu
optim = Flux.setup(Flux.Adam(0.01), f_test)

# Testing gradient
loss, grads = Flux.withgradient(f_test) do m
    # Evaluate model and loss inside gradient context:
    y_hat = m(xs_gpu)
    Flux.crossentropy(y_hat, ys_gpu)
end
Flux.update!(optim, f_test, grads[1])

((R = (layers = ((weight = [32mLeaf(Adam{Float64}(0.01, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0; -0.365752; … ; 0.0; -0.127473;;], Float32[0.0; 0.0133774; … ; 0.0; 0.00162493;;], (0.81, 0.998001))[32m)[39m, bias = [32mLeaf(Adam{Float64}(0.01, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0, -0.311517, 0.270187, 0.0, -0.108571], Float32[0.0, 0.0097043, 0.00730011, 0.0, 0.00117876], (0.81, 0.998001))[32m)[39m, σ = ()), (weight = [32mLeaf(Adam{Float64}(0.01, (0.9, 0.999), 1.0e-8), [39m(Float32[0.0 -0.230803 … 0.0 -0.361674; 0.0 0.107301 … 0.0 0.168143; … ; 0.0 -0.109133 … 0.0 -0.171014; 0.0 -0.068439 … 0.0 -0.107246], Float32[0.0 0.00532698 … 0.0 0.0130808; 0.0 0.00115134 … 0.0 0.00282721; … ; 0.0 0.001191 … 0.0 0.00292459; 0.0 0.00046839 … 0.0 0.00115017], (0.81, 0.998001))[32m)[39m, bias = [32mLeaf(Adam{Float64}(0.01, (0.9, 0.999), 1.0e-8), [39m(Float32[-0.405279, 0.188415, 0.0, -0.191633, -0.120176], Float32[0.0164251, 0.00355003, 0.0, 0.00367231, 0.00144423], (0.81, 0.998001))[

In [None]:
struct CLayer
    F::FLayer # Trainable NN
    CG_mats::Dict{Tuple{Int, Int}, Matrix} # Dictionary of matrices, keyed by (ℓo, mo)

    ℓi::Int # Input ℓ
    ℓf::Int # Filter ℓ
    ℓms::Vector{Tuple{Int, Int}} # Specifying output order
end

# Constructor
function CLayer()

    # Choosing
    ℓm_pairs = []
    for ℓo in abs(ℓi - ℓf):(ℓi + ℓf)
        for mo in -ℓo:ℓo
            ℓm_pairs.append!([(ℓo, mo)])
end

# Forward pass
function (C::CLayer)(rr, V)
    F_out = F(rr)
    FV = F_out' * V

    # TODO Use symmetry in CG
    reduce(hcat, [sum(CG_mat[ℓm] .* FV) for ℓm in C.ℓms])
end

Flux.@functor CLayer (F,)