In [1]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA
using NNlib

using TensorCast

# Catching slow GPU errors
CUDA.allowscalar(false)

using BenchmarkTools
using ProgressMeter

include("utils.jl")
include("Spherical.jl");

### Defining Data

In [4]:
# Define Tetris Shapes
tetris = [[0 0 0; 0 0 1; 1 0 0; 1 1 0],  # chiral_shape_1
          [0 0 0; 0 0 1; 1 0 0; 1 -1 0], # chiral_shape_2
          [0 0 0; 1 0 0; 0 1 0; 1 1 0],  # square
          [0 0 0; 0 0 1; 0 0 2; 0 0 3],  # line
          [0 0 0; 0 0 1; 0 1 0; 1 0 0],  # corner
          [0 0 0; 0 0 1; 0 0 2; 0 1 0],  # T
          [0 0 0; 0 0 1; 0 0 2; 0 1 1],  # zigzag
          [0 0 0; 1 0 0; 1 1 0; 2 1 0]]  # L

4×3 Matrix{Int64}:
 0  0  0
 0  0  1
 1  0  0
 1  1  0

### Testing Network Definition

In [2]:
"""
One simple ℝ ≥ 0 -> ℝ function broadcasted across every elements of the array.
Function is a linear combination of basis functions ∑ᵢ aᵢ rbfᵢ(r), with learned weightings aᵢ.
"""
struct RLayer
    as::Vector{Float32}
    # TODO Maybe add another NN to copy TFN
    
    centers::Vector{Float32}
    γ::Float32
end

function RLayer(centers; init=Flux.glorot_uniform)
    n_basis = length(centers)
    as = init(n_basis)

    γ = (centers[end] - centers[1]) / n_basis
    RLayer(as, centers, γ)
end

function (R::RLayer)(radials)
    reduce(+, [a * @.(exp(- R.γ * (radials - c)^2)) for (a, c) in zip(R.as, R.centers)])
end

Flux.@functor RLayer (as,)

In [3]:
struct FLayer
    #R::Chain # Radial NN
    R::RLayer

    Ys::Vector{Function} # SH functions for this ℓf
    ℓf::Int # Filter angular momentum # TODO Consider removing
end

function FLayer(Ys::Vector{Function}, centers::Vector{Float32})
    # Will later allow for custom spec
    R = RLayer(centers)

    ℓf = (length(Ys) - 1) ÷ 2
    FLayer(R, Ys, ℓf)
end

# Dimension needs to be made one bigger
function (F::FLayer)(rr)
    # Apply R to the input radii
    rr_radials = rr[:, :, 1, :] # TODO Possibly change this back to view

    R_out = F.R(rr_radials)

    # Multiply by SH components
    θs = @view rr[:,:,2,:]
    ϕs = @view rr[:,:,3,:]
    Y_out = Flux.batch([Y.(θs, ϕs) for Y in F.Ys])
    
    R_out .* Y_out
end

Flux.@functor FLayer (R,)

#### Defining Convolution

In [4]:
struct CLayer
    F::FLayer # Trainable NN
    CG_mats::Dict{Tuple{Int, Int}, CuArray{Float32}} # Dictionary of matrices, keyed by (ℓo, mo)

    ℓi::Int # Input ℓ
    ℓf::Int # Filter ℓ
    ℓms::Vector{Tuple{Int, Int}} # Specifying output order
end

# Constructor
function CLayer(ℓi::Int, ℓf::Int, ℓos::Vector{Int}, centers::Vector{Float32})
    @assert ℓos ⊆ abs(ℓi - ℓf):(ℓi + ℓf) "Output `ℓo` not compatible with filter `ℓf` and input `ℓi`."

    Ys = generate_Yℓms(ℓf)
    F_NN = FLayer(Ys, centers)

    # Not going to choose every one
    ℓms::Vector{Tuple{Int, Int}} = []
    CG_mats::Dict{Tuple{Int, Int}, CuArray{Float32}} = Dict()
    for ℓo in ℓos
        for mo in -ℓo:ℓo
            push!(ℓms, (ℓo, mo))
            CG_mat = zeros(Float32, (2ℓi + 1, 2ℓf + 1))
            for (i_i, mi) in enumerate(-ℓi:ℓi)
                for (i_f, mf) in enumerate(-ℓf:ℓf)
                    # TODO Check that ordering of f and i is correct
                    # Currently giving zero
                    CG_mat[i_i, i_f] = cg(ℓi, mi, ℓf, mf, ℓo, mo)
                end
            end
            CG_mats[(ℓo, mo)] = cu(CG_mat)
        end
    end
    
    CLayer(F_NN, CG_mats, ℓi, ℓf, ℓms)
end

# Forward pass
"""
V indexed by [mi, b]
"""

#=
function (C::CLayer)(rr, V)
    n_points, _, _, n_samples = size(rr)

    F_out = permutedims(C.F(rr), (1, 4, 2, 3))
    F_reshape = reshape(F_out, n_points, 2C.ℓf + 1, :)

    V_repeat = repeat(V, inner=(1, 1, n_points))
    F_tilde = batched_mul(V_repeat, F_reshape)

    # TODO Make this general
    # For now just assume one CG_mat
    CG_mat = CuArray(C.CG_mats[C.ℓms[1]])

    L_tilde = CG_mat .* F_tilde

    # Steps using a lot of memory atm
    L = sum(L_tilde, dims=(1, 2))#[1,1,:]
    reshape(L, n_points, n_samples)
