Skip to content

Are multiplicities > 1 in irreps_in2 supported in uvw TensorProduct? #144

@mjrs33

Description

@mjrs33

Is support for irreps_in2 multiplicity > 1 in uvw mode (FullyConnectedTensorProduct) currently incomplete, or am I overlooking a usage detail?
Any guidance (or a confirmation that this is a bug) would be greatly appreciated.

import e3nn.o3 as o3
import openequivariance as oeq
import torch

def create_tps(irreps_in1, irreps_in2, irreps_out):
    instructions = [
        (i_1, i_2, i_out, "uvw", True, 1.0)
        for i_1, (_, ir_1) in enumerate(irreps_in1)
        for i_2, (_, ir_2) in enumerate(irreps_in2)
        for i_out, (_, ir_out) in enumerate(irreps_out)
        if ir_out in ir_1 * ir_2
    ]
    kwargs = {
        "irreps_in1": irreps_in1,
        "irreps_in2": irreps_in2,
        "irreps_out": irreps_out,
        "instructions": instructions,
        "internal_weights": False,
        "shared_weights": False,
        "path_normalization": "element",
        "irrep_normalization": "component"
    }
    tp_e3nn = o3.TensorProduct(**kwargs).cuda()
    problem = oeq.TPProblem(**kwargs)
    tp_oeq = oeq.TensorProduct(problem)
    
    x1 = irreps_in1.randn(1, -1, device="cuda")
    x2 = irreps_in2.randn(1, -1, device="cuda")
    w = torch.randn(1, tp_e3nn.weight_numel, device="cuda")
    return tp_e3nn, tp_oeq, x1, x2, w


# multiplicity=1
torch.manual_seed(0)
irreps_in1 = o3.Irreps("8x0e+4x1e")
irreps_in2 = o3.Irreps("1x0e")
irreps_out = o3.Irreps("8x0e+4x1e")
tp_e3nn, tp_oeq, x1, x2, w = create_tps(irreps_in1, irreps_in2, irreps_out)
out_e3nn = tp_e3nn(x1, x2, w)
out_oeq = tp_oeq(x1, x2, w)
print(torch.abs(out_e3nn - out_oeq).max().item())
# -> 5.960464477539063e-08

# multiplicity=2
irreps_in1 = o3.Irreps("8x0e+4x1e")
irreps_in2 = o3.Irreps("2x0e")
irreps_out = o3.Irreps("8x0e+4x1e")
tp_e3nn, tp_oeq, x1, x2, w = create_tps(irreps_in1, irreps_in2, irreps_out)
out_e3nn = tp_e3nn(x1, x2, w)
out_oeq = tp_oeq(x1, x2, w)
print(torch.abs(out_e3nn - out_oeq).max().item())
# -> 0.9126971960067749

Thanks in advance for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions