## Import Library

In [1]:
# START BY CREATING A CONDA ENV USING ENVIRONMENT.YML THEN USE THAT ENV AS THE KERNEL FOR THIS NOTEBOOK

In [2]:
# Imports
from sentence_transformers import SentenceTransformer, util
import numpy as np
import pandas as pd
# import mlflow
import os
import pickle

from frequently_requested_docs.docs_helper import getModel, getSaveName, loadEmbeddings, getEmbeddingPath
from frequently_requested_docs.docs_config import TOP_K, MODEL_NAMES, DATA_CSV_PATH

## Model Selection and Initialization

In [3]:
# List of models optimized for semantic textual similarity can be found at:
# https://docs.google.com/spreadsheets/d/14QplCdTCDwEmTqrn1LH4yrbKvdogK4oQvYO1K1aPR5M/edit#gid=0

# MODEL_NAMES = [
#     'nli-mpnet-base-v2',
#     'nli-roberta-base-v2',
#     'princeton-nlp/sup-simcse-roberta-large',
#     'princeton-nlp/unsup-simcse-roberta-large',
#     'stsb-distilroberta-base-v2',
#     'stsb-mpnet-base-v2',
#     'stsb-roberta-base',
#     'stsb-roberta-base-v2',
#     'stsb-roberta-large',
# ]

# look at MODEL_NAMES in config.py for more model names to test
model_name = 'princeton-nlp/unsup-simcse-roberta-large'
save_name = getSaveName(model_name)

print(save_name, model_name)
    
model = getModel(model_name, save_name)
print(type(model))
        

princeton-nlp-unsup-simcse-roberta-large princeton-nlp/unsup-simcse-roberta-large
Loading model from disc
<class 'sentence_transformers.SentenceTransformer.SentenceTransformer'>


## Retrieve Top K most similar docs from freqdoc dataset given a request

In [4]:
# Format of corpus sentences
corpus_docs = []
data = pd.read_csv(DATA_CSV_PATH)
data.reset_index()

for ind, row in data.iterrows():
    if isinstance(row['Document'], str):
        corpus_docs.append(row)
        
print(corpus_docs[0])
        

Component                                                 FEMA
Document                             FEMA Clothing- XYZ Legend
Tag          FEMA Individual and Public Assistance Claims P...
URL          https://www.dhs.gov/sites/default/files/public...
Agency                                                     DHS
Name: 0, dtype: object


In [5]:
# Load corpus embeddings if exist, otherwise encode embeddings
embedding_path = getEmbeddingPath(save_name)
corpus_embeddings = None
            
corpus_docs, corpus_embeddings = loadEmbeddings(model, embedding_path, corpus_docs)

Loading pre-computed embeddings from disc


In [6]:
print(type(corpus_embeddings))


<class 'torch.Tensor'>


In [7]:
# Test 1 or more sentences
examples = ['I am searching for the Detention Facility Reviews for the Randall County Jail in Amarillo, Texas', 'Statements made by former georgia senator david perdue about visas.', 'All documents regarding the TSA’s throughput data for August 2017']
sentence = examples[2]

# encode sentence to get sentence embeddings
sentence_embedding = model.encode(sentence, convert_to_tensor=True)

In [8]:
# compute similarity scores of the sentence with the corpus
cos_scores = util.pytorch_cos_sim(sentence_embedding, corpus_embeddings)[0]

# Sort the results in decreasing order and get the first TOP_K
top_results = np.argpartition(-cos_scores, range(TOP_K))[0:TOP_K]

print("Sentence:", sentence, "\n")
print("Top", TOP_K, "most similar sentences in corpus:")
for idx in top_results[0:TOP_K]:
    print(corpus_docs.iloc[int(idx)]["Document"], "(Score: %.4f)" % (cos_scores[idx]))

Sentence: All documents regarding the TSA’s throughput data for August 2017 

Top 25 most similar sentences in corpus:
August 2017 FOIA Log (Score: 0.7429)
FY18 August FOIA Logs  (Score: 0.7229)
TSA Throughput Data July 9, 2017 to July 15, 2017 (Score: 0.6852)
August 2016 FOIA Log (Score: 0.6851)
TSA Throughput Data July 30, 2017 to August 5, 2017 (Score: 0.6828)
August 2018 FOIA Log (Score: 0.6811)
August 2015 FOIA Log (Score: 0.6795)
TSA Throughput Data July 2, 2017 to July 8, 2017 (Score: 0.6714)
August 2010 FOIA Log (Score: 0.6644)
TSA Throughput Data July 16, 2017 to July 22, 2017 (Score: 0.6630)
FOIA Log August 2017 (Score: 0.6606)
TSA Throughput Data July 8, 2018 to July 14, 2018 Page Count: 942 (Score: 0.6575)
TSA Throughput Data August 13, 2017 to August 19, 2017 (Score: 0.6548)
July 2016 FOIA Log (Score: 0.6528)
TSA Throughput Data July 15, 2018 to July 21, 2018 (Score: 0.6523)
TSA Throughput Data August 20, 2017 to August 26, 2017 (Score: 0.6491)
FOIA Log August 2016 (Score: