In [1]:
import torch
from torch.utils import data
from transformers import AutoTokenizer, AutoModel
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from itertools import islice
import os
import json
from time import time
from collections import Counter
import numpy as np
import pandas as pd
import torch.nn.functional as F
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tnrange, tqdm
from utils import *
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
%load_ext autoreload
%autoreload 2
text_path = 'data/preprocessed_text.json'

In [2]:
with open('data/preprocessed_text.json', 'r') as f:
    articles = json.load(f)
len(articles.keys())

33375

In [None]:
# original scibert
tokenizer_scibert = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model_scibert = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
# bert finetuned on covid
tokenizer_covid = AutoTokenizer.from_pretrained('deepset/covid_bert_base')
model_covid = AutoModel.from_pretrained('deepset/covid_bert_base')
# bert for sentences
model_sent = SentenceTransformer('bert-base-nli-mean-tokens')

 68%|█████████████████████████████████████████████████████▋                         | 275M/405M [01:39<00:41, 3.17MB/s]

In [None]:
def sentence_similarity(first, second):
    return cosine_similarity(sentence_embedding(tokenizer_scibert, model_scibert, first), sentence_embedding(tokenizer_scibert, model_scibert, second))

In [None]:
sentence_similarity("What are the risk factors for the virus?", "Fever was one of the symptoms of the virus.")

## Generate title embeddings OR just load them
Generation will take about 30 minutes per model for the full titles

### Crop paper titles to the first sentence. Drop those that are still too large

In [None]:
# select n papers
n = 2000
max_length = 30
selection = take(n, articles)
selected_papers = {key: articles[key] for key in selection}
titles = [paper_json[1]['title'] for paper_json in selected_papers.items()]
cropped_titles = []
for title in titles:
    dot_index = title.find(".")
    if dot_index == -1:
        cropped_titles.append(title)
    else:
        cropped_titles.append(title[0:dot_index + 1])

# first run removes titles that are too long, 
# second run builds actual embeddings once both tokenizers have removed those that are too long

encoded_scibert, indices_to_drop = get_encodings_drop_long(cropped_titles, tokenizer_scibert, max_length = max_length)
drop_from_lists([cropped_titles, titles], indices_to_drop)

encoded_covid, indices_to_drop = get_encodings_drop_long(cropped_titles, tokenizer_covid, max_length = max_length)
drop_from_lists([cropped_titles, titles], indices_to_drop)

encoded_scibert, indices_to_drop = get_encodings_drop_long(cropped_titles, tokenizer_scibert, max_length = max_length)
drop_from_lists([cropped_titles, titles], indices_to_drop)

encoded_covid, indices_to_drop = get_encodings_drop_long(cropped_titles, tokenizer_covid, max_length = max_length)
drop_from_lists([cropped_titles, titles], indices_to_drop)

#### Generate

In [None]:
batch_size = 32
title_generator = data.DataLoader(encoded_scibert, batch_size=batch_size, num_workers=4)
embeddings_scibert = torch.zeros(encoded_scibert.shape[0], 768)
embeddings_covid = torch.zeros(encoded_scibert.shape[0], 768)
with torch.no_grad():
    cur_index = 0
    t = tqdm(iter(title_generator), leave=False, total=len(title_generator))
    for i, batch in enumerate(t):
        cur_index += batch_size
        output_scibert = model_scibert(batch)
        embeddings_scibert[cur_index - batch_size: cur_index] = output_scibert[0][:, 0, :]

title_generator = data.DataLoader(encoded_covid, batch_size=batch_size, num_workers=4)
with torch.no_grad():
    cur_index = 0
    t = tqdm(iter(title_generator), leave=False, total=len(title_generator))
    for i, batch in enumerate(t):
        cur_index += batch_size
        output_covid = model_covid(batch)
        embeddings_covid[cur_index - batch_size: cur_index] = output_covid[0][:, 0, :]

