Given 2 square matrix $A, B$

$tr(AB) =$ sum of $(A \circ B)$ = $flat(A) \bullet flat(B)$

where $\circ$ is element-wise product, $\bullet$ is inner product and $flat(x)$ is treat $x$ as a vecter

$AB$ is slower because it compute all non-diagram element that is unnecessary when finding $tr(AB)$ 

$flat$ method is quicker than $(A \circ B)$ in torch at numerical result

In [80]:
import numpy as np
import torch
import time

In [81]:
def NpTr2S(A, B, flat = True):
    if flat:
        return np.dot(A.T.reshape(-1), B.reshape(-1))
    else:
        return (A*B.T).sum()

def TorchTr2S(A, B, flat = True):
    if flat:
        return A.view(1,-1).mm(B.t().contiguous().view(-1,1)).view(())
    else:
        return (A*B.t()).sum()

In [82]:
A = np.random.normal(0, 1, (5000, 5000)).astype(np.float32)
B = np.random.normal(0, 1, (5000, 5000)).astype(np.float32)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(np.trace(np.dot(A,B)), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B, False), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

3982.4336, 1.0010411739349365
3982.4336, 1.055237054824829
3982.4277, 0.3329920768737793
3982.4775, 0.3156619071960449


In [83]:
A = A.astype(np.float64)
B = B.astype(np.float64)

start = time.time()
print(np.diag(np.dot(A,B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(np.trace(np.dot(A,B)), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B, False), end=", ")
print(time.time() - start)

start = time.time()
print(NpTr2S(A, B), end=", ")
print(time.time() - start)

3982.4341851562767, 2.116687297821045
3982.4341851562767, 1.9839439392089844
3982.4341851562667, 0.44022393226623535
3982.434185156503, 0.3969540596008301


In [84]:
A = torch.Tensor(A.astype(np.float32))
B = torch.Tensor(B.astype(np.float32))

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(torch.trace(A.mm(B)), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B, False), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(3982.4358), 0.9481821060180664
tensor(3982.4336), 0.9221210479736328
tensor(3982.4377), 0.2957921028137207
tensor(3982.4360), 0.09590291976928711


In [85]:
A = A.type(torch.DoubleTensor)
B = B.type(torch.DoubleTensor)

start = time.time()
print(torch.diag(A.mm(B)).sum(), end=", ")
print(time.time() - start)

start = time.time()
print(torch.trace(A.mm(B)), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B, False), end=", ")
print(time.time() - start)

start = time.time()
print(TorchTr2S(A, B), end=", ")
print(time.time() - start)

tensor(3982.4342, dtype=torch.float64), 1.986386775970459
tensor(3982.4342, dtype=torch.float64), 2.104691982269287
tensor(3982.4342, dtype=torch.float64), 0.5125560760498047
tensor(3982.4342, dtype=torch.float64), 0.21176981925964355
