In [10]:
import torch
from equiformer_pytorch import Equiformer

# Init model
model = Equiformer(
    dim_in=(20, 3),
    input_degrees=2,
    dim = (4, 4, 2),      
    dim_head = (4, 4, 4), 
    heads = (2, 2, 2),    
    num_degrees = 3,      
    reduce_dim_out = False,
)

# Create toy input:
# 1 batch with 5 points stacked on top of each other in 3D with 
# random type-0 and type-1 features + 2 padding points with random features
# and coors set to `coors_mask`.
coors_mask = torch.tensor([0, 0, 0]).float()
feats = {
    0: torch.randn(1, 7, 20, 1),
    1: torch.randn(1, 7, 3, 1)
}
coors = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], coors_mask.detach(), coors_mask.detach()]]).float()
mask  = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).bool()

# Forward pass
out = model(feats, coors, mask) # (1, 128)
out.type0 # invariant type 0    - (1, 128)
# out.type1 # equivariant type 1  - (1, 128, 3)

tensor([[[-0.7798, -1.7470,  0.4664, -0.3501],
         [-0.9648, -1.1258, -0.4793, -1.2538],
         [ 1.2645, -0.4801,  1.2091,  0.8419],
         [-1.0893,  0.4631, -0.9827, -1.2779],
         [-0.6973, -1.8270, -0.3994, -0.1276],
         [-0.5780, -1.0073, -0.0454,  1.6276],
         [ 0.6491, -1.2229, -0.5996, -1.3129]]],
       grad_fn=<ReshapeAliasBackward0>)