In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from dkouqe.document_retrieval.multiple_search import (MultipleRetrievalManager, MergeStrategy,
                             MultipleOntologiesManager, SortByBestScore)
                             

from dkouqe.document_retrieval.semantic_search import FaissIndex, VectorSearch, SearchContext,SearchType
import os
from sentence_transformers import SentenceTransformer

import json
from bef.search_utils import (vector_search,
                              search_and_print, search_apply_kernel,)
from bef.vector_utils import transform_and_normalize


import faiss
import pickle
import numpy as np
from typing import Dict

In [3]:
DIR = '/home/julio/repos/event_finder/data/pubmed_full/results/cg/'



In [4]:
# BERT_MODEL = 'all-mpnet-base-v2'
BERT_MODEL = 'all-MiniLM-L12-v2'
sentence_emb_model = SentenceTransformer(BERT_MODEL)


In [5]:
class VectorSearchDimRedu(SearchType):

    def retrieve_records(self, query, model, index,kernel,bias, num_results=10):
        """Tranforms query to vector using a pretrained, sentence-level 
        DistilBERT model and finds similar vectors using FAISS.
        Args:
            query (str): User query that should be more than a sentence long.
            model (sentence_transformers.SentenceTransformer.SentenceTransformer)
            index (`numpy.ndarray`): FAISS index that needs to be deserialized.
            num_results (int): Number of results to return.
        Returns:
            D (:obj:`numpy.array` of `float`): Distance between results and query.
            I (:obj:`numpy.array` of `int`): Paper ID of the results.
        
        """
        vector = model.encode([query])

        vector = transform_and_normalize(vector, kernel, bias)
        distances, ids = index.search(np.array(vector).astype("float32"), k=num_results)

        return distances.flatten(), ids.flatten()


In [6]:
class FaissIndexGraphEventsDimRedu(SearchContext):
    """Contains and faiss index and the ontology 
    """
    #TODO use dataclass
    def __init__(self, 
        search_type, 
        graphs_file, 
        index_file, 
        embedding_model,
        kernel,
        bias
        ) -> None:

        self._search_type = search_type
        self._embedding_model = embedding_model
        self.kernel = kernel
        self.bias = bias
        self._is_active = True
        print('Loading graph ...')
        self.graphs = []
        #loading with no embeddings
        # with open(DATA_DIR + 'graphs.json') as ff:
        with open(graphs_file) as ff:
            for g in json.load(ff):
                # for node in g['nodes']:
                #     node.pop('embedding')
                self.graphs.append(g)

        self.index = faiss.read_index(index_file)

    def retrieve_records(self, query: str, num_docs: int):
        distances, ids = self._search_type.retrieve_records(query,
                                                            self._embedding_model,
                                                            self.index, 
                                                            self.kernel,
                                                            self.bias,
                                                            num_docs)

        results = {}
        results['faiss_scores'] = distances
        results['records_ids'] = ids
        # results['nodes'] = 
        return results


    def search_and_print_ontology(self, query: str, num_docs: int, sentence_emb_model,kernel,bias) -> None:
        """Search and print the ontology, for debuging.

        Args:
            query (str): [description]
            num_docs (int): [description]
        """
        search_apply_kernel(query, self._embedding_model,
                            self.index, self.graphs, kernel, bias)




In [7]:
# years = ['2000','2001','2002']
years = [str(i) for i in range(2000,2003)]
strategy = VectorSearchDimRedu()
emb = sentence_emb_model
seekers = {}
BASE_DIR = '../data/pubmed_full/results/cg/'
for year in years:
    graphs_file = os.path.join(BASE_DIR, 'graph_pubmed_' + year + '.json')
    #Check existence

    index_file = os.path.join(BASE_DIR, 'graph_pubmed_' + year + '.index')

    with open(os.path.join(BASE_DIR, 'graph_pubmed_' + year + '.kernel'), 'rb') as handle:
        kernel = pickle.load(handle)

    with open(os.path.join(BASE_DIR, 'graph_pubmed_' + year + '.bias'), 'rb') as handle:
        bias = pickle.load(handle)

    # graph = obonet.read_obo(url)
    seek_args = {'search_type': strategy,
                 'index_file': index_file,
                 'embedding_model': emb,
                 'kernel':kernel,
                 'bias':bias,
                 'graphs_file': graphs_file}

    seekers[year] = FaissIndexGraphEventsDimRedu(**seek_args)


Loading graph ...
Loading graph ...
Loading graph ...


In [8]:
# import traceback
# try:
#     seekers['2000'].retrieve_records('the acetycholine produces catabolism', 5)
# except:
#     traceback.print_exc()


In [9]:
seekers['2000'].retrieve_records('the acetycholine produces catabolism', 5)


{'faiss_scores': array([0.74260783, 0.74260783, 0.7528304 , 0.77426314, 0.7861366 ],
       dtype=float32),
 'records_ids': array([ 410983,  410984,  994276, 1389012, 1246086])}

