For example, consider
n = 10
with we.namespace('init'):
A = we.random.randn(n, n)
B = we.random.randn(n, n)
with we.namespace('einsum'):
X = we.einsum('ik,jl->ij', A, B) # 10_000 flops.
On the other hand, torch only uses ~300 flops, since the einsum is effectively lowered to
X = we.einsum('i,j->ij', A.sum(axis=1), B.sum(axis=1))
To account for this, as a preprocessing step in einsum, we should check for index labels that are summed over and appear in only a single input tensor (counting duplicate arguments as distinct tensors for this purpose). The FLOP model should remove those indices at the cost of numel(tensor), corresponding to summing over that tensor's index individually. Then FLOP counting should proceed as usual.
For example, consider
On the other hand, torch only uses
~300flops, since the einsum is effectively lowered toTo account for this, as a preprocessing step in einsum, we should check for index labels that are summed over and appear in only a single input tensor (counting duplicate arguments as distinct tensors for this purpose). The FLOP model should remove those indices at the cost of
numel(tensor), corresponding to summing over that tensor's index individually. Then FLOP counting should proceed as usual.