In [1]:
using Flux
using Symbolics
using SphericalHarmonics
using CUDA
using NNlib
using OneHotArrays
using MLUtils

using Rotations
using TensorCast
using Test

# Disable slow GPU indexing and turn warnings into errors
CUDA.allowscalar(false)

using BenchmarkTools
using ProgressMeter

include("utils.jl")
include("Spherical.jl")
include("alt_parallel.jl")

In [2]:
include("TFNLayers.jl")

## Trying Shape Classification

In [3]:
# Define Tetris Shapes
tetris = [[0 0 0; 0 0 1; 1 0 0; 1 1 0],  # chiral_shape_1
          [0 0 0; 0 0 1; 1 0 0; 1 -1 0], # chiral_shape_2
          [0 0 0; 1 0 0; 0 1 0; 1 1 0],  # square
          [0 0 0; 0 0 1; 0 0 2; 0 0 3],  # line
          [0 0 0; 0 0 1; 0 1 0; 1 0 0],  # corner
          [0 0 0; 0 0 1; 0 0 2; 0 1 0],  # T
          [0 0 0; 0 0 1; 0 0 2; 0 1 1],  # zigzag
          [0 0 0; 1 0 0; 1 1 0; 2 1 0]]  # L

tetris = convert.(Array{Float32, 2}, tetris)
tetris_batched = batch(tetris)
onehot_tetris = onehotbatch(1:length(tetris) |> collect, 1:length(tetris)) |> gpu
Vones = ones(Float32, (size(tetris_batched, 1), size(tetris_batched, 3), 1)) |> gpu;

In [4]:
sph_tetris = tetris_batched |> pairwise_rs |> cart_to_sph |> gpu;
centers = range(0f0, 3.5f0; length=4) |> collect

classifier = Chain(
            SIWrapper([1 => 4]),
            E3ConvLayer([4], [[(0, 0) => [0], (0, 1) => [1]]], centers),
            SIWrapper([4 => 4, 4 => 4]),
            NLWrapper([4, 4]),
            E3ConvLayer([4, 4], [[(0, 0) => [0], (0, 1) => [1]], [(1, 0) => [1], (1, 1) => [0, 1]]], centers),
            SIWrapper([8 => 4, 12 => 4]),
            NLWrapper([4, 4]),
            E3ConvLayer([4, 4], [[(0, 0) => [0]], [(1, 1) => [0]]], centers),
            SIWrapper([8 => 4]),
            NLWrapper([4]),
            PLayer(),
            Dense(4 => 8)
) |> gpu;

In [5]:
optim = Flux.setup(Flux.Adam(0.01), classifier)

# Testing gradient
losses = []
@showprogress for epoch in 1:600
    loss, grads = Flux.withgradient(classifier) do c
        # Evaluate model and loss inside gradient context:
        y_hat = c((sph_tetris, ([Vones],)))
        Flux.logitcrossentropy(y_hat, onehot_tetris)
    end
    Flux.update!(optim, classifier, grads[1])
    push!(losses, loss)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:02:26[39m


In [6]:
rotated_tetris = [(rand(RotMatrix{3, Float32}) * tet')' for tet in tetris]
rotated_tetris_batched = batch(rotated_tetris)
sph_rotated = rotated_tetris_batched |> pairwise_rs |> cart_to_sph |> gpu;

In [7]:
final_unrotated = NNlib.softmax(classifier((sph_tetris, ([Vones],))))
final = NNlib.softmax(classifier((sph_rotated, ([Vones],))))

@test OneHotArrays.onecold(final |> cpu) == OneHotArrays.onecold(final_unrotated |> cpu)

[32m[1mTest Passed[22m[39m