In [1]:
import numpy as np
import torch
import os, sys, time

In [2]:
def forloopdists(X, T):
    D = np.zeros((X.shape[0], T.shape[0]))

    for i in range(X.shape[0]):
        for j in range(T.shape[0]):
            D[i, j] = np.linalg.norm(X[i, :] - T[j, :])**2
    return D

In [76]:
def numpydists2(X, T):
    X = np.expand_dims(X, axis=1)
    T = np.expand_dims(T, axis=0)
    
    X2 = np.einsum("ijk, ijk -> ij", X, X)
    T2 = np.einsum("ijk, ijk -> ij", T, T)
    XT = np.einsum("ijk, ijk -> ij", X, T)
    TX = np.einsum("ijk, ijk -> ij", T, X)
    
    return X2 + T2 - XT - TX

def numpydists(X, T):
    X = np.expand_dims(X, axis=1)
    T = np.expand_dims(T, axis=0)
    D = np.sum((X - T)**2, axis=2)
    return D

In [85]:
def pytorchdists2(X, T, device=None):
    X = torch.tensor(X)
    T = torch.tensor(T)
    X = torch.unsqueeze(X, dim=1)
    T = torch.unsqueeze(T, dim=0)
    
    X2 = torch.einsum("ijk, ijk -> ij", X, X)
    T2 = torch.einsum("ijk, ijk -> ij", T, T)
    XT = torch.einsum("ijk, ijk -> ij", X, T)
    TX = torch.einsum("ijk, ijk -> ij", T, X)
    
    D = X2 + T2 - XT - TX
    return D

def pytorchdists(X, T, device=None):
    X = torch.tensor(X)
    T = torch.tensor(T)
    
    X = torch.unsqueeze(X, dim=1)
    T = torch.unsqueeze(T, dim=0)
    D = torch.sum((X - T)**2, dim=2)
    return D

In [77]:
X = np.random.normal(size=(500, 30)) #5000 instead of 250k for forloopdists
T = np.random.normal(size=(50, 30))

D = numpydists2(X, T)
D3 = numpydists(X, T)

In [80]:
print(D - D3)

[[ 2.13162821e-14  1.42108547e-14  0.00000000e+00 ...  1.42108547e-14
   0.00000000e+00  0.00000000e+00]
 [ 7.10542736e-15 -1.42108547e-14  7.10542736e-15 ... -7.10542736e-15
   7.10542736e-15  1.42108547e-14]
 [ 0.00000000e+00 -7.10542736e-15 -7.10542736e-15 ...  0.00000000e+00
   0.00000000e+00 -2.13162821e-14]
 ...
 [ 0.00000000e+00  7.10542736e-15  1.42108547e-14 ...  0.00000000e+00
   0.00000000e+00 -7.10542736e-15]
 [ 7.10542736e-15  0.00000000e+00  0.00000000e+00 ...  3.55271368e-15
   1.42108547e-14  1.42108547e-14]
 [ 0.00000000e+00  1.42108547e-14  0.00000000e+00 ... -7.10542736e-15
  -2.84217094e-14 -7.10542736e-15]]


## The Test

In [83]:
def timing():
    X = np.random.normal(size=(5000, 300)) #5000 instead of 250k for forloopdists
    T = np.random.normal(size=(500, 300))

    #if X.shape[0] * T.shape[0] * X.shape[1] < 1E9:
    #    since = time.time()
    #    dists0 = forloopdists(X, T)
    #    time_elapsed = float(time.time()) - float(since)
    #    print('For complete in {:.3f}s'.format( time_elapsed ))

    since = time.time()
    dists1 = pytorchdists(X, T)
    time_elapsed = float(time.time()) - float(since)
    print('Torch complete in {:.3f}s'.format( time_elapsed ))
    
    since = time.time()
    dists12 = pytorchdists2(X, T)
    time_elapsed = float(time.time()) - float(since)
    print('Torch 2 complete in {:.3f}s'.format( time_elapsed ))

    since = time.time()
    dists2 = numpydists(X, T)
    time_elapsed = float(time.time()) - float(since)
    print('Numpy complete in {:.3f}s'.format( time_elapsed ))
    
    since = time.time()
    dists22 = numpydists2(X, T)
    time_elapsed = float(time.time()) - float(since)
    print('Numpy 2 complete in {:.3f}s'.format( time_elapsed ))

In [84]:
timing()

Torch complete in 7.144s
Numpy complete in 4.362s
Numpy 2 complete in 0.738s


In [12]:
print(np.log10(5000 * 500 * 300))
print(np.log10(250000 * 500 * 300))

8.8750612633917
10.574031267727719
