In [1]:
%load_ext autoreload
%autoreload 2

# Search COVID papers with Deep Learning
*Transformers + Elastic Search = ❤️*

Good news everyone! In this article, we are not going to fit a linear regression on the sars-cov-19 data. Rather, we will do something more interesting. Most of the work is based on [this project](https://github.com/gsarti/covid-papers-browser) in which I am working with students from the Universita of Triest (Italy). A live demo is available TODO

## Data

We are going to use this dataset from Kaggle, we only need the `metadata.csv` file that contains information about the paper and the full text of the abstract. Let's take a look!

In [27]:
import pandas as pd
from Project import Project
# Project holds all the path
pr = Project()

df = pd.read_csv(pr.data_dir / 'metadata.csv')

df.head(1)

Unnamed: 0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,authors,journal,Microsoft Academic Paper ID,WHO #Covidence,has_pdf_parse,has_pmc_xml_parse,full_text_file,url
0,zjufx4fo,b2897e1277f56641193a6db73825f707eed3e4c9,PMC,Sequence requirements for RNA strand transfer ...,10.1093/emboj/20.24.7220,PMC125340,11742998.0,unk,Nidovirus subgenomic mRNAs contain a leader se...,2001-12-17,"Pasternak, Alexander O.; van den Born, Erwin; ...",The EMBO Journal,,,True,True,custom_license,http://europepmc.org/articles/pmc125340?pdf=re...


In [28]:
df['pubmed_id'] = df['pubmed_id']

In [29]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 57366 entries, 0 to 57365
Data columns (total 18 columns):
cord_uid                       57366 non-null object
sha                            43540 non-null object
source_x                       57366 non-null object
title                          57203 non-null object
doi                            54020 non-null object
pmcid                          46804 non-null object
pubmed_id                      40905 non-null float64
license                        57366 non-null object
abstract                       46847 non-null object
publish_time                   57358 non-null object
authors                        54840 non-null object
journal                        51576 non-null object
Microsoft Academic Paper ID    964 non-null float64
WHO #Covidence                 1768 non-null object
has_pdf_parse                  57366 non-null bool
has_pmc_xml_parse              57366 non-null bool
full_text_file                 48921 non-null ob

Let's create a `Dataset` to properly work with the data in the pytorch enviroment. To give our search engine more context, we will embed the `title` and the `abstract` together, this is sotred in the `title_abstract` key.

In [37]:
from torch.utils.data import Dataset, DataLoader

class CovidPapersDataset(Dataset):    
    def __init__(self, df, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.df = df
        self.df = self.df[['title', 'authors', 'abstract', 'url', 'pubmed_id', ]]
        self.df.title.fillna('', inplace = True)
        self.df.abstract.fillna('', inplace = True)
        self.df = self.df.fillna(0)
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        self.df.loc[idx:, 'title_abstract'] = f"{row['title']} {row['abstract']}"
        return  self.df.loc[idx].to_dict()

    def __len__(self):
        return self.df.shape[0]
    
    @classmethod
    def from_path(cls, path, *args, **kwargs):
        df = pd.read_csv(path)
        return cls(df=df, *args, **kwargs)

In order, I have subclassed `torch.utils.data.Dataset` to create a custom dataset. The dataset is expecting a dataframe as input, we keep only the interesting columns as input and filling up the nans. 

We return a dictionary since the dataframe is not a support type in pytorch. We can test it out

In [38]:
ds = CovidPapersDataset.from_path(pr.data_dir / 'metadata.csv')

ds[0]['title']

'Sequence requirements for RNA strand transfer during nidovirus discontinuous subgenomic RNA synthesis'

## Embed

Now we need a way to create an embedding from the data. We defined a class `Embedder` that loads automatically a model from `hugging_faces`. On top of that we add a pooling layer to create one single `768` vector for each input.

**TODO** Explain model choice

In [39]:
from dataclasses import dataclass
from sentence_transformers import models, SentenceTransformer

@dataclass
class Embedder:
    name: str = 'gsarti/scibert-nli'
    max_seq_length: int  = 128
    do_lower_case: bool  = True
    
    def __post_init__(self):
        word_embedding_model = models.BERT(
            'gsarti/biobert-nli',
            max_seq_length=128,
            do_lower_case=True
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                pooling_mode_mean_tokens=True,
                pooling_mode_cls_token=False,
                pooling_mode_max_tokens=False
            )
    
        self.model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
        
    def __call__(self, text):
        return self.model.encode(text) 

Let's try it by embedding a data point in our dataset

In [8]:
embedder = Embedder()

emb = embedder(ds[0]['title_abstract'])

emb[0].shape

(768,)

Et voilà!

We can now create a Dataloader to process our papers in batch.

## Search

Okay, we now have a way to embed each paper, but how can we search in the data using a query? Assuming we have embedded **all** the papers we could also **embed the query** and compute the cosine similarity between the query and all the embedding and then show the results sorted by the distance. Intuitively, the closer the more similar. 

So, how we can do it? We need a proper way to handle the data and to run the cosine similarity fast enough. Fortunately, Elastic Search comes to the rescue!

### Elastic Search

[Elastic Search](https://www.elastic.co/) is a database with one goal, yes you guessed right: searching. We will first store all the embedding in elastic and then use its API to perform the searching. If you are lazy like me you can [install elastic search with docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html)

```
docker pull docker.elastic.co/elasticsearch/elasticsearch:7.6.2
docker run -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:7.6.2

```

Perfect. The next step is to store the embeddings and the papers' information on elastic search. It is a very straightforward process. We first need to create an `index` (a new database) and then to create one entry for each paper.

When we create an `index` we need to describe for elastic what we wish to store. In our case:


```
{
    "settings": {
        "number_of_shards": 2,
        "number_of_replicas": 1
    },
    "mappings": {
        "dynamic": "true",
        "_source": {
            "enabled": "true"
        },
        "properties": {
            "title": {
                "type": "text"
            },
            ... all other properties (columns of the datafarme)
            "embed": {
                "type": "dense_vector",
                "dims": 768
            }
        }
    }
}

```

The last entry we define a property `embed` as a dense vector with `768`. This is indeed our embed.

In [40]:
from dataclasses import dataclass, field
from pathlib import Path
from elasticsearch import Elasticsearch
from tqdm.autonotebook import tqdm
import pandas as pd
import json
from elasticsearch.helpers import bulk

@dataclass
class ElasticSearchProvider:
    index_file: dict
    client: Elasticsearch = Elasticsearch()
    index_name: str = 'covid'

    def drop(self):
        self.client.indices.delete(index=self.index_name, ignore=[404])
        return self

    def create_index(self):
        self.client.indices.create(index=self.index_name, body=self.index_file)
        return self

    def create_and_bulk_documents(self, entries:list):
        entries_elastic = []
        for entry in entries:
            entry_elastic = {
                **entry,
                **{
                    '_op_type': 'index',
                    '_index': self.index_name
                }
            }
        
            entries_elastic.append(entry_elastic)
            
        bulk(self.client, entries_elastic)

    def __call__(self, entries: list):
        self.create_and_bulk_documents(entries)

        return self

Unfortunatly, Elastic Search won't be able to serialize the `numpy` arrays. So we need to create an Adapter for our data. This class takes a list and "adapt" the data to work in our `ElasticSearchProvider`.

In [41]:
class CovidPapersEmbeddedAdapter:
        
    def __call__(self, x, embs):
        for el, emb in zip(x, embs):
            el['embed'] = np.array(emb).tolist()

        return x

Okay, we have everything in place. A way to represent the date, one to encode them in a vector and a way to store the result. Let's wrap 

In [None]:
dl = DataLoader(ds, batch_size=128, num_workers=4, collate_fn=lambda x: x)
es_adapter = CovidPapersEmbeddedAdapter()

import numpy as np
from utils import device



with open(pr.base_dir / 'es_index.json', 'r') as f:
    index_file = json.load(f)
    es_provider = ElasticSearchProvider(index_file)

es_provider.drop()
es_provider.create_index()

for batch in tqdm(dl):
    x = [b['title_abstract'] for b in batch]
    embs = embedder(x)
    es_provider(es_adapter(batch, embs))
    


HBox(children=(FloatProgress(value=0.0, max=449.0), HTML(value='')))

If we now go take a look at `http://localhost:9200/covid/_search?pretty=true&q=*:*` We will see our data correctly stored on elastic search

### Make a query

We are almost done. The last piece of the puzzle is a way to search in the database. Elastic search can perform cosine similarity between one input vector and a target vector field in all the documents. The syntax is very easy:

```
 {
    "query": {
        "match_all": {}
    },
    "script": {
        "source":
        "cosineSimilarity(params.query_vector, doc['embed']) + 1.0",
        "params": {
            "query_vector": vector
        }
    }
}

```

Where `vector` is our input. So, we define a class that does exactly that, takes an vector as an input an show all the results from the query

In [23]:
@dataclass
class Elasticsearcher:
    """
    This class implements the logic behind searching for a vector in elastic search.
    """
    client: Elasticsearch = Elasticsearch()
    index_name: str = 'covid'

    def __call__(self, vector: list):
        script_query = {
            "script_score": {
                "query": {
                    "match_all": {}
                },
                "script": {
                    "source":
                    "cosineSimilarity(params.query_vector, doc['embed']) + 1.0",
                    "params": {
                        "query_vector": vector
                    }
                }
            }
        }

        res = self.client.search(
            index=self.index_name,
            body={
                "size": 1000,
                "query": script_query,
                "_source": {
                    "includes": ["title", "abstract"]
                }
            })

        return res

In [25]:
es_search = Elasticsearcher()
es_search(embedder('covid syntopms')[0].tolist())

{'took': 270,
 'timed_out': False,
 '_shards': {'total': 2, 'successful': 2, 'skipped': 0, 'failed': 0},
 'hits': {'total': {'value': 10000, 'relation': 'gte'},
  'max_score': 1.8080369,
  'hits': [{'_index': 'covid',
    '_type': '_doc',
    '_id': 'Phn1sXEBuCitufzF0m8L',
    '_score': 1.8080369,
    '_source': {'abstract': 0, 'title': 'TOC'}},
   {'_index': 'covid',
    '_type': '_doc',
    '_id': 'Ehn4sXEBuCitufzFbr3k',
    '_score': 1.8080369,
    '_source': {'abstract': 0, 'title': 'TOC'}},
   {'_index': 'covid',
    '_type': '_doc',
    '_id': 'vxn4sXEBuCitufzF7stI',
    '_score': 1.8080369,
    '_source': {'abstract': 0, 'title': 'TOC'}},
   {'_index': 'covid',
    '_type': '_doc',
    '_id': 'YBn1sXEBuCitufzFBlFH',
    '_score': 1.8077158,
    '_source': {'abstract': 0, 'title': 'References'}},
   {'_index': 'covid',
    '_type': '_doc',
    '_id': 'ixn3sXEBuCitufzFgaEk',
    '_score': 1.8077158,
    '_source': {'abstract': 0, 'title': 'References'}},
   {'_index': 'covid',
   