end
=#


# Trying to use tensorcast
function (C::CLayer)(rr, V)
    F_out = C.F(rr)
    
    # Using Einstein summation convention for brevity
    # Speed seems comparable
    @reduce F_tilde[mi, mf, a, γ] := sum(b) V[mi, b, γ] * F_out[b, a, γ, mf]

    # TODO Make this general
    # For now just assume one CG_mat
    CG_mat = C.CG_mats[C.ℓms[1]]

    L_tilde = CG_mat .* F_tilde

    @reduce L[a, γ] := sum(mi, mf) L_tilde[mi, mf, a, γ]
end


Flux.@functor CLayer (F,)

In [5]:
ℓi, ℓf, ℓos = 0, 1, [1]

centers = range(0f0, 3.5f0; length=4) |> collect

c_test = CLayer(ℓi, ℓf, ℓos, centers) |> gpu

xs = rand(Float32, (4, 3, 100))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss) |> gpu

V = ones(Float32, (1, 4, 100)) |> gpu
#V = ones(Float32, (1, 4))

C_out = c_test(rss, V)

4×10000 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.0586906   0.148469    0.23939   …  -0.0931654   0.119467   0.259697
 -0.12375    -0.694081    0.318815      0.172496   -0.12617   -0.245736
 -0.154179    0.510324   -0.722801      0.0979667   0.302776  -0.532552
  0.219239    0.0352885   0.164596     -0.177297   -0.296073   0.518592

#### Testing Increasing Dimensionality

In [10]:
xs = rand(Float32, (4, 3, 6))
xss = pairwise_rs(xs)
rss = cart_to_sph(xss)

sh = size(rss)
ℓ = 2
Ys = generate_Yℓms(ℓ)
yss = rand(Float32, (sh[1], sh[2], sh[4], 2ℓ+1))

rss_gpu = rss |> gpu
yss_gpu = yss |> gpu

centers = range(0f0, 3.5f0; length=4) |> collect
f_test = FLayer(Ys, centers) |> gpu

FLayer(RLayer(Float32[0.11861226, -0.5310653, 0.11397445, 0.89927113], Float32[0.0, 1.1666666, 2.3333333, 3.5], 0.875f0), Function[var"#24#25"(), var"#26#27"(), var"#28#29"(), var"#30#31"(), var"#32#33"()], 2)

In [11]:
optim = Flux.setup(Flux.Adam(0.01), f_test)

# Testing gradient
losses = []
@showprogress for epoch in 1:400
    loss, grads = Flux.withgradient(f_test) do f
        # Evaluate model and loss inside gradient context:
        y_hat = f(rss_gpu)
        Flux.mse(y_hat, yss_gpu)
    end
    Flux.update!(optim, f_test, grads[1])
    push!(losses, loss)
end

LoadError: MethodError: [0mCannot `convert` an object of type [92mBool[39m[0m to an object of type [91mVector{Float32}[39m
[0mClosest candidates are:
[0m  convert(::Type{T}, [91m::Union{Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}} where T, Union{Base.LogicalIndex{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}, LinearAlgebra.Adjoint{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Diagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.LowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Symmetric{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Transpose{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.Tridiagonal{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitLowerTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UnitUpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, LinearAlgebra.UpperTriangular{T, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}, PermutedDimsArray{T, N, <:Any, <:Any, <:Union{NNlib.BatchedAdjoint{T, <:CuArray{T}}, NNlib.BatchedTranspose{T, <:CuArray{T}}}}} where {T, N}}[39m) where T<:Array at C:\Users\xboxe\.julia\packages\NNlibCUDA\kCpTE\src\batchedadjtrans.jl:15
[0m  convert(::Type{Vector{T}}, [91m::AbstractAlgebra.Perm{T}[39m) where T at C:\Users\xboxe\.julia\packages\AbstractAlgebra\Oe2Uj\src\generic\PermGroups.jl:49
[0m  convert(::Type{Array{T, N}}, [91m::StaticArraysCore.SizedArray{S, T, N, N, Array{T, N}}[39m) where {S, T, N} at C:\Users\xboxe\.julia\packages\StaticArrays\jA1zK\src\SizedArray.jl:88
[0m  ...

In [140]:
R = Chain(
        Dense(1 => 5, relu),
        Dense(5 => 1, relu)
    )

vector = [0.1, 0.2, 0.3]
R(reshape(vector, (1, 3)))

1×3 Matrix{Float64}:
 0.0595705  0.119141  0.178712

In [286]:
@variables θ::Real, ϕ::Real
Ys_sym = computeYlm(θ, ϕ; lmax=2, SHType=SphericalHarmonics.RealHarmonics())

# Create dictionary of functions
Ys = Dict(key => (eval ∘ Base.remove_linenums! ∘ build_function)(Ys_sym[key] |> simplify, θ, ϕ)
            for key in Ys_sym.modes[1] |> collect)


Ys_bodged = Dict(key => convert_expr_to_F32(build_function(Ys_sym[key] |> simplify, θ, ϕ))
            for key in Ys_sym.modes[1] |> collect)

ℓ = 2
Ys_ℓ = [Ys_bodged[(ℓ, m)] for m in -ℓ:ℓ];