In [1]:
import faiss
import json

### This example only loads a small portion of the embeddings

In [2]:
qnode_to_index_file = 'qnodes_to_id.json'
output_index = 'wikidata_index.idx'

In [3]:
class FAISS_Index(object):
    
    def __init__(self, qnode_to_index_file, output_index, k, efSearch=400, nprobe=4):
        
        self._index = faiss.read_index(output_index)
        try:
            # Set the parameters
            faiss.downcast_index(self._index.quantizer).hnsw.efSearch = efSearch
            self._index.nprobe = nprobe
        except:
            print('Cannot set parameters for this index')
        
        # Load the entity to index map
        with open(qnode_to_index_file) as fd:
            self._qnode_to_index = json.load(fd)
        self._index_to_qnode = {v: k for k, v in self._qnode_to_index.items()}
        
        self._k = k
        
    def get_neighbors(self, qnode, get_scores=False):
        ''' Find the neighbors for the given qnode '''
        scores, candidates = self._index.search(self._index.reconstruct( \
                                                        self._qnode_to_index[qnode]).reshape(1, -1), \
                                                self._k)
        candidates = [self._index_to_qnode[x] for x in candidates[0] if x != -1]
        scores = scores[0][:len(candidates)]

        if get_scores:
            return [ (c, s) for c, s in zip(candidates, scores) ]
        return candidates
    
    @property
    def k(self):
        return self._k
    
    @property
    def index(self):
        return self._index

In [4]:
index = FAISS_Index(qnode_to_index_file, output_index, 5)

### Search

In [5]:
query_qnode = 'Q7319603'

In [6]:
%time index.get_neighbors(query_qnode)

CPU times: user 844 ms, sys: 35.6 ms, total: 879 ms
Wall time: 26.2 ms


['Q7319603', 'Q11302945', 'Q7381229', 'Q5224589', 'Q5339902']

In [7]:
%time index.get_neighbors(query_qnode, get_scores=True)

CPU times: user 687 ms, sys: 31.2 ms, total: 718 ms
Wall time: 19.8 ms


[('Q7319603', 0.0),
 ('Q11302945', 1.6781642),
 ('Q7381229', 1.7304077),
 ('Q5224589', 1.9623686),
 ('Q5339902', 2.3353627)]

### Without using the class

In [8]:
index = faiss.read_index(output_index)
try:
    # Set the parameters
    faiss.downcast_index(index.quantizer).hnsw.efSearch = 400
    index.nprobe = 4
except:
    print('Cannot set parameters for this index')

# Load the entity to index map
with open(qnode_to_index_file) as fd:
    qnode_to_index = json.load(fd)
index_to_qnode = {v: k for k, v in qnode_to_index.items()}

In [9]:
def get_neighbors(qnode, get_scores=False, k=5):
    ''' Find the neighbors for the given qnode '''
    scores, candidates = index.search(index.reconstruct( \
                                            qnode_to_index[qnode]).reshape(1, -1), \
                                      k)
    candidates = [index_to_qnode[x] for x in candidates[0] if x != -1]
    scores = scores[0][:len(candidates)]

    if get_scores:
        return [ (c, s) for c, s in zip(candidates, scores) ]
    return candidates

In [10]:
%time get_neighbors(query_qnode)

CPU times: user 620 ms, sys: 38.7 ms, total: 659 ms
Wall time: 18.9 ms


['Q7319603', 'Q11302945', 'Q7381229', 'Q5224589', 'Q5339902']

In [11]:
%time get_neighbors(query_qnode, get_scores=True)

CPU times: user 703 ms, sys: 3.8 ms, total: 707 ms
Wall time: 19.8 ms


[('Q7319603', 0.0),
 ('Q11302945', 1.6781642),
 ('Q7381229', 1.7304077),
 ('Q5224589', 1.9623686),
 ('Q5339902', 2.3353627)]