embeddings_sent = torch.tensor(model_sent.encode(cropped_titles))

In [None]:
#torch.save(embeddings, "embeddings.pt")

#### Load

In [None]:
#embeddings = torch.load("embeddings.pt")

## Similarity search

In [None]:
query = "Risk factors for covid-19 death"
query_embedding_scibert = get_query_embedding(tokenizer_scibert, model_scibert, query, max_length=max_length)
query_embedding_covid = get_query_embedding(tokenizer_covid, model_covid, query, max_length=max_length)
query_embedding_sent = get_query_embedding(None, model_sent, query, max_length=max_length)

In [None]:
n = 20
indices_scibert, titles_scibert = find_top_n_similar(embeddings_scibert, query_embedding_scibert, titles, n=n)
titles_scibert

In [None]:
indices_covid, titles_covid = find_top_n_similar(embeddings_covid, query_embedding_covid, titles, n=n)
titles_covid

In [None]:
indices_sent, titles_sent = find_top_n_similar(embeddings_sent, query_embedding_sent, titles, n=n)
titles_sent

## Visualization

In [None]:
tsne_scibert = get_tsne_embeddings(embeddings_scibert)
tsne_covid = get_tsne_embeddings(embeddings_covid)
tsne_sent = get_tsne_embeddings(embeddings_sent)

In [None]:
def plot_query_embeddings(query, n=40):
    models = [model_scibert, model_covid, model_sent]
    tokenizers = [tokenizer_scibert, tokenizer_covid, None]
    embeddings = [embeddings_scibert, embeddings_covid, embeddings_sent]
    tsnes = [tsne_scibert, tsne_covid, tsne_sent]
    plot_titles = ["Scibert", "Covid", "Bert Sentence"]
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    for index, cur in enumerate(zip(models, tokenizers, embeddings, plot_titles, tsnes)):
        query_embedding = get_query_embedding(cur[1], cur[0], query)
        similar, _ = find_top_n_similar(cur[2], query_embedding, titles, n=n)
        similar = set(similar[:n].tolist())
        tsne = cur[4]
        for i in range(tsne.shape[0]):
            if i in similar:
                ax[index].scatter(tsne[i, 0], tsne[i, 1], c='r', s=16)
            else:
                ax[index].scatter(tsne[i, 0], tsne[i, 1], c='b', s=4)
        ax[index].set_title(cur[3])

In [None]:
def plot_query_embeddings_plotly(query, titles, n=40):
    models = [model_scibert, model_covid, model_sent]
    tokenizers = [tokenizer_scibert, tokenizer_covid, None]
    embeddings = [embeddings_scibert, embeddings_covid, embeddings_sent]
    tsnes = [tsne_scibert, tsne_covid, tsne_sent]
    plot_titles = ["Scibert", "Covid", "Bert Sentence"]
    fig = make_subplots(rows=1, cols=3)
    for index, cur in enumerate(zip(models, tokenizers, embeddings, plot_titles, tsnes)):
        query_embedding = get_query_embedding(cur[1], cur[0], query)
        similar, _ = find_top_n_similar(cur[2], query_embedding, titles, n=n)
        similar_set = set(similar[:n].tolist())
        tsne = cur[4]
        fig.add_trace(go.Scatter(x=tsne[:, 0], y=tsne[:, 1], \
                                 mode="markers", text=titles, \
                                 marker=dict(size=[6 if i in similar_set else 4 for i in range(len(titles))],\
                                             color=['red' if i in similar_set else 'blue' for i in range(len(titles))]))\
                      , 1, index + 1) 
    fig.update_layout(height=400, width=1000, title_text="Visualization of Search Results for '{}'".format(query))
    fig.show()
    print("Top 10 results:")
    for i in similar[:10]:
        print(titles[i])

In [None]:
plot_query_embeddings_plotly("Risk factors for covid-19 death", titles)

In [None]:
plot_query_embeddings_plotly("Asymptomatic carriers of the virus", titles)