In [1]:
import numpy as np
import pytblis
import torch
import time

In [2]:
nfull = 128
nao = 96
e_full = np.cos(np.arange(nfull**4) + 0.2).reshape(nfull, nfull, nfull, nfull)
e = e_full[:nao, :nao, :nao, :nao]
e_torch = torch.asarray(e)

In [3]:
def fp(arr):
    return np.cos(np.arange(arr.size)) @ arr.reshape(-1)

In [4]:
def fp_torch(arr):
    return torch.cos(torch.arange(np.prod(list(arr.size())), dtype=torch.double)) @ arr.reshape(-1)

In [5]:
subscripts_list = [
    "abxy, xycd -> abcd",  # naive gemm case, 2 * n^6
    "axyz, xyzb -> ab",    # naive gemm case, 2 * n^5
    "axyz, bxyz -> ab",    # naive syrk case,     n^5
    "axyz, ybzx -> ab",    # comp  gemm case, 2 * n^5
    "axby, yacx -> abc",   # batch gemm case, 2 * n^5
    "xpay, aybx -> ab",    # complicate case, 2 * n^4
]

In [6]:
print("NumPy einsum")
repeat_list = [5, 20, 20, 20, 1, 1]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = np.einsum(subscripts, e, e, optimize=True)
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp(v):20.12f}")

NumPy einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     7.306426 sec (avg of  5 repeats)
fingerprint :     -20.188290390819
Subscripts: axyz, xyzb -> ab
elapsed time:     0.776790 sec (avg of 20 repeats)
fingerprint :      20.343405116707
Subscripts: axyz, bxyz -> ab
elapsed time:     0.455486 sec (avg of 20 repeats)
fingerprint : -200211.721311474335
Subscripts: axyz, ybzx -> ab
elapsed time:     0.725011 sec (avg of 20 repeats)
fingerprint :       0.274707781823
Subscripts: axby, yacx -> abc
elapsed time:    27.076382 sec (avg of  1 repeats)
fingerprint :       0.466623082298
Subscripts: xpay, aybx -> ab
elapsed time:   248.522274 sec (avg of  1 repeats)
fingerprint :       0.134542958876


In [7]:
print("PyTBLIS einsum")
repeat_list = [5, 20, 20, 20, 1, 1]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = pytblis.einsum(subscripts, e, e, optimize="greedy")
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp(v):20.12f}")

PyTBLIS einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     1.951329 sec (avg of  5 repeats)
fingerprint :     -20.188290390824
Subscripts: axyz, xyzb -> ab
elapsed time:     0.141577 sec (avg of 20 repeats)
fingerprint :      20.343405116705
Subscripts: axyz, bxyz -> ab
elapsed time:     0.114764 sec (avg of 20 repeats)
fingerprint : -200211.721311338129
Subscripts: axyz, ybzx -> ab
elapsed time:     0.139257 sec (avg of 20 repeats)
fingerprint :       0.274707781825
Subscripts: axby, yacx -> abc
elapsed time:    27.224720 sec (avg of  1 repeats)
fingerprint :       0.466623082298
Subscripts: xpay, aybx -> ab
elapsed time:   249.175256 sec (avg of  1 repeats)
fingerprint :       0.134542958876


In [8]:
print("PyTorch einsum")
repeat_list = [5, 20, 20, 20, 20, 20]
for subscripts, nrepeat in zip(subscripts_list, repeat_list):
    print(f"Subscripts: {subscripts}")
    t = time.time()
    for _ in range(nrepeat):
        v = torch.einsum(subscripts, e_torch, e_torch)
    print(f"elapsed time: {(time.time() - t) / nrepeat:12.6f} sec (avg of {nrepeat:2d} repeats)")
    print(f"fingerprint : {fp_torch(v):20.12f}")

PyTorch einsum
Subscripts: abxy, xycd -> abcd
elapsed time:     2.105942 sec (avg of  5 repeats)
fingerprint :     -20.188290390819
Subscripts: axyz, xyzb -> ab
elapsed time:     0.204239 sec (avg of 20 repeats)
fingerprint :      20.343405116707
Subscripts: axyz, bxyz -> ab
elapsed time:     0.211401 sec (avg of 20 repeats)
fingerprint : -200211.721311473870
Subscripts: axyz, ybzx -> ab
elapsed time:     0.406712 sec (avg of 20 repeats)
fingerprint :       0.274707781823
Subscripts: axby, yacx -> abc
elapsed time:     0.263642 sec (avg of 20 repeats)
fingerprint :       0.466623082310
Subscripts: xpay, aybx -> ab
elapsed time:     0.147302 sec (avg of 20 repeats)
fingerprint :       0.134542958873
