In [47]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
import torch
from equiformer_pytorch import Equiformer


def get_toy_model():
    model = Equiformer(
        dim_in=(20, 3),
        input_degrees=2,
        dim = (4, 4, 2),      
        dim_head = (4, 4, 4), 
        heads = (2, 2, 2),    
        num_degrees = 3,      
        reduce_dim_out = False,
        num_neighbors=1
    )
    return model


def get_toy_input(coors_mask=torch.tensor([0, 0, 0]).float()):
    # 1 batch with 5 points stacked on top of each other in 3D with 
    # 20 random type-0 and 3 random type-1 features + 2 padding points
    # with random features and coors set to `coors_mask`.
    feats = {
        0: torch.randn(1, 7, 20, 1),
        1: torch.randn(1, 7, 3, 1)
    }
    coors = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], coors_mask.detach(), coors_mask.detach()]]).float()
    mask  = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).bool()
    return feats, coors, mask


def train_toy(model, feats, coors, mask):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    for _ in range(2):
        optimizer.zero_grad()
        out = model(feats, coors, mask)
        print('out.type0[mask]', out.type0[mask])
        loss = out.type0[mask].sum().norm() # dummy loss through type-0 features
        print('loss', loss)
        print()
        loss.backward()
        optimizer.step()

In [49]:
# 1
# Below, the expected 1-NN grpah creates dummy edge with padded token for node 0
# instead of the edge to the closest node 1.
model = get_toy_model()
feats, coors, mask = get_toy_input()
train_toy(model, feats, coors, mask)

EdgeInfo(neighbor_indices=tensor([[[5],
         [0],
         [1],
         [2],
         [3],
         [0],
         [0]]]), neighbor_mask=tensor([[[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False]]]), edges=None)
out.type0[mask] tensor([[-0.5886, -0.5910, -0.7438,  1.6586],
        [-1.2187,  1.0213,  0.3353, -1.1659],
        [-1.6366,  0.5513, -0.6860, -0.7395],
        [ 0.3319,  0.4592,  0.0487,  1.9175],
        [-1.2321,  1.1638,  0.2361, -1.0353]], grad_fn=<IndexBackward0>)
loss tensor(1.9140, grad_fn=<NormBackward1>)

EdgeInfo(neighbor_indices=tensor([[[5],
         [0],
         [1],
         [2],
         [3],
         [0],
         [0]]]), neighbor_mask=tensor([[[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False]]]), edges=None)
out.type0[mask] tensor([[-0.4579, -0.4649, -0.6095,  1.7904],
        [-1.3384,  1.0126,  0.4569, -0.9824],
        [

In [50]:
# 2
# So, the expected would be to set padded coords to inf not to be considered as nearest to any node.
# but it then leads to NaN loss.
model = get_toy_model()
feats, coors, mask = get_toy_input(coors_mask=torch.tensor([float('inf'), float('inf'), float('inf')]).float())
train_toy(model, feats, coors, mask)

EdgeInfo(neighbor_indices=tensor([[[1],
         [0],
         [1],
         [2],
         [3],
         [0],
         [0]]]), neighbor_mask=tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False]]]), edges=None)
out.type0[mask] tensor([[-0.3113, -0.7774, -1.2577,  1.3103],
        [ 0.7112,  0.8714,  1.3977, -0.8839],
        [-1.6713, -1.0266, -0.2549, -0.2964],
        [ 0.8971, -1.1418, -1.3702,  0.1183],
        [ 0.7703,  0.1469,  1.8398, -0.0075]], grad_fn=<IndexBackward0>)
loss tensor(0.9361, grad_fn=<NormBackward1>)

EdgeInfo(neighbor_indices=tensor([[[1],
         [0],
         [1],
         [2],
         [3],
         [0],
         [0]]]), neighbor_mask=tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False]]]), edges=None)
out.type0[mask] tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [na

In [51]:
# 3
with torch.autograd.set_detect_anomaly(True):
    model = get_toy_model()
    feats, coors, mask = get_toy_input(coors_mask=torch.tensor([float('inf'), float('inf'), float('inf')]).float())
    train_toy(model, feats, coors, mask)

EdgeInfo(neighbor_indices=tensor([[[1],
         [0],
         [1],
         [2],
         [3],
         [0],
         [0]]]), neighbor_mask=tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False]]]), edges=None)
out.type0[mask] tensor([[-0.0844,  1.4477, -1.3534, -0.2554],
        [-1.4856, -0.7888,  1.0603, -0.2158],
        [-0.0585,  0.4015, -1.6796, -1.0071],
        [ 0.9682,  1.1221, -0.7594,  1.1076],
        [-0.3963,  0.5135, -1.0182,  1.5945]], grad_fn=<IndexBackward0>)
loss tensor(0.8872, grad_fn=<NormBackward1>)



  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 711, in start
    self.io_loop.start()
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 215, in start
    self.asyncio_loop.run_forever()
  File "/Users/anton/miniconda3/envs/ppiformer_m1/lib/python3.10/asyncio/base_events.py", line

RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.