In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import os

# Use plt style default
plt.style.use('default')

# 1) Preparando Terreno

To generate random vectors uniformly distributed on the surface of an $n$-dimensional sphere with radius 1, follow these steps:

1. **Generate Gaussian random variables**:

    First, generate a vector $\mathbf{v}\in \mathbb{R}^n$ where each component $v_i$ is sampled from a standard normal distribution, i.e., $v_i\sim\mathcal{N}(0,1)$. This gives you random points in $n$-dimensional space that are not uniformly distributed on the sphere yet.

2. **Normalize the vector**:

    To project the point onto the unit sphere, normalize the vector $\mathbf{v}$ so that its magnitude is 1. This is done by dividing $\mathbf{v}$ by its Euclidean norm:

    $$\mathbf{u} = \frac{\mathbf{v}}{\|\mathbf{v}\|}$$

    where $\|\mathbf{v}\|$ is the Euclidean norm of $\mathbf{v}$, calculated as:

    $$\|\mathbf{v}\| = \left(\sum_{i = 1}^n v_i^2\right)^{1/2}$$

In [2]:
def random_unit_vectors(num_vectors, n):
    # Gera uma matriz de vetores com distribuição normal (gaussiana)
    vectors = np.random.normal(0, 1, (num_vectors, n))
    norm = np.linalg.norm(vectors, axis = 1)
    for i in np.where(norm != 1.0)[0]:
        vectors[i] = vectors[i] / np.linalg.norm(vectors[i])
    vectors = vectors.astype(np.float32)
    norm = np.linalg.norm(vectors, axis = 1)
    for i in np.where(norm != 1.0)[0]:
        vectors[i] = vectors[i] / np.linalg.norm(vectors[i])
    
    return vectors

# Exemplo: Gerar 5 vetores aleatórios no espaço 3D e garantir precisão float32
num_vectors = 10000
n = 3
np.random.seed(42)
random_vectors = random_unit_vectors(num_vectors, n)

In [3]:
%matplotlib qt
# Plot points in random_vectors on 3d sphere
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(random_vectors[:,0], random_vectors[:,1], random_vectors[:,2], 'o', ms=1)
# Plot sphere with radius 1
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = np.outer(np.cos(u), np.sin(v))
y = np.outer(np.sin(u), np.sin(v))
z = np.outer(np.ones(np.size(u)), np.cos(v))
# Include edges
ax.plot_surface(x, y, z, color='b', alpha=0.1, edgecolor='k')
# Adjust aspect ratio
ax.set_box_aspect([1,1,1])
plt.tight_layout()

# 2) Preparando busca

In [4]:
def L2(query, features):
    # Compute L2 distance between query and features
    return np.sqrt(np.sum((query - features)**2, axis=1))

def knn(query, k, features, distFunc = None, tie = 0):
    """
    Realiza uma busca k-NN (k-Nearest Neighbors) utilizando uma função de distância personalizada.

    Parameters
    ----------
    query : np.ndarray
        O vetor de consulta para o qual os vizinhos mais próximos serão encontrados.
    k : int
        O número de vizinhos mais próximos a serem retornados.
    features : np.ndarray
        Um array 2D onde cada linha é um vetor de características.
    distFunc : callable, optional
        A função de distância a ser utilizada. Se None, a função L2 (distância Euclidiana) será utilizada (default é None).
    tie : int, optional
        Se 1, verifica se há mais características com a mesma distância que o k-ésimo vizinho mais próximo e as inclui no resultado (default é 0).

    Returns
    -------
    list of tuple
        Uma lista de tuplas onde cada tupla contém o índice e a distância do vizinho mais próximo no array de características.
        A lista é ordenada pela distância em ordem crescente.

    Notes
    -----
    Em caso de empate (distâncias iguais) e se `tie` for 1, todos os vizinhos com a mesma distância que o k-ésimo vizinho mais próximo são incluídos no resultado.
    """
    if distFunc is None:
        distFunc = L2
    
    # Vetor com distâncias
    distances = distFunc(query, features)
    k_nearest = []
     
    for i in range(k):
        min_idx = np.argmin(distances)
        k_nearest.append((min_idx, distances[min_idx]))
        distances[min_idx] = np.inf
        
    # Check if there are more features with the same distance
    if tie:
        min_idx = np.argmin(distances)
        while distances[min_idx] == k_nearest[-1][1]:
            k_nearest.append((min_idx, distances[min_idx]))
            distances[min_idx] = np.inf
            min_idx = np.argmin(distances)
    
    return k_nearest

In [5]:
np.random.seed(233)
list_of_queries = random_unit_vectors(10, 3)
list_of_queries

array([[ 0.6733262 ,  0.14008501, -0.72595316],
       [-0.37680808, -0.17516552, -0.9095783 ],
       [-0.600396  , -0.77817523,  0.18430394],
       [ 0.6694828 , -0.70914644, -0.22114256],
       [-0.44115913,  0.03943314,  0.89656216],
       [-0.02321262, -0.14116941,  0.9897133 ],
       [ 0.655004  ,  0.63452375, -0.41030395],
       [ 0.2351463 , -0.87342113,  0.426429  ],
       [ 0.47943065, -0.7857913 ,  0.39074072],
       [-0.6201924 ,  0.69483495, -0.3640959 ]], dtype=float32)

