In [None]:
import numpy as np
import pandas as pd
from sklearn.neighbors import KDTree
from annoy import AnnoyIndex
from scipy.sparse import csr_matrix

In [None]:
def _gini(array):
    """Calculate the Gini coefficient of a numpy array.
    """
    
    array = array.flatten().astype(float)
    if np.amin(array) < 0:
        array -= np.amin(array)
    array += 0.0000001
    array = np.sort(array)
    index = np.arange(1, array.shape[0]+1)
    n = array.shape[0]
    array_sum = np.sum(array)
    return (2 * np.dot(index, array) - (n + 1) * array_sum) / (n * array_sum)

In [None]:
def _knn(X_ref,
         X_query=None,
         k=20,
         leaf_size=40,
         metric='euclidean',
         annoy_n_trees=10):
    """Calculate K nearest neighbors for each row, using KDTree or Annoy.
    """
    if X_query is None:
        X_query = X_ref.copy()
    dim = X_ref.shape[1]
    if dim < 8192:
        kdt = KDTree(X_ref, leaf_size=leaf_size, metric=metric)
        kdt_d, kdt_i = kdt.query(X_query, k=k, return_distance=True)   
    else:
        annoy_index = AnnoyIndex(dim, metric)
        #Add all points from X_ref to Annoy index --> change algorithm
        for i in range(X_ref.shape[0]):
            annoy_index.add_item(i, X_ref[i])
        annoy_index.build(annoy_n_trees)
        kdt_i = []
        kdt_d = []
        for x in X_query:
            indices, distances = annoy_index.get_nns_by_vector(x, k, include_distances=True)
            kdt_i.append(indices)
            kdt_d.append(distances)
        kdt_i = np.array(kdt_i)
        kdt_d = np.array(kdt_d)

    sp_row = np.repeat(np.arange(kdt_i.shape[0]), kdt_i.shape[1])
    sp_col = kdt_i.flatten()
    sp_conn = np.repeat(1, len(sp_row))
    sp_dist = kdt_d.flatten()
    mat_conn_ref_query = csr_matrix(
        (sp_conn, (sp_row, sp_col)),
        shape=(X_query.shape[0], X_ref.shape[0])).T
    mat_dist_ref_query = csr_matrix(
        (sp_dist, (sp_row, sp_col)),
        shape=(X_query.shape[0], X_ref.shape[0])).T
    return mat_conn_ref_query, mat_dist_ref_query