# Creator: Ivan Bardarov <br> (University of Strathclyde, March 2019)
## This module is used to retrieve similar vectors from the multidimentional space

In [1]:
import nmslib
from os.path import exists
from os import makedirs

def get_index_filename(cat):
    """
    Returns the name of the index file for the specified category
    
    Parameters
    ----------
    cat : int
        The id of the category for which we want to retrive the collection
        
    Returns
    -------
    str
        the path to the index file
    """
    makedirs("../indexes", exist_ok=True)
    return f"../indexes/{cat}.bin"

def init_index_for_data(data):
    """
    Create a new index and add the data to it
    
    Parameters
    ----------
    data : ndarray
        the array of vectors that are going to be indexed
        
    Returns
    -------
    index
        the index for the data
    """
    index = nmslib.init(space='angulardist')
    index.addDataPointBatch(data)
    return index

def create_index(cat, data):
    """
    Create a new index and add the data to it
    
    Parameters
    ----------
    cat : int
        The id of the category for which we want to retrive the collection
    data : ndarray
        the array of vectors that are going to be indexed
        
    Returns
    -------
    index
        the index for the data
    """
    index = init_index_for_data(data)
    index.createIndex()
    save_index(index, cat)
    return index

def save_index(index, cat):
    """
    Save the index to a file
    
    Parameters
    ----------
    index : index
        The index to be saved
    cat : int
        The id of the category for which we want to save the index
        
    """
    index.saveIndex(get_index_filename(cat))
    
def load_index(cat, data):
    """
    Load the index to a file
    
    Parameters
    ----------
    cat : int
        The id of the category for which we want to save the index
    data : ndarray
        the array of vectors that are going to be indexed
        
    Returns
    -------
    index
        the index for the data
        
    """
    index = init_index_for_data(data)
    index.loadIndex(get_index_filename(cat))
    return index

def get_index(cat, data):
    """
    The function used to retrieve index for a category which decides if it should be loaded or created
    
    Parameters
    ----------
    cat : int
        The id of the category for which we want to save the index
    data : ndarray
        the array of vectors that are going to be indexed
        
    Returns
    -------
    index
        the index for the data
        
    """
    path = get_index_filename(cat)
    return load_index(cat, data) if exists(path) else create_index(cat, data)

def get_knn(index, vec, k):
    """
    The function used to retrieve index for a category which decides if it should be loaded or created
    
    Parameters
    ----------
    index: index
        the index for the data
    vec: ndarray
        the array for which we need to retrieve similar items
    k: int 
        the number of similar items needed
    """
    return index.knnQuery(vec, k=k)