Given 2 square matrix $A, B$

trace of $AB$ $ = tr(AB) =$ sum of $(A \circ B)$

where $\circ$ is element-wise product

$AB$ is slower because it compute all non-diagram element that is unnecessary when finding $tr(AB)$ 

In [8]:
import numpy as np
import torch
import time

In [9]:
def NpTr2S(A, B):
    return (A*B.T).sum()

def TorchTr2S(A, B):
    return (A*B.t()).sum()

In [10]:
A = np.random.normal(0, 1, (1000, 1000)).astype(np.float32)
B = np.random.normal(0, 1, (1000, 1000)).astype(np.float32)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

-986.8219, 0.008379936218261719
-986.82184, 0.002413034439086914


In [11]:
A = A.astype(np.float64)
B = B.astype(np.float64)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

-986.821826018991, 0.02141118049621582
-986.8218260189899, 0.005032777786254883


In [12]:
A = torch.Tensor(A.astype(np.float32))
B = torch.Tensor(B.astype(np.float32))

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(-986.8220), 0.013262033462524414
tensor(-986.8221), 0.0034589767456054688


In [13]:
A = A.type(torch.DoubleTensor)
B = B.type(torch.DoubleTensor)

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(-986.8218, dtype=torch.float64), 0.026175975799560547
tensor(-986.8218, dtype=torch.float64), 0.006474971771240234
