In [1]:
import numpy as np

In [2]:
A = np.random.randn(5,5)

# Regular NumPy
trace_np = np.trace(A)

# Einsum
trace_einsum = np.einsum('ii->', A)

# Verify
norm_diff = np.linalg.norm(trace_np - trace_einsum)  # should be 0

print(trace_np)
print(trace_einsum)
print(norm_diff)

1.0361742880385103
1.0361742880385103
0.0


In [3]:
B = np.random.randn(5,5)

# Regular NumPy
prod_np = A.dot(B)

# Einsum
prod_einsum = np.einsum('ij,jk->ik', A, B)

# Verify
norm_diff = np.linalg.norm(prod_np - prod_einsum)  # ~1e-15 (numerical)

print(prod_np)
print(prod_einsum)
print(norm_diff)

[[-3.98582064 -0.94987931 -0.84522122 -0.58280709  0.48655849]
 [ 0.56713562 -1.32069926  1.91545728 -1.81560285  0.29872532]
 [-1.71476543 -1.90220469 -1.39670824 -0.45867032 -0.91910422]
 [ 0.75016497  1.80379298  0.27329141  0.31357598  2.08360387]
 [ 1.72551868 -2.03131784  1.97057922 -0.68903473 -0.55410569]]
[[-3.98582064 -0.94987931 -0.84522122 -0.58280709  0.48655849]
 [ 0.56713562 -1.32069926  1.91545728 -1.81560285  0.29872532]
 [-1.71476543 -1.90220469 -1.39670824 -0.45867032 -0.91910422]
 [ 0.75016497  1.80379298  0.27329141  0.31357598  2.08360387]
 [ 1.72551868 -2.03131784  1.97057922 -0.68903473 -0.55410569]]
4.2276033262255756e-16


In [4]:
batch1 = np.random.randn(3,4,5)
batch2 = np.random.randn(3,5,6)

# Regular NumPy
batch_np = np.matmul(batch1, batch2)  # or batch1 @ batch2

# Einsum
batch_einsum = np.einsum('bij,bjk->bik', batch1, batch2)

# Verify
norm_diff = np.linalg.norm(batch_np - batch_einsum)  # ~1e-15

print(batch_np)
print(batch_einsum)
print(norm_diff)

[[[-0.56895156 -0.10516422  0.32718977  0.91717256  0.77202744
    1.1536184 ]
  [ 0.94679558 -1.83349494  1.35862612 -0.26167095 -2.21943071
    0.20437275]
  [-0.68599927 -2.38114899  1.55575597  0.34635143 -2.73968047
   -1.32964612]
  [-0.0579348   1.5116446  -0.77325283 -1.35146001 -1.47826642
   -1.19820021]]

 [[ 4.9122211  -4.00827769 -4.03581326  0.56210129  1.39107106
    3.7434432 ]
  [-2.72244196 -1.64699612  0.56072511 -2.32055152 -0.24988922
   -2.6179719 ]
  [-1.34554518  2.17034127 -0.03425734  1.32582759  2.95885236
    0.93024537]
  [-1.2616288  -0.39783015  0.66623274 -1.72565935 -0.44679932
   -1.9858827 ]]

 [[ 0.70398556  0.35965482 -3.07520221  0.43791175  1.27952663
   -0.30026536]
  [ 0.79929243 -0.83137013 -0.78932876 -0.62716008 -0.05026737
    0.06927686]
  [ 1.38133111  3.0886856  -4.61148358  2.4399062   2.17153111
    0.82026129]
  [-0.47047089 -0.35172397 -1.82165062 -3.57443889 -1.15742414
    0.03692628]]]
[[[-0.56895156 -0.10516422  0.32718977  0.9171