# Notebook 5: Model Refinement

The purpose of this model is to develop the BERT model from Notebook 4 further and then to simplify it for use in the flask app. Code and approach inspired by this [blog](https://towardsdatascience.com/calculating-document-similarities-using-bert-and-other-models-b2c1a29c9630) and with help from Caroline.

In [1]:
import pandas as pd
import numpy as np
import gensim

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import euclidean_distances
from sentence_transformers import SentenceTransformer

## Refining and simplifying BERT Model 

In [6]:
#Reading in dataframe
df = pd.read_csv('./sutta_csv/cleaned/df_all_prep.csv')

In [7]:
#Instantiate model
sbert_model = SentenceTransformer('bert-base-nli-mean-tokens')

In [9]:
# Generate doc embeddings and save to dataframe

document_embeddings = sbert_model.encode(df['text_full'], show_progress_bar = False)
embeddings_df = pd.DataFrame(document_embeddings)
embeddings_df.index = df['ref'].values

In [10]:
#Generate pairwise similarities
pairwise_similarities = cosine_similarity(document_embeddings)
pairwise_differences = euclidean_distances(document_embeddings)

In [None]:
#Exporting pairwise similarities for each of access and conservation of computational resource
# with open('./data/pairwise_similarities.npy', 'wb') as f:
#     np.save(f, pairwise_similarities)

# with open('./data/pairwise_differences.npy', 'wb') as f:
#     np.save(f, pairwise_differences)

In [11]:
#Function using only cosine similarities from the BERT model document embeddings
def most_similar(ref):
  
    df['ref'] = df['ref'].str.replace('\u2009', ' ')
   
    with open('./data/pairwise_similarities.npy', 'rb') as f:
        pairwise_similarities = np.load(f)
    
    def get_ix(ref):
        return df[df['ref'] == ref].index[0]

    doc_ix = get_ix(ref)
    
    similar_ix = np.argsort(pairwise_similarities[doc_ix])[::-1][1:]
    
    for ix in similar_ix[:5]:
        if ix == doc_ix:
            pass
        else:
            print(f'Title: {df.iloc[ix]["title"]}')
            print(f'URL: {df.iloc[ix]["title_url"]}')

In [12]:
most_similar('MN 1')

Title: MN 49  Brahma-nimantanika Sutta | The Brahmā Invitation
URL: https://www.dhammatalks.org/suttas/MN/MN49.html
Title: MN 113  Sappurisa Sutta | A Person of Integrity
URL: https://www.dhammatalks.org/suttas/MN/MN113.html
Title: MN 117  Mahā Cattārīsaka Sutta | The Great Forty
URL: https://www.dhammatalks.org/suttas/MN/MN117.html
Title: MN 138  Uddesa-vibhaṅga Sutta | An Analysis of the Statement
URL: https://www.dhammatalks.org/suttas/MN/MN138.html
Title: MN 131  Bhaddekaratta Sutta | An Auspicious Day
URL: https://www.dhammatalks.org/suttas/MN/MN131.html
