# Approximate Nearest Neighbours Search

Sometimes, when we are processing a user query, it may be **acceptable to retrieve a "good guess"** of 
nearest neighbours to the query **instead of true nearest neighbours**. In those cases, one can use an algorithm which doesn't guarantee to return the actual nearest neighbour in every case, **in return for improved speed or memory savings**. Thus, with the help of such algorithms one can do a **fast approximate search in a very large dataset**. Today we will expore one of such approaches based on graphs.

This is what we are going to do in this lab: 

1. Complete implementation of small-world graph;
2. Implement search in this graph;
3. Build a *navigable* small-world graph;
4. Compare search quality in the resulting graphs.


## 1. Dataset preparation
We will utilize the same dataset which was used in the last lab - the [dataset with curious facts](https://github.com/hsu-ai-course/hsu.ai/blob/master/code/datasets/nlp/facts.txt). Using trained `doc2vec` [model](https://github.com/jhlau/doc2vec) (Associated Press News DBOW (0.6GB), we will infer vectors for every fact and normalize them.


### 1.1 Loading doc2vec model

In [1]:
from gensim.models.doc2vec import Doc2Vec

# unpack a model into 3 files and target the main one
# doc2vec.bin  <---------- this
# doc2vec.bin.syn0.npy
# doc2vec.bin.sin1neg.npy
model = Doc2Vec.load('doc2vec.bin', mmap=None)
print(type(model))
print(type(model.infer_vector(["to", "be", "or", "not"])))

<class 'gensim.models.doc2vec.Doc2Vec'>
<class 'numpy.ndarray'>


### 1.2 Reading data

In [2]:
import urllib.request
data_url = "https://raw.githubusercontent.com/hsu-ai-course/hsu.ai/master/code/datasets/nlp/facts.txt"
file_name= "facts.txt"
urllib.request.urlretrieve(data_url, file_name)

facts = []
with open(file_name) as fp:
    for cnt, line in enumerate(fp):
        facts.append(line.strip('\n'))

### 1.3 Transforming sentences into vectors

In [3]:
import nltk
import numpy as np

def word_tokenize(sentence):
    return nltk.word_tokenize(sentence.lower())

def get_words_from_sentence(sentences):
    for sentence in sentences: 
        yield word_tokenize(sentence.split('.', 1)[1])

sent_vecs = np.array([])
sent_vecs = np.array(list(model.infer_vector(words) for words in get_words_from_sentence(facts)))

### 1.4 Normalizing vectors

In [4]:
def norm_vectors(A):
    An = A.copy()
    for i, row in enumerate(An):
        An[i, :] /= np.linalg.norm(row)
    return An

def find_k_closest(query, dataset, k=5):    
    index = list((i, v, np.dot(query, v)) for i, v in enumerate(dataset))    
    return sorted(index, key=lambda pair: pair[2], reverse=True)[:k]

sent_vecs_normed = norm_vectors(sent_vecs)

## 2. Small world network ##
We discussed [small world networks](https://en.wikipedia.org/wiki/Small-world_network) in lecture. This beautiful concept utilizes skip-list idea to reach query neighbourhood fastly from any random graph node. You have practically all code written, you just need to complete `rewire()` function with respect to [Watts–Strogatz process](https://en.wikipedia.org/wiki/Watts%E2%80%93Strogatz_model).

**Please write rewiring code.**

Function `build_graph` accepts some iterable collection of `values`. In our case this will be embeddings. 

- `K` is a parameter of Watts–Strogatz model, expressing average degree of graph nodes.
- `p` stands for probability of "rewiring".

In [5]:
import random
class Node:
    ''' Graph node class. Major properties are `value` to access embedding and `neighbourhood` for adjacent nodes '''
    def __init__(self, value, idx):
        self.value = value
        self.idx = idx
        self.neighbourhood = set()
        

def build_graph(values, K, p=0.4):
    '''Accepts container with values. Returns list with graph nodes'''
    
    def rewire(nodes, i, j, k):
        #TODO remove i-j connection and add i-k connection, bi-directional
        
        return True
    
        
    N = len(values)
    nodes = [None] * N
    
    # create nodes
    for i, val in enumerate(values):
        nodes[i] = Node(val, i)
    
    # create K-regular lattice
    for i, val in enumerate(nodes):
        for j in range(i - K // 2, i + K // 2 + 1):
            if i != j:
                nodes[i].neighbourhood.add(j % N)
                nodes[j % N].neighbourhood.add(i)
        
    for i, node in enumerate(nodes):
        #TODO for each node rewire right hand side i-j edge to some other random node
        # See Watts–Strogatz model for details
        
        pass
                
    return nodes

The bigger `K` and `p` you choose, the longer method runs. Bigger `K` leads to bigger near-cliques in a graph and, as a consequence, bigger context to consider at each step of search. Bigger `p` is for a better "remote hops", but it should not be close to 1, as it will make graph random (not SW).

In [6]:
import time
start = time.time()
G = build_graph(sent_vecs_normed, K=10, p=0.2)
finish = time.time()
print("Graph built in {:.2f} ms".format(1000 * (finish - start)))

Graph built in 2.00 ms


### 2.1 Searching in a small-world graph

Now you need to implement an efficient search procedure which would utilize small world properties. Starting from the random node, at each step you should move towards the closest node (in terms of cosine simiarity, in our case), meanwhile keeping and refreshing top-K nearest neightbours collection. 

**Please implement basic NSW search**. 

You can refer to the `K-NNSearch` algorithm which pseudocode is given in section 4.2 of the [original paper](https://publications.hse.ru/mirror/pubs/share/folder/x5p6h7thif/direct/128296059).

`search_nsw_basic()`
- `query` - `vector` (`np.ndarray`) representing a query.
- `nsw` - SW graph.
- `top` - re-ranking set size.
- `guard_hops` - if method does not converge, we will terminate when reaching guard_hops #steps.
- `returns` - a pair of a `set` of indices and number of hops `(nearest_neighbours_set, hops)`

In [7]:
import sortedcontainers
from scipy.spatial import distance

def search_nsw_basic(query, nsw, top=5, guard_hops=100):
    ''' basic search algorithm, takes vector query and returns a pair (nearest_neighbours, hops)'''
    #TODO implement basic NSW search
    
    hops = 0    
    return [], hops

### 2.2 Test the search procedure

In [9]:
test_queries = ["good mood", "birds", "virus and bacteria"]
test_vectors = np.array([model.infer_vector(word_tokenize(q)) for q in test_queries])
test_queries_normed = norm_vectors(test_vectors)

First, let's display the true nearest neighbours and measure average search time. 

In [10]:
search_time = 0
for i, query in enumerate(test_queries):
    start = time.time()
    r = find_k_closest(test_queries_normed[i], sent_vecs_normed)
    finish = time.time()
    search_time += finish - start  

    print("\nResults for query:", query)
    for k, v, p in r:
        print("\t", facts[k], "sim=", p)

print("\nExact search took {:.4f} ms on average".format(1000 * (search_time/len(test_queries))))


Results for query: good mood
	 68. Cherophobia is the fear of fun. sim= 0.59210587
	 76. You breathe on average about 8,409,600 times a year sim= 0.5562223
	 144. Dolphins sleep with one eye open! sim= 0.5458862
	 97. 111,111,111 X 111,111,111 = 12,345,678,987,654,321 sim= 0.5400203
	 18. You cannot snore and dream at the same time. sim= 0.5364364

Results for query: birds
	 47. Avocados are poisonous to birds. sim= 0.7138059
	 111. Butterflies taste their food with their feet. sim= 0.66970384
	 121. Birds don’t urinate. sim= 0.6401714
	 109. Cows kill more people than sharks do. sim= 0.6386259
	 144. Dolphins sleep with one eye open! sim= 0.6125039

Results for query: virus and bacteria
	 47. Avocados are poisonous to birds. sim= 0.6077043
	 39. A 2010 study found that 48% of soda fountain contained fecal bacteria, and 11% contained E. Coli. sim= 0.6056562
	 54. Coconut water can be used as blood plasma. sim= 0.60471106
	 109. Cows kill more people than sharks do. sim= 0.5947235
	 83

Now, let's see `search_nsw_basic` in action. It should work way faster than pairwise comparisons above.

In [11]:
search_time = 0
for i, query in enumerate(test_queries):
    start = time.time()
    ans, hops = search_nsw_basic(test_queries_normed[i], G)
    finish = time.time()
    search_time += finish - start

    print("\nResults for query:", query)
    for k in ans:
        print("\t", facts[k], "sim=", np.dot(test_queries_normed[i], sent_vecs_normed[k]))
        
print("\nBasic nsw search took {:.4f} ms on average".format(1000 * (search_time/len(test_queries))))    


Results for query: good mood
	 76. You breathe on average about 8,409,600 times a year sim= 0.5562223
	 41. Blueberries will not ripen until they are picked. sim= 0.48288202
	 29. Chewing gum while you cut an onion will help keep you from crying. sim= 0.4531188
	 71. Human thigh bones are stronger than concrete. sim= 0.43797714
	 79. A waterfall in Hawaii goes up sometimes instead of down. sim= 0.41813037

Results for query: birds
	 47. Avocados are poisonous to birds. sim= 0.7138059
	 44. Honey never spoils. sim= 0.5701026
	 49. The number of animals killed for meat every hour in the U.S. is 500,000. sim= 0.5694398
	 42. About 150 people per year are killed by coconuts. sim= 0.56624955
	 46. A hardboiled egg will spin, but a soft-boiled egg will not. sim= 0.52303654

Results for query: virus and bacteria
	 39. A 2010 study found that 48% of soda fountain contained fecal bacteria, and 11% contained E. Coli. sim= 0.6056562
	 137. Human birth control pills work on gorillas. sim= 0.53696

The results you see should be worse than the exact nearest neighbours, however, not completely random. Pay attention to the similarity values.

## 3. Navigable small-world graph

When building small-world graph using Watts–Strogatz model, there was no notion of proximity between the nodes - it was completely ignored. In Navigable small-world graphs, however, the idea is to insert nodes in such a way that the cliques form real neighbourhoods, meaning points that are connected are close to each other. Please refer to section 5 of the [paper](https://publications.hse.ru/mirror/pubs/share/folder/x5p6h7thif/direct/128296059) for the details.

In [12]:
def build_navigable_graph(values, K):
    '''Accepts container with values. Returns list with graph nodes.
    K parameter stands for the size of the set of closest neighbors to connect to when adding a node'''
    #TODO implement navigable small-world graph consrtuction  
                
    return []

### 3.1 Building and testing the graph

In [13]:
navigable_G = build_navigable_graph(sent_vecs_normed, K=10)

In [14]:
search_time = 0
for i, query in enumerate(test_queries):
    start = time.time()
    ans, hops = search_nsw_basic(test_queries_normed[i], navigable_G) 
    finish = time.time()
    search_time += finish - start
    
    print("\nResults for query:", query)
    for k in ans:
        print("\t", facts[k], "sim=", np.dot(test_queries_normed[i], sent_vecs_normed[k]))
        
print("\nSearch in navigable graph took {:.4f} ms on average".format(1000 * (search_time/len(test_queries))))  


Results for query: good mood
	 68. Cherophobia is the fear of fun. sim= 0.59210587
	 45. About half of all Americans are on a diet on any given day. sim= 0.526565
	 57. Gorillas burp when they are happy sim= 0.5145933
	 70. Pirates wore earrings because they believed it improved their eyesight. sim= 0.49504918
	 6. There are more lifeforms living on your skin than there are people on the planet. sim= 0.48445895

Results for query: birds
	 47. Avocados are poisonous to birds. sim= 0.7138059
	 111. Butterflies taste their food with their feet. sim= 0.66970384
	 12. A human will eat on average 70 assorted insects and 10 spiders while sleeping. sim= 0.593714
	 112. A tarantula can live without food for more than two years. sim= 0.5825779
	 110. Cats have 32 muscles in each of their ears. sim= 0.5541166

Results for query: virus and bacteria
	 6. There are more lifeforms living on your skin than there are people on the planet. sim= 0.5208382
	 12. A human will eat on average 70 assorted in

## 4. Comparing search quality in resulting graphs

For both graphs, for each data sample, retrieve K nearest neighbours and compare them to the true nearest neighbours of the sample. If the retrieved result is present in the true top-k list of the sample, then it is counted as a hit. For both graphs, report the total number of hits, and the average number of hits per sample.

For example: `Number of hits 394 out of 795, avg per query 2.48`

In [None]:
#TODO measure and report the described metrics

### Bonus task

Generate a small set of 2d points and build 2 types of graphs for this set: small-world graph based on Watts–Strogatz algorithm, and Navigable small-world graph. Visualize both graphs and analyze their structures - do they differ? Does Navigable small-world graph capture geometric proximity better?

In [None]:
#TODO bonus task