# Install useful libraries

In [1]:
!pip install sentence_transformers



# Imports

In [2]:
import os
import re
import time
import json
import torch
import pickle
import nltk
import nltk.data
nltk.download('punkt')

from sentence_transformers import SentenceTransformer, util

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
from preprocessing import *

# Preprocessing

## Start by creating the class that will be used to store an article

If the Dataset has already been parsed and stored in a pickle file, set this variable to True, to avoid parsing it again. If you are running this Notebook for the first time, then of course, the Dataset has not been parsed, so set the variable to False.

In [4]:
has_already_been_parsed = True

In [5]:
class CovidArticle:
  """ class used to keep the information of an article """

  def __init__(self, article_data):
    tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
    self._id = article_data["paper_id"]
    self._name = article_data["metadata"]["title"]
    self._abstract = self._preprocess_abstract(article_data["abstract"], tokenizer)
    self._corpus_text = self._preprocess_bodytext(article_data["body_text"], tokenizer)

  @property
  def id(self):
    return self._id

  @property
  def title(self):
    return self._name

  @property
  def abstract(self):
    return self._abstract

  @property
  def corpus(self):
    return self._corpus_text

  @id.setter
  def id(self, _id):
    self._id = _id

  @title.setter
  def title(self, _title):
    self._name = _title

  @abstract.setter
  def abstract(self, _abstract):
    self._abstract = _abstract

  @corpus.setter
  def corpus(self, _corpus):
    self._corpus_text = _corpus

  @staticmethod
  def _preprocess(text_data):
    result = remove_urls(text_data)
    result = remove_references(result)
    result = remove_multiple_full_stops(result)
    result = remove_et_al(result)
    result = remove_figure_references(result)
    result = remove_multiple_whitespace(result)
    return result

  def _preprocess_abstract(self, abstract, tokenizer):
    if not abstract:
      sentences = [""]
    else:
      sentences = []
      for paragraph in abstract:
        sentences = sentences + self._preprocess_paragraph(paragraph, tokenizer)[1]
    return ("Abstract.", sentences)

  def _preprocess_bodytext(self, body_text, tokenizer):
    return [self._preprocess_paragraph(paragraph, tokenizer)
            for paragraph in body_text] if body_text else [("", "")]

  def _preprocess_paragraph(self, paragraph, tokenizer):
    return (paragraph['section'] + '.',
            [sentence.strip()
             for sentence in tokenizer.tokenize(self._preprocess(paragraph['text']))
             if sentence != '.'])

  @property
  def summary(self):
    """ returns a list containing the title + the sentences in the abtract plus section names """
    sentences = [paragraph_and_section[0]
                 for paragraph_and_section in self._corpus_text
                 if paragraph_and_section[0] != '.']
    return list(set([self.title + '.', self._abstract[0], *self._abstract[1]] + sentences))

  @property
  def text(self):
    """ returns all the text in the article: abstract + sections + paragraph text """
    sentences = [self._abstract[0], *self._abstract[1]]
    for paragraph_and_section in self._corpus_text:
      if paragraph_and_section[0] != '.':
        sentences.append(paragraph_and_section[0])
      if paragraph_and_section[1] != '.':
        sentences = sentences + paragraph_and_section[1]
    return sentences

  def get_section_from_index(self, idx):
    """ given an index of a sentence in the whole text of the article, it
        returns the section and the index of the sentence inside the section """
    text = self.text
    abstract = self.abstract
    corpus_text = self.corpus

    length_of_abstract = 1 + len(abstract[1])
    target_index = None
    target_section = None

    # indexed sentence is in abstract
    if idx < length_of_abstract:
      target_index = idx
      target_section = abstract
    
    # else, indexed sentence is in body text
    else:
      current_index = length_of_abstract
      for paragraph in corpus_text:
        paragraph_length = 1 + len(paragraph[1])
        if idx <= current_index + paragraph_length:
          target_index = idx - current_index
          target_section = paragraph
          break
        else:
          current_index += paragraph_length

    return target_index, target_section

## Helper method to parse all the articles

