In [2]:
import logging

import torch
from torch_cluster import radius_graph
from torch_geometric.data import Data, DataLoader
from torch_scatter import scatter

from e3nn import o3
from e3nn.nn import FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.math import soft_one_hot_linspace
from e3nn.util.test import assert_equivariant

import functools
import matplotlib.pyplot as plt
import numpy as np

# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

ModuleNotFoundError: No module named 'torch_geometric'

In [42]:
# Definition for left and right screws.
tetracubes = torch.tensor([
  # Right screw.
  [ [ -0.50,  0.25, -0.25 ], [ -0.50,  0.25,  0.75 ],
    [  0.50,  0.25, -0.25 ], [  0.50, -0.75, -0.25 ] ],
  # Left screw.
  [ [ -0.75,  0.50, -0.25 ], [  0.25, -0.50,  0.75 ],
    [  0.25,  0.50, -0.25 ], [  0.25, -0.50, -0.25 ] ],
])

In [49]:
def generate_datasets(num_train=1000, num_valid=100, noise_scale=0.05):

  # Assign a label to each tetracube.
  labels = torch.arange(tetracubes.shape[0])

  # Randomly choose among the 2 tetracubes to generate train and validation datasets.
  train_choice = torch.multinomial(torch.ones(tetracubes.shape[0]), num_samples=num_train, replacement=True)
  valid_choice = torch.multinomial(torch.ones(tetracubes.shape[0]), num_samples=num_valid, replacement=True)
  train_shapes = tetracubes[train_choice]
  valid_shapes = tetracubes[valid_choice]
 
  train_labels = labels[train_choice]
  valid_labels = labels[valid_choice]

  # Add Gaussian noise for some variety.
  train_shapes += noise_scale * torch.randn(train_shapes.shape)
  valid_shapes += noise_scale * torch.randn(valid_shapes.shape)

  # Return final train and validation datasets.
  train_data = dict(shapes=train_shapes, labels=train_labels)
  valid_data = dict(shapes=valid_shapes, labels=valid_labels)
  return train_data, valid_data

In [50]:
from typing import Callable, Optional, Protocol, Tuple, Union
import jaxtyping

Array = torch.Tensor
Float = jaxtyping.Float

def basis(
    r: torch.Tensor,
    max_degree: int,
    num: int,
    # radial_fn: Callable[[Float[Array, '...'], int], Float[Array, '... num']],
):
    r"""Basis function corresponding to e3x.nn.basis which uses
        e3nn.spherical_harmonics for angular functions
    """
    
    original_shape = r.shape[:-1]
    r = r.reshape(-1, 3)

    # Normalize input vectors.
    a = torch.maximum(torch.max(torch.abs(r)), torch.finfo(r.dtype).tiny)
    b = r / a
    norm = a * torch.sqrt(torch.sum(b * b, dim=-1, keepdim=True))
    u = r / torch.where(norm > 0, norm, 1)
    norm = norm.squeeze(-1)  # (...)

    # radial function
    # rbf = radial_fn(norm, num)  # (..., N)

    # basis function
    ylm = e3nn.spherical_harmonics(e3nn.s2_irreps(max_degree), u, normalize="component")
    return ylm
    
    product = lambda x, weight: lambda w: w * x(weight)(ylm, rbf)

    return product.reshape((*original_shape, *product.shape[-2:]))

In [None]:
class TensorDense(torch.nn.Module):

    features: int
    max_degree: int
    irreps_out: e3nn.Irreps
    use_gaunt: bool

    @nn.compact
    def __call__(self, x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:

        x1 = e3nn.o3.Linear(irreps_in=x.irreps, irreps_out=x.irreps, channel_out=self.features)(x)
        x2 = e3nn.o3.Linear(irreps_in=x.irreps, irreps_out=x.irreps, channel_out=self.features)(x)
            
        # Keep irreps only up to max_degree.
        filter_ir_out = e3nn.tensor_product(x1.irreps, x2.irreps).filter(
            lmax=self.max_degree
        )
    
        tp = 

        # Additionally, filter out irreps.
        irreps_out = self.irreps_out
        if irreps_out is None:
            irreps_out = tp.irreps

        x = e3nn.o3.Linear(irreps_in=tp.irreps_out, irreps_out=irreps_out)(tp)        
        return x

class E3NNModel(torch.nn.Module):
  features = 8
  max_degree = 3
  use_gaunt: bool = False

  def __call__(self, shapes):  # The 'shapes' array has shape (..., 4, 3).
      
    # 1. Center shapes at origin (for translational invariance).
    shapes -= torch.mean(shapes, keepdims=True, axis=-2)   # 'shapes' still has shape (..., 4, 3).
    
    # 2. Featurize by expanding cube midpoints in basis functions and taking the mean over the 4 cubes.
    
    x = basis( 
      shapes,
      num=self.features,
      max_degree=self.max_degree,
    #   radial_fn=functools.partial(e3x.nn.triangular_window, limit=2.0),
    ) # 'x' has shape (..., 4, (max_degree+1)**2, features).
    
    x = e3nn.mean(x, axis=-3)  # 'x' now has shape (..., (max_degree+1)**2, features).
        
    # 3. Apply feature transformations.
        
    "No pseudoscalar features yet"
    x = TensorDense(
            features=self.features, max_degree=self.max_degree, irreps_out=None, use_gaunt=self.use_gaunt
        )(x)
    
    "Pseduoscalar features in tensor product"
    x = TensorDense(
            features=self.features, max_degree=self.max_degree, irreps_out="0e + 0o", use_gaunt=self.use_gaunt
        )(x)

    x = x.axis_to_mul()
    # 4. Predict logits (with an ordinary Dense layer).
    logits = torch.nn.Linear(features=tetracubes.shape[0])(x.array)  # Logits has shape (..., 2).

    return logits