# Tensor contractions

In [None]:
import numpy as np

# Index dimensions
Di = 30
Dj = 50
Dk = 20
Dl = 50
Dm = 20

A = np.random.rand(Di, Dj, Dk) # A_ijk
B = np.random.rand(Dj, Dl) # B_jl
C = np.random.rand(Dk, Dm) # C_km

In [49]:
A_ij_k = np.reshape(A, [Di*Dj, Dk])
AC_ij_m = A_ij_k @ C
AC_i_j_m = np.reshape(AC_ij_m, [Di, Dj, Dm])
AC_i_m_j = np.transpose(AC_i_j_m, [0, 2, 1])
AC_im_j = np.reshape(AC_i_m_j, [Di*Dm, Dj])
ACB_im_l = AC_im_j @ B
ACB_i_m_l = np.reshape(ACB_im_l, [Di, Dm, Dl])
ACB_i_l_m = np.transpose(ACB_i_m_l, [0, 2, 1])

## Einsum

In [31]:
D = np.einsum('ijk,jl,km->ilm', A, B, C) # D_ilm

In [54]:
np.allclose(D, ACB_i_l_m)

True

In [32]:
D.shape

(30, 50, 20)

# Contraction Order

In [33]:
# Use einsum_path to calculate optimal contraction order
path_info = np.einsum_path('ijk,jl,km->ilm', A, B, C)
print(path_info[0])

['einsum_path', (0, 1), (0, 1)]


In [34]:
print(path_info[1])

  Complete contraction:  ijk,jl,km->ilm
         Naive scaling:  5
     Optimized scaling:  4
      Naive FLOP count:  9.000e+07
  Optimized FLOP count:  4.200e+06
   Theoretical speedup:  21.429
  Largest intermediate:  3.000e+04 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 jl,ijk->kil                              km,kil->ilm
   4                 kil,km->ilm                                 ilm->ilm


In [41]:
%%timeit
D = np.einsum('ijk,jl,km->ilm', A, B, C, optimize=True) # D_ilm

378 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [42]:
%%timeit
D = np.einsum('ijk,jl,km->ilm', A, B, C, optimize=False) # D_ilm

27.6 ms ± 446 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [47]:
%%timeit
D = np.einsum('ijk,jl,km->ilm', A, B, C, optimize=['einsum_path', (1, 2), (0, 1)]) # D_ilm

2.54 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [44]:
%%timeit
D = np.einsum('ijk,jl,km->ilm', A, B, C, optimize=['einsum_path', (0, 1), (0, 1)]) # D_ilm

326 µs ± 33.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
%%timeit
D = np.einsum('ijk,jl,km->ilm', A, B, C, optimize=['einsum_path', (0, 2), (0, 1)]) # D_ilm

132 µs ± 6.71 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
