# Are High-Degree Representations Really Uncessary in Equivariant Graph Neural Networks?

*Background*: 
As symmetric graphs, five regular polyhedra are invariant to rotations up to certain angles. Interestingly, we theoretically proved that any equivariant GNN on these symmetric graphs will degenerate to a zero function if the degree of their representation is fixed to be 1. 

*Experiment*: 
In this notebook, we evaluate equivarinat layers on their ability to distinguish the orientation of $k$-fold strcuture in 3D space. An [$L$-fold symmetric structure](https://en.wikipedia.org/wiki/Rotational_symmetry) does not change when rotated by an angle $\frac{2\pi}{L}$ around a point (in 2D) or axis (3D).

In [None]:
import numpy as np
import random
import math
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
import e3nn
from functools import partial

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))
print("e3nn version {}".format(e3nn.__version__))

In [None]:
from experiments.utils.plot_utils import plot_2d, plot_3d
from experiments.utils.train_utils import run_experiment
from models.schnet import SchNetModel
from models.egnn import EGNNModel
from models.gvpgnn import GVPGNNModel
from models.tfn import TFNModel
from models.mace import MACEModel
from models.hegnn import HEGNNModel

In [None]:
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
def create_rotsym_envs(fold=3):
    dataset = []

    # Environment 0
    atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)
    edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )
    x = torch.Tensor([1,0,0])
    pos = [
        torch.Tensor([0,0,0]),  # origin
        x,   # first spoke 
    ]
    for count in range(1, fold):
        R = e3nn.o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)
        pos.append(x @ R.T)
    pos = torch.stack(pos)
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    pos = pos @ e3nn.o3.rand_matrix()
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset

In [None]:
# Create dataset
fold = 2
dataset = create_rotsym_envs(fold)
for data in dataset:
    plot_3d(data, lim=1)

# Create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
model_name = "hegnn"
correlation = 2
max_ell = 10

model = {
    "schnet": SchNetModel,
    "egnn": partial(EGNNModel, equivariant_pred=True),
    "gvp": partial(GVPGNNModel, equivariant_pred=True),
    "tfn": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),
    "mace": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),
    "hegnn": partial(HEGNNModel, max_ell=max_ell, all_ell=False, equivariant_pred=True),
}[model_name](num_layers=1, in_dim=1, out_dim=2)

best_val_acc, test_acc, train_time = run_experiment(
    model, 
    dataloader,
    val_loader, 
    test_loader,
    n_epochs=100,
    n_times=2,
    device=device,
    verbose=False
)