In [6]:
def get_articles(root_dir, filenames, log_every=None, stop_at=None):
  """
  :param str root_dir:         The root directory containing all the articles in json format.
  :param list[str] filenames:  A list containing the names of the json files.
  :param int log_every:        Frequency of prints that show how many files have been parsed at a
                                  specific timestep.
  :param int stop_at:          The number of articles to parsed. If None, then all the articles
                                  available will be parsed.
  
  :return:  A list containing Article objectes, one for every article.
  :rtype:   list[CovidArticle]
  """

  covid_articles = []

  for filename in filenames:
    full_filepath = os.path.join(root_dir, filename)
    with open(full_filepath) as f:
      data = json.load(f)
      covid_articles.append(CovidArticle(data))

    if log_every is not None and len(covid_articles) % log_every == 0:
      print('{} articles parsed'.format(len(covid_articles)))

    if stop_at is not None and len(covid_articles) == stop_at:
      break

  return covid_articles

Get a list with the articles either by
- Parsing it using the ``` get_articles() ``` function, which takes ~ 1 hour, and then saving it in a pickle file.
- Loading an already parsed pickle file containing the list.

I would suggest first parsing the Dataset in CPU, and then switching to GPU to load it.

In [7]:
# edit your paths here
root_dir = os.path.join('.', 'drive', 'My Drive', 'Colab Notebooks', 'AI2', 'Project4', 'Dataset', 'comm_use_subset')
save_path = os.path.join('.', 'drive', 'My Drive', 'Colab Notebooks', 'AI2', 'Project4', 'Preprocessed_Dataset', 'processed_articles.pickle')

# if the dataset is to be parsed for the first time
if not has_already_been_parsed:
  filenames = sorted(os.listdir(root_dir))
  articles = get_articles(root_dir, filenames, log_every=100)
  with open(save_path, 'wb') as f:
    pickle.dump(articles, f)
# else, it has already been parsed, just load it
else:
  with open(save_path, 'rb') as f:
    articles = pickle.load(f)

In [8]:
len(articles)

9009

Let's take a look at IDs and the titles of the first 20 articles

In [9]:
for article in articles[:20]:
  print('ID: {},\tTitle: {}'.format(article.id, article.title))

ID: 000b7d1517ceebb34e1e3e817695b6de03e2fa78,	Title: Supplementary Information An eco-epidemiological study of Morbilli-related paramyxovirus infection in Madagascar bats reveals host-switching as the dominant macro-evolutionary mechanism
ID: 00142f93c18b07350be89e96372d240372437ed9,	Title: immunity to pathogens taught by specialized human dendritic cell subsets
ID: 0022796bb2112abd2e6423ba2d57751db06049fb,	Title: Public Health Responses to and Challenges for the Control of Dengue Transmission in High-Income Countries: Four Case Studies
ID: 00326efcca0852dc6e39dc6b7786267e1bc4f194,	Title: a section of the journal Frontiers in Pediatrics A Review of Pediatric Critical Care in Resource-Limited Settings: A Look at Past, Present, and Future Directions
ID: 00352a58c8766861effed18a4b079d1683fec2ec,	Title: MINI REVIEW Function of the Deubiquitinating Enzyme USP46 in the Nervous System and Its Regulation by WD40-Repeat Proteins
ID: 0043d044273b8eb1585d3a66061e9b4e03edc062,	Title: Evaluation of

Let's also take a look at the contents of one article

In [10]:
art = articles[3]

In [11]:
art.id

'00326efcca0852dc6e39dc6b7786267e1bc4f194'

In [12]:
art.title

'a section of the journal Frontiers in Pediatrics A Review of Pediatric Critical Care in Resource-Limited Settings: A Look at Past, Present, and Future Directions'

In [13]:
art.abstract

