# Are High-Degree Representations Really Uncessary in Equivariant Graph Neural Networks?

*Background*: 
As symmetric graphs, five regular polyhedra are invariant to rotations up to certain angles. Interestingly, we theoretically proved that any equivariant GNN on these symmetric graphs will degenerate to a zero function if the degree of their representation is fixed to be 1. 

*Experiment*: 
In this notebook, we evaluate equivarinat layers on their ability to distinguish the orientation of five regular polyhedra. 

In [None]:
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
import e3nn
from functools import partial

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))
print("e3nn version {}".format(e3nn.__version__))

In [2]:
from experiments.utils.plot_utils import plot_2d, plot_3d
from experiments.utils.train_utils import run_experiment
from models.schnet import SchNetModel
from models.egnn import EGNNModel
from models.gvpgnn import GVPGNNModel
from models.tfn import TFNModel
from models.mace import MACEModel
from models.hegnn import HEGNNModel

In [None]:
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [4]:
from experiments.data.reg_poly_coords import pos_dict, ver_num_dict

def get_edge(pos):
    assert isinstance(pos, torch.Tensor)
    num_node = pos.size(0)
    edge_index = [[], []]
    for i in range(num_node):
        for j in range(i+1, num_node):
            edge_index[0].append(i)
            edge_index[1].append(j)

    return torch.LongTensor(edge_index)

# create environments
def create_envs(face_num=20):
    dataset = []

    # Environment 0
    atoms = torch.zeros(ver_num_dict[face_num], dtype=torch.long)
    pos = pos_dict[face_num]
    edge_index = get_edge(pos)

    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    # Environment 1
    # pos = torch.matmul(pos, torch.Tensor(random_rotation_matrix()))
    pos = pos @ e3nn.o3.rand_matrix()
    
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)

    return dataset

In [None]:
# creat dataset
face_num = 4    # Only select from 4, 6, 8, 12, 20
dataset = create_envs(face_num)
for data in dataset:
    plot_3d(data, lim=2)

# Create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(dataset, batch_size=2, shuffle=False)

In [None]:
# Set parameters
model_name = "gvp"

correlation = 2
max_ell = 5

model = {
    "schnet": SchNetModel,
    "egnn": partial(EGNNModel, equivariant_pred=True),
    "gvp": partial(GVPGNNModel, equivariant_pred=True),
    "tfn": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),
    "mace": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),
    "hegnn": partial(HEGNNModel, max_ell=max_ell, all_ell=False, equivariant_pred=True),
}[model_name](num_layers=1, in_dim=1, out_dim=2)

best_val_acc, test_acc, train_time = run_experiment(
    model, 
    dataloader,
    val_loader, 
    test_loader,
    n_epochs=100,
    n_times=10,
    device=device,
    verbose=False
)