In [6]:
# Perform k-NN search for each queryk = 64
k = 32
results = []
times = []

for query in list_of_queries:
    
    t0 = time.time()
    k_nearest = knn(query, k, random_vectors, tie=0)
    t1 = time.time()
    results.append(k_nearest)
    
    
    dt = (t1-t0)*1000
    print(f"Query {len(results)}: {dt:.3f} ms")
    times.append(dt)

# Mean and standard deviation of time
median_time = np.median(times)
total_time = np.sum(times)
print(f"Total time: {total_time/1000:.2f}, Median time: {median_time:.3f} ms")

Query 1: 0.997 ms
Query 2: 1.001 ms
Query 3: 0.998 ms
Query 4: 1.000 ms
Query 5: 0.998 ms
Query 6: 0.000 ms
Query 7: 1.002 ms
Query 8: 0.998 ms
Query 9: 1.003 ms
Query 10: 0.000 ms
Total time: 0.01, Median time: 0.998 ms


In [7]:
%matplotlib qt
# Plot points in random_vectors on 3d sphere
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(random_vectors[:,0], random_vectors[:,1], random_vectors[:,2], 'o', ms=1)
# Plot queries and results
for j in range(len(list_of_queries)):
    ax.plot(list_of_queries[j][0], list_of_queries[j][1], list_of_queries[j][2], 'o', ms=5, color='r')
    for i in range(k):
        idx = results[j][i][0]
        ax.plot(random_vectors[idx,0], random_vectors[idx,1], random_vectors[idx,2], 'o', color='g', ms=2)

# ax.plot(list_of_queries[0][0], list_of_queries[0][1], list_of_queries[0][2], 'o', ms=5, color='r')
# for i in range(k):
#     idx = results[0][i][0]
#     ax.plot(random_vectors[idx,0], random_vectors[idx,1], random_vectors[idx,2], 'o', ms=3, color='g')

# Plot sphere with radius 1
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = np.outer(np.cos(u), np.sin(v))
y = np.outer(np.sin(u), np.sin(v))
z = np.outer(np.ones(np.size(u)), np.cos(v))
# Include edges
ax.plot_surface(x, y, z, color='b', alpha=0.1, edgecolor='k')
# Adjust aspect ratio
ax.set_box_aspect([1,1,1])
plt.tight_layout()

# 3) Testes busca

In [37]:
# Gerar vetores
num_vectors = 100000
n = 4
np.random.seed(42)
random_vectors = random_unit_vectors(num_vectors, n)

# Gerar queries
np.random.seed(235)
list_of_queries = random_unit_vectors(80, n)
random_vectors.shape, list_of_queries.shape

((100000, 4), (80, 4))

In [38]:
# Save 1 million vectors to disk
if not os.path.exists(f'datasets/random_unit_vectors_{n}d.npy'):
    np.save(f'datasets/random_unit_vectors_{n}d.npy', random_vectors)

# Save queries
if not os.path.exists(f'datasets/random_unit_vectors_{n}d_queries.npy'):
    np.save(f'datasets/random_unit_vectors_{n}d_queries.npy', list_of_queries)

In [40]:
# Perform k-NN search for each queryk = 64
k = 32
results = []
times = []

for query in list_of_queries:
    
    t0 = time.time()
    k_nearest = knn(query, k, random_vectors, tie=0)
    results.append(k_nearest)
    t1 = time.time()
    
    dt = (t1-t0)*1000
    print(f"Query {len(results)}: {dt:.3f} ms")
    times.append(dt)

# Mean and standard deviation of time
median_time = np.median(times)
total_time = np.sum(times)
print(f"Total time: {total_time:.2f} ms, Median time: {median_time:.3f} ms")

Query 1: 7.998 ms
Query 2: 6.002 ms
Query 3: 5.996 ms
Query 4: 9.221 ms
Query 5: 10.998 ms
Query 6: 6.999 ms
Query 7: 7.002 ms
Query 8: 7.993 ms
Query 9: 14.004 ms
Query 10: 10.641 ms
Query 11: 10.000 ms
Query 12: 12.999 ms
Query 13: 8.999 ms
Query 14: 17.004 ms
Query 15: 7.996 ms
Query 16: 12.001 ms
Query 17: 13.034 ms
Query 18: 6.999 ms
Query 19: 9.001 ms
Query 20: 5.005 ms
Query 21: 5.999 ms
Query 22: 6.999 ms
Query 23: 12.452 ms
Query 24: 7.756 ms
Query 25: 17.003 ms
Query 26: 11.224 ms
Query 27: 4.997 ms
Query 28: 4.996 ms
Query 29: 6.006 ms
Query 30: 5.196 ms
Query 31: 3.999 ms
Query 32: 8.161 ms
Query 33: 5.171 ms
Query 34: 4.003 ms
Query 35: 7.999 ms
Query 36: 6.999 ms
Query 37: 7.367 ms
Query 38: 5.251 ms
Query 39: 4.000 ms
Query 40: 8.183 ms
Query 41: 6.180 ms
Query 42: 3.999 ms
Query 43: 7.001 ms
Query 44: 6.007 ms
Query 45: 5.291 ms
Query 46: 6.197 ms
Query 47: 4.999 ms
Query 48: 5.001 ms
Query 49: 6.160 ms
Query 50: 5.157 ms
Query 51: 4.000 ms
Query 52: 7.167 ms
Query 53: 