In [None]:
import numpy as np
import time

from opt_einsum import contract
import opt_einsum as oe

# 设置BLAS线程数

try:
    import mkl
    mkl.set_num_threads(1)
except ImportError:
    pass
try:
    import blas
    blas.set_num_threads(1)
except ImportError:
    pass



In [None]:
dim = 10

I = np.random.rand(dim, dim, dim, dim)
C = np.random.rand(dim, dim)

def naive(I, C):
    # N^8 scaling
    return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)

def optimized(I, C):
    # N^5 scaling
    K = np.einsum('pi,ijkl->pjkl', C, I)
    K = np.einsum('qj,pjkl->pqkl', C, K)
    K = np.einsum('rk,pqkl->pqrl', C, K)
    K = np.einsum('sl,pqrl->pqrs', C, K)
    return K

In [None]:
np.allclose(naive(I, C), optimized(I, C))

In [None]:
%timeit naive(I, C)
%timeit optimized(I, C)

In [None]:
dim = 30
I = np.random.rand(dim, dim, dim, dim)
C = np.random.rand(dim, dim)

In [None]:
%timeit optimized(I, C)

In [None]:
%timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)

In [None]:
path_info = oe.contract_path('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
print(path_info[1])

In [None]:
%timeit np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C, optimize=True)

In [None]:
%timeit np.einsum('pj,qj->qpj', C, C, optimize=True)

In [None]:
%timeit contract('pj,qj->qpj', C, C)