In [22]:
class MultipleGraphIndexManager(MultipleOntologiesManager):
    """Handles when we want to retrieve information from multiple index
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._results = self.empty_results()


    def retrieve_records(self, query: str, num_docs: int) -> Dict:
        # results = self._results
        #TODO numdocs should be specific for each seeker?
        self._results = self.empty_results()
        for name, seeker in self._seekers.items():
            if seeker._is_active:
                new_results = seeker.retrieve_records(query, num_docs)
                if 'faiss_scores' in new_results:

                    self._results['scores'] = np.concatenate(
                        [self._results['scores'], new_results['faiss_scores']])
                    self._results['records_ids'] = np.concatenate(
                        [self._results['records_ids'],
                         new_results['records_ids']])

                    new_refs = [seeker] * len(new_results['faiss_scores'])
                    self._results['seeker_ref'] += new_refs

                    graphs = seeker.graphs
                    nodes = []
                    for id_ in new_results['records_ids']:
                        node = graphs[id_]['nodes']
                        nodes.append(node)
                    # self._results['nodes'] = np.concatenate(
                    #     [self._results['nodes'], nodes])
                    self._results['nodes'] += nodes
                                
        self._results = self.merge_strategy.merge_records(self._results, query)
        return self._results


In [23]:
merge_strategy1 = SortByBestScore()
merge_strategy_descending = SortByBestScore(ascending=True)
mseeker = MultipleGraphIndexManager(seekers, merge_strategy1)
mseeker._seekers


{'2000': <__main__.FaissIndexGraphEventsDimRedu at 0x7f4e94443730>,
 '2001': <__main__.FaissIndexGraphEventsDimRedu at 0x7f4e94443b50>,
 '2002': <__main__.FaissIndexGraphEventsDimRedu at 0x7f4d2ed4afa0>}

In [24]:
# mseeker._seekers['2000']._is_active = True
# mseeker._seekers['2001']._is_active = True
# mseeker._seekers['2002']._is_active = True


In [28]:
# res = mseeker.retrieve_records('the acetycholine produces catabolism', 5)
res = mseeker.retrieve_records('melanoma', 5)


In [29]:
print(res)

{'scores': array([0.25744694, 0.26700145, 0.27029008, 0.27031076, 0.27691984,
       0.27691984, 0.27691984, 0.27713156, 0.27765232, 0.27765232,
       0.28131837, 0.28131837, 0.30437818, 0.30537075, 0.30699319]), 'records_ids': array([ 659612.,  876146., 1205325.,  680141.,  587142., 1500301.,
        847348., 1455204.,  148631.,  997861.,   74126.,  837122.,
        643174.,   98909.,  723237.]), 'seeker_ref': array([<__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443b50>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443b50>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443b50>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
       <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,


In [30]:
# res = mseeker.retrieve_records('the acetycholine produces catabolism', 5)
res = mseeker.retrieve_records('gut microbiome HIV infection', 5)


  results[key] = np.array(val)[idx]


In [31]:
res

{'scores': array([0.53720975, 0.72387248, 0.73587126, 0.73791671, 0.77120394,
        0.77847654, 0.77988464, 0.77988464, 0.77988464, 0.77988464,
        0.77988464, 0.7800796 , 0.78015149, 0.78419787, 0.78458649]),
 'records_ids': array([1018148.,  922531.,  861727.,  410088.,  488989.,  232619.,
         536810.,  638952.,  642068.,  957312.,  425604.,  618016.,
        1222098.,  139086.,  475011.]),
 'seeker_ref': array([<__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443b50>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443730>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4e94443b50>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f4d2ed4afa0>,
        <__main__.FaissIndexGraphEventsDimRedu object at 0x7f

In [27]:
mseeker.print_results()


id  51423.0
score  0.04396906495094299
array([{'type': 'Cell', 'name': 'Monocyte - derived macrophages\n', 'id': 'T9'},
       {'type': 'Planned_process', 'name': 'obtained\n', 'id': 'T41'}],
      dtype=object)
id  187973.0
score  0.07727345824241638
array([{'type': 'Planned_process', 'name': 'isolated\n', 'id': 'T74'},
       {'type': 'Cell', 'name': 'monocyte - derived macrophages\n', 'id': 'T24'}],
      dtype=object)
id  407890.0
score  0.09024284780025482
array([{'type': 'Cell', 'name': 'monocyte - derived macrophages\n', 'id': 'T19'},
       {'type': 'Cell_proliferation', 'name': 'grown\n', 'id': 'T25'}],
      dtype=object)
id  187971.0
score  0.09292612224817276
array([{'type': 'Cell', 'name': 'monocyte - derived macrophages\n', 'id': 'T13'},
       {'type': 'Planned_process', 'name': 'isolated\n', 'id': 'T63'}],
      dtype=object)
id  705331.0
score  0.112118199467659
array([{'type': 'Cell', 'name': 'monocyte - derived macrophages\n', 'id': 'T3'},
       {'type': 'Positive_r