## Purpose

Guidefinder enables users to design RNA targets for entire genomes using any PAM and any genome. The most computatioanlly costly step of Guidefinder compares the Hamming distance of all potenial guide RNA targets in the genome to all other targets. For a typical bacterial genome and Cas9 (Protospacer ajacent Motif site NGG) this could be a (10^6 * (10^6 -1))/2 ~ 5^11 comparisons. To avoid that number of comparisons we perform approxamate nearest neighbor search using Hierarchical Navigable Small World (HNSW) graphs in the NMSlib package.  This is much faster but it requires construction of an index and selecting index and search parameters that balance index speed, search speed, and Recall.

## Bayesian optimization

We need to optimize the following index parameters:
* M (int) 10-100
* efC (int)10-1000
* post (int) 0-2

And the search parameter:
* ec (int) 10-2000

We want to minimize search time and maximize 1-NN recall

To do this we will use the Python package [GPflowOpt](https://gpflowopt.readthedocs.io/en/latest/notebooks/multiobjective.html)

In [5]:
import sys 
import time 
import math

from Bio import SeqIO
import nmslib 
from scipy.sparse import csr_matrix 
from sklearn.model_selection import train_test_split 

import guidefinder


## Calculate the ground truth data

Initially a gound truth dataset will be calculated using the brute force method.

In [19]:
pamobj = guidefinder.core.Pam("NGG", "5prime")
gb = SeqIO.read("test_data/Pseudomonas_aeruginosa_PAO1_107.sample.fasta", "fasta")
pamtargets = pamobj.find_targets(seqrecord_obj=gb, strand="both", target_len=20)
tl = guidefinder.core.TargetList(targets=pamtargets, lcp=10, hammingdist=2, knum=2)
tl.find_unique_near_pam()
bintargets = tl._one_hot_encode(tl.targets)

index = nmslib.init(space='bit_hamming',
                    dtype=nmslib.DistType.INT,
                    data_type=nmslib.DataType.OBJECT_AS_STRING,
                    method='seq_search')
index.addDataPointBatch(bintargets)
index.createIndex( print_progress=True)



In [35]:
start = time.time()
truth_list = index.knnQueryBatch(bintargets, k=3, num_threads = 4)
        
end = time.time()

print('brute-force kNN time total=%f (sec), per query=%f (sec)' % 
      (end-start, float(end-start)/len(bintargets)) )

brute-force kNN time total=0.598412 (sec), per query=0.000054 (sec)


In [30]:
def recall(results, truth):
    """Calculate recall for top two kNN distances
    
    calulate recall on the top 2 distances (not labels becasue we really care that the algoritm estimates the correct 
    distance not the exact value of the neighbor and there can be multiple nieghbors with the same edit distance .)
    """
    dat = zip(results, truth)
    assert len(results) ==len(truth)
    tot = len(results)
    correct = 0
    for res, tr in dat:
        if res[1][0:1] ==tr[1][0:1]:
            correct += 1
    return correct/tot
    
    


In [43]:
def test_func(truth, bintargets, M, efC, post, ef, delaunay_type=2, threads=4):
    start = time.time()
    index_params = {'M': M, 'indexThreadQty': threads,'efConstruction': efC, 'post': post}
    index = nmslib.init(space='bit_hamming',
                    dtype=nmslib.DistType.INT,
                    data_type=nmslib.DataType.OBJECT_AS_STRING,
                    method='hnsw')
    index.addDataPointBatch(bintargets)
    index.createIndex(index_params)
    index.setQueryTimeParams({'efSearch': ef})
    results_list = index.knnQueryBatch(bintargets, k=3, num_threads = 4)
    end = time.time()
    rc = recall(results_list, truth)
    return rc, float(end-start)

In [44]:
test_func(truth=truth_list, bintargets=bintargets, M=10, efC=50, post=1, ef=200, threads=4)

(1.0, 1.2706491947174072)