Given 2 square matrix $A, B$

trace of $AB$ $ = tr(AB) =$ sum of $(A \circ B)$

where $\circ$ is element-wise product

$AB$ is slower because it compute all non-diagram element that is unnecessary when finding $tr(AB)$ 

In [1]:
import numpy as np
import torch
import time

In [2]:
def NpTr2S(A, B):
    return (A*B.T).sum()

def TorchTr2S(A, B):
    return (A*B.t()).sum()

In [3]:
A = np.random.normal(0, 1, (5000, 5000)).astype(np.float32)
B = np.random.normal(0, 1, (5000, 5000)).astype(np.float32)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(np.trace(np.dot(A,B)), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

1576.7007, 1.2647738456726074
1576.7007, 1.2203540802001953
1576.7035, 0.33798670768737793


In [4]:
A = A.astype(np.float64)
B = B.astype(np.float64)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(np.trace(np.dot(A,B)), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

1576.7017899704574, 2.578411817550659
1576.7017899704574, 2.4070417881011963
1576.701789970456, 0.44674205780029297


In [5]:
A = torch.Tensor(A.astype(np.float32))
B = torch.Tensor(B.astype(np.float32))

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(torch.trace(A.mm(B)), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(1576.7025), 1.1117980480194092
tensor(1576.7015), 1.1355969905853271
tensor(1576.7007), 0.3173179626464844


In [6]:
A = A.type(torch.DoubleTensor)
B = B.type(torch.DoubleTensor)

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(torch.trace(A.mm(B)), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(1576.7018, dtype=torch.float64), 2.519334077835083
tensor(1576.7018, dtype=torch.float64), 2.448906183242798
tensor(1576.7018, dtype=torch.float64), 0.4453146457672119
