In [None]:
import torch
from torch import nn
from scipy.spatial.transform import Rotation
import unittest
import importlib  

gvp = importlib.import_module("gvp-pytorch.gvp")
gvp.models = importlib.import_module("gvp-pytorch.gvp.models")
gvp.data = importlib.import_module("gvp-pytorch.gvp.data")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Testing Equivariance: Robustness in Protein Models

Equivariance is the consistency of a model's output with regard to particular input transformations. When the output is unaffected by changes made to the input, a model is said to be equivariant. For example, in computer vision tasks, an equivariant model will produce consistent predictions regardless of the translation, rotation, or scaling applied to the input image. Verifying and quantifying equivariance to rotations is crucial to establish the reliability and robustness of models that analyze protein structures.

In this demo, we investigate the equivariance to rotations of the GVP and our implementation of the steerable MLP. Our objective is to evaluate the robustness and consistency of the models' output when subjected to rotations.

To do so, we first create random `300` random nodes, with each node having `16` vector features. Then, we define a class that has tests for the GVP, GVP with vector gating and our steerable MLP implementation.

In [104]:
node_dim = (100, 16)
n_nodes = 300
nodes = gvp.randn(n_nodes, node_dim, device=device)

class EquivarianceTest(unittest.TestCase):

    """
    This is a class that tests whether the GVP with and without vector gating is 
    equivariant to rotation.
    
    First, the test_equivariance function is called where all the
    features are rotated using the same rotation matrix.

    Second, the test_equivariance_per_feature is called, where every
    feature is rotated where each feature is rotated with a different
    rotation matrix.
    """

    def test_gvp(self):
        model = gvp.GVP(node_dim, node_dim).to(device).eval()
        model_fn = lambda h_V: model(h_V)
        test_equivariance(model_fn, nodes)
        
        
    def test_gvp_vector_gate(self):
        model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval()
        model_fn = lambda h_V: model(h_V)
        test_equivariance(model_fn, nodes)

    def test_steerable_MLP(self):
        return
        model # =  TODO
        model_fn = lambda h_V: model(h_V)
        test_equivariance(model_fn, nodes)

    # I think this does not hold; test doesn't pass.
    def test_gvp_per_feature(self):
        model = gvp.GVP(node_dim, node_dim).to(device).eval()
        model_fn = lambda h_V: model(h_V)
        test_equivariance_per_feature(model_fn, nodes)
    def test_gvp_vector_gate_per_feature(self):
        model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval()
        model_fn = lambda h_V: model(h_V)
        test_equivariance_per_feature(model_fn, nodes)

Now, let's enhance the clarity of the tests regarding our desired objectives. Our aim is to ensure that the scalar output features of the model remain unchanged under rotation, as scalars possess inherent rotation invariance. Additionally, we expect the vector features to exhibit rotation equivariance, implying that the model's output, after rotating the original vector input, should be identical to rotating the output obtained by passing the original vector input through the model. Or also in another representation:

$$\begin{align}
\text{Invariance:}   && \text{model}(\text{rotation}(\text{scalars})) &= \text{model}(\text{scalars})         \\
\text{Equivariance:} && \text{model}(\text{rotation}(\text{vectors})) &= \text{rotation}(\text{model}(\text{vectors})).
\end{align}$$

In [105]:
def test_equivariance(model, nodes):
    
    random = torch.as_tensor(Rotation.random().as_matrix(), 
                             dtype=torch.float32, device=device)
    
    with torch.no_grad():
    
        out_s, out_v = model(nodes)
        n_v_rot = nodes[1] @ random
        out_v_rot = out_v @ random
        out_s_prime, out_v_prime = model((nodes[0], n_v_rot))
        
        assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
        assert torch.allclose(out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4)

Here is an attempt to rotate a single vector feature and see if this gives an equivariant result. However, this is unsuccessful.

In [106]:
def test_equivariance_per_feature(model, nodes):

    with torch.no_grad():
        for feature in range(nodes[1].shape[1]):
            random = torch.as_tensor(Rotation.random().as_matrix(), 
                             dtype=torch.float32, device=device)
    
            out_s, out_v = model(nodes)

            # rotate a single input vector feature
            n_v_rot = nodes[1][:, feature, :] @ random
            # rotate a single output vector feature
            out_v_rot = out_v[:, feature, :] @ random

            # concatenate the features back together
            full_n_v_rot = torch.cat((nodes[1][:, :feature, :], n_v_rot.unsqueeze(1), nodes[1][:, feature+1:, :]), 1)
            full_out_v_rot = torch.cat((out_v[:, :feature, :], out_v_rot.unsqueeze(1), out_v[:, feature+1:, :]), 1)

            out_s_prime, out_v_prime = model((nodes[0], full_n_v_rot))
            
            assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
            assert torch.allclose(full_out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4)

In [107]:
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.F.F.
FAIL: test_gvp_per_feature (__main__.EquivarianceTest.test_gvp_per_feature)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_44923/1428190658.py", line 40, in test_gvp_per_feature
    test_equivariance_per_feature(model_fn, nodes)
  File "/tmp/ipykernel_44923/2527260079.py", line 21, in test_equivariance_per_feature
    assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
AssertionError

FAIL: test_gvp_vector_gate_per_feature (__main__.EquivarianceTest.test_gvp_vector_gate_per_feature)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_44923/1428190658.py", line 44, in test_gvp_vector_gate_per_feature
    test_equivariance_per_feature(model_fn, nodes)
  File "/tmp/ipykernel_44923/2527260079.py", line 21, in test_equivariance_per_feature
    assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4)
AssertionE