('Abstract.',
 ['Fifteen years ago, United Nations world leaders defined millenium development goal 4 (MDG 4): to reduce under-5-year mortality rates by two-thirds by the year 2015.',
  'Unfortunately, only 27 of 138 developing countries are expected to achieve MDG 4.',
  'The majority of childhood deaths in these settings result from reversible causes, and developing effective pediatric emergency and critical care services could substantially reduce this mortality.',
  'The Ebola outbreak highlighted the fragility of health care systems in resource-limited settings and emphasized the urgent need for a paradigm shift in the global approach to healthcare delivery related to critical illness.',
  'This review provides an overview of pediatric critical care in resource-limited settings and outlines strategies to address challenges specific to these areas.',
  'Implementation of these tools has the potential to move us toward delivery of an adequate standard of critical care for all childr

In [14]:
art.summary

['Research Strategies for Clinical evidence.',
 'The majority of childhood deaths in these settings result from reversible causes, and developing effective pediatric emergency and critical care services could substantially reduce this mortality.',
 'Global Justice.',
 'Research Agenda for Critical Care in Resource-Poor Settings.',
 'This review provides an overview of pediatric critical care in resource-limited settings and outlines strategies to address challenges specific to these areas.',
 'Fifteen years ago, United Nations world leaders defined millenium development goal 4 (MDG 4): to reduce under-5-year mortality rates by two-thirds by the year 2015.',
 'eTHiCS OF PeDiATRiC CRiTiCAL CARe iN ReSOURCe-LiMiTeD COUNTRieS.',
 'AUTHOR CONTRiBUTiONS.',
 'Health Care work Force and education.',
 'CONCLUSiON.',
 'Category Specific exclusions Comments.',
 'Unfortunately, only 27 of 138 developing countries are expected to achieve MDG 4.',
 'Critical Care Guidelines and Toolkits.',
 'wHAT HA

In [15]:
art.text

['Abstract.',
 'Fifteen years ago, United Nations world leaders defined millenium development goal 4 (MDG 4): to reduce under-5-year mortality rates by two-thirds by the year 2015.',
 'Unfortunately, only 27 of 138 developing countries are expected to achieve MDG 4.',
 'The majority of childhood deaths in these settings result from reversible causes, and developing effective pediatric emergency and critical care services could substantially reduce this mortality.',
 'The Ebola outbreak highlighted the fragility of health care systems in resource-limited settings and emphasized the urgent need for a paradigm shift in the global approach to healthcare delivery related to critical illness.',
 'This review provides an overview of pediatric critical care in resource-limited settings and outlines strategies to address challenges specific to these areas.',
 'Implementation of these tools has the potential to move us toward delivery of an adequate standard of critical care for all children glo

# Load the pre-trained models

In [16]:
embedder1 = SentenceTransformer('stsb-distilbert-base')

In [17]:
embedder2 = SentenceTransformer('stsb-roberta-base')

## Define Look-up dictionaries to avoid recomputing some embeddings

Get the embeddings of the summaries of all the articles, for each model

In [18]:
# should take ~ 15 mins on GPU for the whole Dataset
index_to_summary_embeddings1 = {}
for idx, article in enumerate(articles):
  index_to_summary_embeddings1[idx] = embedder1.encode(article.summary, convert_to_tensor=True)

In [19]:
# should take ~ 27 mins on GPU for the whole Dataset
index_to_summary_embeddings2 = {}
for idx, article in enumerate(articles):
  index_to_summary_embeddings2[idx] = embedder2.encode(article.summary, convert_to_tensor=True)

Also define dictionaries that will be used to store the embeddings of the whole articles, to avoid computing them twice

In [20]:
index_to_article_embedding1 = {}
index_to_article_embedding2 = {}

# Predict

In [21]:
def max_similarity(corpus_embeddings, query_embedding, k=1):
  """ returns the indices and values (cosine similarity) of the sentences from the corpus with the
      top k cosine similarities with the query embedding """
  cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0].cpu()
  return torch.topk(cos_scores, k=k)

In [22]:
def relative_articles(articles, index_to_summary_embeddings, query_embedding, threshold=0.5):
  """ returns a list with the indices of the articles, for which a sentence in its summary with
      cosine similarity > threshold, exists """
  return [idx for idx, article in enumerate(articles)
          if max_similarity(index_to_summary_embeddings[idx], query_embedding)[0][0] > threshold]

In [23]:
def get_passage(corpus, best_sentence_idx, index_in_section, paragraph, embedder):
  """ given a sentence inside the corpus, it returns a string containing all the similar sentences
      that belong in the same section of the corpus """
  best_sentence_embedding = embedder.encode([corpus[best_sentence_idx]], convert_to_tensor=True)
  previous_sentence_is_relevant = True
  next_sentence_is_relevant = True
  passage = corpus[best_sentence_idx]
  offset = 1

  # while either the previous sentence of the next are relevant, keep appending them in the passage
  while previous_sentence_is_relevant is True or next_sentence_is_relevant is True:
    
    if index_in_section - offset < 0:
      previous_sentence_is_relevant = False
    elif index_in_section + offset >= len(paragraph):
      next_sentence_is_relevant = False

    if previous_sentence_is_relevant is True:
      previous_sentence = corpus[best_sentence_idx - offset]
      previous_sentence_embedding = embedder.encode([previous_sentence], convert_to_tensor=True)
      cosine_similarity = util.pytorch_cos_sim(best_sentence_embedding, previous_sentence_embedding)[0][0].cpu()
      if cosine_similarity < 0.5:
        previous_sentence_is_relevant = False
      else:
        passage = ' '.join([previous_sentence, passage])

    if next_sentence_is_relevant is True:
      next_sentence = corpus[best_sentence_idx + offset]
      next_sentence_embedding = embedder.encode([next_sentence], convert_to_tensor=True)
      cosine_similarity = util.pytorch_cos_sim(best_sentence_embedding, next_sentence_embedding)[0][0].cpu()
      if cosine_similarity < 0.5:
        next_sentence_is_relevant = False
      else:
        passage = ' '.join([passage, next_sentence])

    offset += 1

  return passage

In [24]:
# deprecated
def threshold_is_ok(number_of_articles_included, min_articles=5, max_articles=150):
  """ returns a boolean that determines whether the number of articles that did not get filtered
      is acceptable """
  return min_articles <= number_of_articles_included <= max_articles

# deprecated
def fix_threshold(number_of_articles_included, threshold, min_articles=5, max_articles=150,
                  decrease_step=0.01, increase_step=0.05):
  """ computes a new value for a threshold, depending on how many articles were found relevant
      using the previous threshold """
  if number_of_articles_included < min_articles:
    return threshold + increase_step
  elif number_of_articles_included > max_articles:
    return threshold - decrease_step

In [25]:
def find_best_article(articles, index_to_summary_embeddings, index_to_article_embeddings, query, embedder, threshold=0.5):
  """ returns the index (in the articles list) of the article that best fits the given query """
  query_embedding = embedder.encode(query, convert_to_tensor=True)
  articles_to_explore = relative_articles(articles, index_to_summary_embeddings, query_embedding, threshold=threshold)
  while len(articles_to_explore) == 0:
    threshold -= 0.05
    articles_to_explore = relative_articles(articles, index_to_summary_embeddings, query_embedding, threshold=threshold)
  
  best_cos_sim = 0.0
  best_article_idx = None
  best_sentence_idx = None

  # print("For query '{}', found {} articles to explore.\n".format(query, len(articles_to_explore)))
  # print([articles[article].id for article in articles_to_explore])

  for idx in articles_to_explore:
    article = articles[idx]

    if idx not in index_to_article_embeddings:
      corpus = article.text
      corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
      index_to_article_embeddings[idx] = corpus_embeddings
    else:
      corpus_embeddings = index_to_article_embeddings[idx]

    top_sentences_and_scores = max_similarity(corpus_embeddings, query_embedding)
    score = top_sentences_and_scores[0][0]

    if score > best_cos_sim:
      best_cos_sim = score
      best_article_idx = idx
      best_sentence_idx = top_sentences_and_scores[1][0]

  best_article = articles[best_article_idx]
  corpus = best_article.text
  index_in_section, section = best_article.get_section_from_index(best_sentence_idx)
  paragraph = [section[0], *section[1]]
  passage = get_passage(corpus, best_sentence_idx, index_in_section, paragraph, embedder)

  return best_article_idx, passage

These queries and their respective answers are written in the file queries.txt

In [26]:
queries = ['How is the diagnosis of pulmonary tuberculosis made in TB clinics and hospitals?',
           'Was the Porcine epidemic diarrhea virus first detected in Slovenia?',
           'Is handwashing the most important measure against infectious diseases?',
           'How was the importance of clathrin-mediated endocytosis for MHV confirmed?',
           'Which alternatives do we have for conventional chemotherapy?',
           'Is there a drug to treat the EV71 infection?',
           'Which is common cause for diarrhea and septicemia in calves?',
           'Which scanning technique was used to confirm hypotheses regarding the MAb-1G10 epitope structure?',
           'How can host translational inhibition be achieved?',
           'What is Multiple sclerosis?']

Define a threshold that will be use to filter out articles where their summary has an a cosine similarity with the query, lower than the threshold. The bigger the threshold, the more articles get filtered out. This has 2 effects:
- Greatly improves running time
- Increased probability of missing out the best article as it's summary might not be similar to the query, yet its body text may have the correct passage

In [27]:
threshold = 0.65

## Find the best articles for them

In [28]:
total_time = 0.0
for idx, query in enumerate(queries):
  
  start_time = time.time()
  best_article_index, passage = find_best_article(articles, index_to_summary_embeddings1, index_to_article_embedding1, query, embedder1, threshold=threshold)
  elapsed_time = time.time() - start_time
  total_time += elapsed_time
  best_article = articles[best_article_index]

  print('-' * 200)
  print("For Query {}: '{}', the most relevant article is:\n".format(idx + 1, query))
  print('ID: {},\tTitle: {}\n'.format(best_article.id, best_article.title))
  print('The Passage is:')
  print(passage)
  print('\nFound the answer in %.2f seconds' % elapsed_time)
  print('-' * 200)
  print('\n')

print('\n\n')
print('-' * 200)
print('\nAverage time per Query: {:.2f}\n'.format(total_time / len(queries)))
print('-' * 200)

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
For Query 1: 'How is the diagnosis of pulmonary tuberculosis made in TB clinics and hospitals?', the most relevant article is:

ID: 0043d044273b8eb1585d3a66061e9b4e03edc062,	Title: Evaluation of the tuberculosis programme in Ningxia Hui Autonomous region, the People's Republic of China: a retrospective case study

The Passage is:
Diagnosis of pulmonary TB in hospitals and TB clinics is made on the basis of clinical examination; chest radiography and sputum smear microscopy and/or sputum culture. Following diagnosis, patients enter the DOTS program which prescribes short-course chemotherapy (SCC) comprising 2 months of isoniazid (H), rifampicin (R), pyrazinamide (Z) plus streptomycin (S) or ethambutol (E) followed by 4 months of H and R. This is the WHO recommended regimen for treating new

In [29]:
total_time = 0.0
for idx, query in enumerate(queries):
  start_time = time.time()
  best_article_index, passage = find_best_article(articles, index_to_summary_embeddings2, index_to_article_embedding2, query, embedder2, threshold=threshold)
  elapsed_time = time.time() - start_time
  total_time += elapsed_time
  best_article = articles[best_article_index]
  
  print('-' * 200)
  print("For Query {}: '{}', the most relevant article is:\n".format(idx + 1, query))
  print('ID: {},\tTitle: {}\n'.format(best_article.id, best_article.title))
  print('The Passage is:')
  print(passage)
  print('\nFound the answer in %.2f seconds' % elapsed_time)
  print('-' * 200)
  print('\n')

print('\n\n')
print('-' * 200)
print('\nAverage time per Query: {:.2f}\n'.format(total_time / len(queries)))
print('-' * 200)

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
For Query 1: 'How is the diagnosis of pulmonary tuberculosis made in TB clinics and hospitals?', the most relevant article is:

ID: 0043d044273b8eb1585d3a66061e9b4e03edc062,	Title: Evaluation of the tuberculosis programme in Ningxia Hui Autonomous region, the People's Republic of China: a retrospective case study

The Passage is:
Diagnosis of pulmonary TB in hospitals and TB clinics is made on the basis of clinical examination; chest radiography and sputum smear microscopy and/or sputum culture. Following diagnosis, patients enter the DOTS program which prescribes short-course chemotherapy (SCC) comprising 2 months of isoniazid (H), rifampicin (R), pyrazinamide (Z) plus streptomycin (S) or ethambutol (E) followed by 4 months of H and R. This is the WHO recommended regimen for treating new

An analysis of the results can be found in the pdf report.