Skip to content

Deal with within-tensor summed indices #55

@wilswu99

Description

@wilswu99

For example, consider

n = 10
with we.namespace('init'):
    A = we.random.randn(n, n)
    B = we.random.randn(n, n)
with we.namespace('einsum'):
    X = we.einsum('ik,jl->ij', A, B)          # 10_000 flops.

On the other hand, torch only uses ~300 flops, since the einsum is effectively lowered to

X = we.einsum('i,j->ij', A.sum(axis=1), B.sum(axis=1))

To account for this, as a preprocessing step in einsum, we should check for index labels that are summed over and appear in only a single input tensor (counting duplicate arguments as distinct tensors for this purpose). The FLOP model should remove those indices at the cost of numel(tensor), corresponding to summing over that tensor's index individually. Then FLOP counting should proceed as usual.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestpriority:p2Nice-to-have, scheduledtopic:flop-accountingFLOP counting, budget deduction, cost models, and accounting policy

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions