In [106]:
import numpy as np
import time
import polars as pl


def dist(A, B, GramMatrix):
    return (
        np.diag(np.matmul(
            np.matmul(A, GramMatrix), 
            A.T)) +
        np.diag(np.matmul(
            np.matmul(B.T, GramMatrix), 
            B)).reshape(-1, 1) - 
        2 * (np.matmul(
            np.matmul(B.T, GramMatrix), 
            A.T))) 

In [107]:
A = np.random.rand(3, 4)
np.dot(A.T, A)

array([[0.68349891, 0.73592707, 1.04745364, 0.4957842 ],
       [0.73592707, 0.85606979, 1.14926101, 0.56760494],
       [1.04745364, 1.14926101, 1.90322017, 1.05942348],
       [0.4957842 , 0.56760494, 1.05942348, 0.66330192]])

In [108]:
def performance_test(n: int, k: int, m: int, iterations: int):
    A = np.random.rand(m, n)
    B = np.random.rand(n, k)
    
    # Gram matrix must be symmetric
    GramMatrix = np.random.rand(n, n)
    GramMatrix = (GramMatrix + GramMatrix.T) * 0.5
    
    assert A.shape[0] == m and B.shape[1] == k
    
    naive_time = 0
    new_time = 0
    for it in range(iterations):
        result1 = np.zeros((k, m))
        
        point0 = time.time()
        
        for j in range(m):
            for i in range(k):
                result1[i, j] = np.dot(np.dot(A[j] - B[:,i].T, GramMatrix), A[j].T - B[:,i]).item()
                
        point1 = time.time()   
          
        result2 = dist(A, B, GramMatrix)
        
        point2 = time.time()
        
        assert np.allclose(result1, result2)
        
        naive_time += (point1 - point0)
        new_time += (point2 - point1)
   
    return  naive_time, new_time, iterations
        

In [109]:
print(performance_test(3, 4, 6, 10000))

(1.2779417037963867, 0.23721766471862793, 10000)
