In [1]:
from torch_geometric.nn import MessagePassing
from torch_geometric.transforms import Distance
import torch_geometric as tg
import torch_geometric.data.batch as tg_batch

from torch_cluster import radius_graph


import torch
import torch
import torch.nn as nn

import atom3dutils

from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.nn import Gate

from scipy.spatial.transform import Rotation
import unittest
import importlib  

gvp = importlib.import_module("gvp-pytorch.gvp")
gvp.models = importlib.import_module("gvp-pytorch.gvp.models")
gvp.data = importlib.import_module("gvp-pytorch.gvp.data")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# REPLACE WITH SIMPLE IMPORT

In [2]:
def balanced_irreps(hidden_features, lmax):
    """Divide subspaces equally over the feature budget"""
    N = int(hidden_features / (lmax + 1))

    irreps = []
    for l, irrep in enumerate(Irreps.spherical_harmonics(lmax)):
        n = int(N / (2 * l + 1))

        irreps.append(str(n) + "x" + str(irrep[1]))

    irreps = "+".join(irreps)

    irreps = Irreps(irreps)

    # Don't short sell yourself, add some more trivial irreps to fill the gap
    gap = hidden_features - irreps.dim
    if gap > 0:
        irreps = Irreps("{}x0e".format(gap)) + irreps
        irreps = irreps.simplify()

    return irreps

def compute_gate_irreps(irreps_out):
    """Compute irreps_scalars, irreps"""
    irreps_scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0])
    irreps_gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0])
    irreps_gates = Irreps([(mul, "0e") for mul, _ in irreps_gated]).simplify()

    return irreps_scalars, irreps_gated, irreps_gates

class Convolution(nn.Module):
    """ SE(3) equivariant convolution, parameterised by a radial network """
    def __init__(self, irreps_in1, irreps_in2, irreps_out):
        super().__init__()
        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out
        self.tp =  o3.FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            irrep_normalization="component",
            path_normalization="element",
            internal_weights=False,
            shared_weights=False
        )

        self.radial_net = RadialNet(self.tp.weight_numel)

    def forward(self, x, rel_pos_sh, distance):
        """
        Features of shape [E, irreps_in1.dim]
        rel_pos_sh of shape [E, irreps_in2.dim]
        distance of shape [E, 1]
        """
        weights = self.radial_net(distance)
        return self.tp(x, rel_pos_sh, weights)

class RadialNet(nn.Module):
    def __init__(self, num_weights):
        super().__init__()

        num_basis = 10
        basis = tg.nn.models.dimenet.BesselBasisLayer(num_basis, cutoff=4)

        self.net = nn.Sequential(basis,
                                nn.Linear(num_basis, 16),
                                nn.SiLU(),
                                nn.Linear(16, num_weights))
    def forward(self, dist):
        return self.net(dist.squeeze(-1))


class ConvLayerSE3(tg.nn.MessagePassing):
    def __init__(self, irreps_in1, irreps_in2, irreps_out, activation=True):
        super().__init__(aggr="add")

        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out

        irreps_scalars, irreps_gated, irreps_gates = compute_gate_irreps(irreps_out)
        self.conv = Convolution(irreps_in1, irreps_in2, irreps_gates + irreps_out)

        if activation:
            self.gate = Gate(irreps_scalars, [nn.SiLU()], irreps_gates, [nn.Sigmoid()], irreps_gated)
        else:
            self.gate = nn.Identity()

    def forward(self, edge_index, x, rel_pos_sh, dist):
        x = self.propagate(edge_index, x=x, rel_pos_sh=rel_pos_sh, dist=dist)
        x = self.gate(x)
        return x

    def message(self, x_j, rel_pos_sh, dist):
        return self.conv(x_j, rel_pos_sh, dist)

