In [1]:
import numpy as np
import pytblis
import torch
import time

In [2]:
nao = 96
e = np.cos(np.arange(nao**4) + 0.2).reshape(nao, nao, nao, nao)
e_torch = torch.asarray(e)

In [3]:
def fp(arr):
    return np.cos(np.arange(arr.size)) @ arr.reshape(-1)

In [4]:
def fp_torch(arr):
    return torch.cos(torch.arange(np.prod(list(arr.size())), dtype=torch.double)) @ arr.reshape(-1)

In [5]:
subscripts_list = [
    "abxy, xycd -> abcd",  # naive gemm case, 2 * n^6
    "axyz, xyzb -> ab",    # naive gemm case, 2 * n^5
    "axyz, bxyz -> ab",    # naive syrk case,     n^5
    "axyz, ybzx -> ab",    # comp  gemm case, 2 * n^5
    "axby, yacx -> abc",   # batch gemm case, 2 * n^5
    "xpay, aybx -> ab",    # complicate case, 2 * n^4
]

In [6]:
print("NumPy einsum")
repeat_list = [5, 20, 20, 20, 1, 1]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = np.einsum(subscripts, e, e, optimize=True)
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp(v):20.12f}")

NumPy einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     2.132740 sec (avg of  5 repeats)
fingerprint : -19471467.265266474336
Subscripts: axyz, xyzb -> ab
elapsed time:     0.063124 sec (avg of 20 repeats)
fingerprint :      48.288443230390
Subscripts: axyz, bxyz -> ab
elapsed time:     0.293240 sec (avg of 20 repeats)
fingerprint : -217920.505845849111
Subscripts: axyz, ybzx -> ab
elapsed time:     0.207656 sec (avg of 20 repeats)
fingerprint :       2.131216642236
Subscripts: axby, yacx -> abc
elapsed time:    29.650491 sec (avg of  1 repeats)
fingerprint :    -134.741201125226
Subscripts: xpay, aybx -> ab
elapsed time:    33.931122 sec (avg of  1 repeats)
fingerprint :       4.640285999007


In [7]:
print("PyTBLIS einsum")
repeat_list = [5, 20, 20, 20, 1, 1]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = pytblis.einsum(subscripts, e, e, optimize="greedy")
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp(v):20.12f}")

PyTBLIS einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     1.958193 sec (avg of  5 repeats)
fingerprint : -19471467.265266474336
Subscripts: axyz, xyzb -> ab
elapsed time:     0.142870 sec (avg of 20 repeats)
fingerprint :      48.288443230387
Subscripts: axyz, bxyz -> ab
elapsed time:     0.116316 sec (avg of 20 repeats)
fingerprint : -217920.505846078857
Subscripts: axyz, ybzx -> ab
elapsed time:     0.142035 sec (avg of 20 repeats)
fingerprint :       2.131216642223
Subscripts: axby, yacx -> abc
elapsed time:    29.598574 sec (avg of  1 repeats)
fingerprint :    -134.741201125226
Subscripts: xpay, aybx -> ab
elapsed time:    33.830630 sec (avg of  1 repeats)
fingerprint :       4.640285999007


In [8]:
print("PyTorch einsum")
repeat_list = [5, 20, 20, 20, 20, 20]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = torch.einsum(subscripts, e_torch, e_torch)
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp_torch(v):20.12f}")

PyTorch einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     1.981685 sec (avg of  5 repeats)
fingerprint : -19471467.265266474336
Subscripts: axyz, xyzb -> ab
elapsed time:     0.063420 sec (avg of 20 repeats)
fingerprint :      48.288443230390
Subscripts: axyz, bxyz -> ab
elapsed time:     0.037417 sec (avg of 20 repeats)
fingerprint : -217920.505845895852
Subscripts: axyz, ybzx -> ab
elapsed time:     0.211103 sec (avg of 20 repeats)
fingerprint :       2.131216642236
Subscripts: axby, yacx -> abc
elapsed time:     0.179182 sec (avg of 20 repeats)
fingerprint :    -134.741201125241
Subscripts: xpay, aybx -> ab
elapsed time:     0.106856 sec (avg of 20 repeats)
fingerprint :       4.640285999005