class ConvModel(nn.Module):
    def __init__(self, irreps_in, irreps_hidden, irreps_edge, irreps_out, depth, max_z:int=atom3dutils._NUM_ATOM_TYPES):
        super().__init__()

        self.irreps_in = irreps_in
        self.irreps_hidden = irreps_hidden
        self.irreps_edge = irreps_edge
        self.irreps_out = irreps_out

        self.embedder = nn.Embedding(max_z, irreps_in.dim)

        self.layers = nn.ModuleList()
        self.layers.append(ConvLayerSE3(irreps_in, irreps_edge, irreps_hidden))
        for i in range(depth-2):
            self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_hidden))
        self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_out, activation=False))

    def forward(self, graph):
        edge_index = graph.edge_index
        z = graph.z
        pos = graph.pos
        batch = graph.batch

        print("edge index", edge_index)
        print("z", z)
        print("pos", pos)
        print("batch", batch)

        # Prepare quantities for convolutional layers
        # Index of source and target node
        src, tgt = edge_index[0], edge_index[1]
        # Vector pointing from the source node to the target node
        rel_pos = pos[tgt] - pos[src]
        # That vector in Spherical Harmonics
        rel_pos_sh = o3.spherical_harmonics(self.irreps_edge, rel_pos, normalize=True)
        # The norm of that vector
        dist = torch.linalg.vector_norm(rel_pos, dim=-1, keepdims=True)

        # Embed atom one-hot
        x = self.embedder(z)

        # Convolve nodes
        for layer in self.layers:
            x = layer(edge_index, x, rel_pos_sh, dist)

        # 1-dim output, squeeze it out
        x = x.squeeze(-1)
        
        # Global pooling of node features
        x = tg.nn.global_mean_pool(x, batch)
        return x

# Testing Equivariance: Robustness in Protein Models with GVP

Equivariance is the consistency of a model's output with regard to particular input transformations. When the output is unaffected by changes made to the input, a model is said to be equivariant. For example, in computer vision tasks, an equivariant model will produce consistent predictions regardless of the translation, rotation, or scaling applied to the input image. Verifying and quantifying equivariance to rotations is crucial to establish the reliability and robustness of models that analyze protein structures.

In this demo, we investigate the equivariance to rotations of the GVP and our implementation of the steerable MLP. Our objective is to evaluate the robustness and consistency of the models' output when subjected to rotations.

To do so, we first create random `300` random nodes, with each node having `100` scalar features and `16` vector features. The edges have `32` scalar features and `1` vector feature. We can randomly generate these node and edge features. Lastly, we define the edge index, which has the information about from where to which node an edge is going.

In [3]:
n_nodes = 300
n_edges = 10000
node_dim = (100, 16)
edge_dim = (32, 1)

nodes = gvp.randn(n_nodes, node_dim, device=device)
edges = gvp.randn(n_edges, edge_dim, device=device)
edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device)

print("Node scalar features", nodes[0].shape)
print("Node vector features", nodes[1].shape)
print("Edge scalar features", edges[0].shape)
print("Edge vector features", edges[1].shape)
print("Edge index", edge_index.shape)

Node scalar features torch.Size([300, 100])
Node vector features torch.Size([300, 16, 3])
Edge scalar features torch.Size([10000, 32])
Edge vector features torch.Size([10000, 1, 3])
Edge index torch.Size([2, 10000])


Now, let's enhance the clarity of the tests regarding our desired objectives. Our aim is to ensure that the scalar output features of the model remain unchanged under rotation, as scalars possess inherent rotation invariance. Additionally, we expect the vector features to exhibit rotation equivariance, implying that the model's output, after rotating the original vector input, should be identical to rotating the output obtained by passing the original vector input through the model. Or also in another representation:

$$\begin{align}
\text{Invariance:}   && \text{model}(\text{rotation}(\text{scalars})) &= \text{model}(\text{scalars})         \\
\text{Equivariance:} && \text{model}(\text{rotation}(\text{vectors})) &= \text{rotation}(\text{model}(\text{vectors})).
\end{align}$$

In [4]:
def test_equivariance_GVP(model, nodes, edges):
    
    random = torch.as_tensor(Rotation.random().as_matrix(), 
                             dtype=torch.float32, device=device)
    
    with torch.no_grad():
    
        out_s, out_v = model(nodes, edges)
        n_v_rot, e_v_rot = nodes[1] @ random, edges[1] @ random
        out_v_rot = out_v @ random
        out_s_prime, out_v_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot))
        
        assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
        assert torch.allclose(out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4)

Then, we define a class that has tests for the GVP, GVP with vector gating and GVP convolutional layer.

In [5]:
class EquivarianceTestGVP(unittest.TestCase):

    """
    This is a class that tests whether the GVP with and without vector gating, and
    the GVP convolutional layer are equivariant to rotation.
    """
    
    def test_gvp(self):
        model = gvp.GVP(node_dim, node_dim).to(device).eval()
        model_fn = lambda h_V, h_E: model(h_V)
        test_equivariance_GVP(model_fn, nodes, edges)
        
    def test_gvp_vector_gate(self):
        model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval()
        model_fn = lambda h_V, h_E: model(h_V)
        test_equivariance_GVP(model_fn, nodes, edges)

    def test_gvp_conv_layer_vector_gate(self):
        model = gvp.GVPConvLayer(node_dim, edge_dim, vector_gate=True).to(device).eval()
        model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E,
                                          autoregressive_x=h_V)
        test_equivariance_GVP(model_fn, nodes, edges)

In [6]:
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

...
----------------------------------------------------------------------
Ran 3 tests in 0.122s

OK


We can conclude from this test that he scalar features are indeed invariant to rotation and the vector features are equivariant to rotation.

# Testing Equivariance: Robustness in Protein Models with steerable MLP

Now we can perform the same tests using the steerable MLP. The only thing we will need to change is how the input is handled. Since the steerable MLP works with irreducible representations, we will first define these for the nodes, edges and output.

In [7]:
# Nodes are encoded using 16 scalars (type-0) and 16 geo vector (type-1)
irreps_node = Irreps("10x0e+10x1e")
# Edges are encoded using 32 scalars and 1 geo vector (type-1)
irreps_edge = Irreps("32x0e+1x1o")
# Output is 16 scalars (type-0) and 16 geo vector (type-1)
irreps_out = Irreps("20x0e + 10x1e")

# irreps_scalars, irreps_gated, irreps_gates = compute_gate_irreps(irreps_out)
# irreps_final = irreps_gates + irreps_out

dim_emb = irreps_node.dim

print("Input irreps", irreps_node)
print("Edge irreps", irreps_edge)
print("Output irreps", irreps_out)
# print("Output + Gates irreps", irreps_final)
print("Dim embedding irreps:", dim_emb)

Input irreps 10x0e+10x1e
Edge irreps 32x0e+1x1o
Output irreps 20x0e+10x1e
Dim embedding irreps: 40


Create a random graph using random positions and edges

In [8]:
# These are the positions of the nodes
pos = torch.randn(size=(n_nodes,3))

# Node embedding
x = torch.randn(size=(n_nodes,dim_emb))

# How these nodes are connected (doesnt matter)
edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device)

# All nodes are in this "batch" -> protein structure
batch = torch.ones(size=(n_nodes,))

# Edge features: vector from source to targer
rel_pos = pos[edge_index[0]] - pos[edge_index[1]]

# Edge features in Spherical Harmonics
rel_pos_sh = o3.spherical_harmonics(irreps_edge, rel_pos, normalize=True)

# Norm of the edge features
dist = torch.linalg.vector_norm(rel_pos, dim=-1, keepdims=True)

The input is defined, so let's define the model, the convolutional layer that is equivariant to SE3 rotations. 

In [9]:
# Define model
model = ConvLayerSE3(
    irreps_in1=irreps_node,
    irreps_in2=irreps_edge,
    irreps_out=irreps_out,
)

out = model(edge_index, x, rel_pos_sh, dist)

Generate a random 3D rotation matrix and get the right representations for the input and output irreducible representations.

In [10]:
rot = o3.rand_matrix()
D_in = irreps_node.D_from_matrix(rot)
D_out = irreps_out.D_from_matrix(rot)

We can rotate the output of the convolutional layer.

In [11]:
# Rotate after
out_rot_after = out @ D_out.T

Or rotate the input of the convolutional layer and rotate this output.

In [12]:
# Rotate before
pos_rot = pos @ rot.T
rel_pos_rot = pos_rot[edge_index[0]] - pos_rot[edge_index[1]]
rel_pos_sh_rot = o3.spherical_harmonics(irreps_edge, rel_pos_rot, normalize=True)
out_rot_before = model(edge_index, x @ D_in.T, rel_pos_sh_rot, dist)

In [13]:
assert torch.allclose(out_rot_after, out_rot_before, rtol=1e-4, atol=1e-4, equal_nan=True)

From this, we can conclude that the model is equivariant to rotation.