
<p class="h4 alert alert-danger">
    Please make sure that the pretrained model exists in the path "./models"
</p>


<p class="alert alert-danger">
Also make sure that the BERT server is running and has access to the pretrained model. We can run the BERT server by writing the following command in the command-line:
</p>
<code>> cd [ repo_dir ]/model_training/models</code>
<br />
<code>> bert-serving-start -model_dir ./multi_cased_L-12_H-768_A-12/ -num_worker=1 -show_tokens_to_client</code>


In [None]:
# import required libraries
import pandas as pd
import numpy as np
from pymongo import MongoClient
import bert_serving.client as bert
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# define params for db operations
uri = 'mongodb://localhost:27017/'
database = 'zs_database'
collection_fetch = 'autotags'
collection_push = 'similarities'

# initiate variables
df = pd.DataFrame()
db = object

# connect to db. TODO: Handle exception cases
client = MongoClient(uri)
db = client[database]

# retrieving required data
df = pd.DataFrame(list(db[collection_fetch].find({}, {"_id":0, "lemmas": 1, "story_id": 1})))

df

In [None]:
# create bert client
client = bert.BertClient(check_length=False)

# encode token-lemmas from our data into bert-vectors
vectors = client.encode(df['lemmas'].values.tolist(), show_tokens=False, is_tokenized=True)

# calculate cosine similarities for all vectors. This can take a while...
cos_sim = cosine_similarity(vectors)

cos_sim

In [None]:
# list of all story ids in order of retrieval
id_list = df['story_id'].values.tolist()

# check if push collection(autotags) already exists, if so, remove(drop) the collection for now
# TODO: handle exception
if collection_push in db.list_collection_names():
    collection = db[collection_push]
    if collection.estimated_document_count() != 0:
        print('Dropping the old collection (' + collection_push + ') ...')
        collection.drop()

collection = db[collection_push]

print("Top five similar story ids: ")

#iterate through the story_ids and find/store top 5 related story_ids
for i in range(0, len(id_list)):
    row = cos_sim[i]
    sort_five = np.argsort(-row)[:6]
    similar_story_ids = []
    for x in sort_five:
        # as each story would be completely similar to itself, we need to remove its id from the list
        if x!=i:
            similar_story_ids.append(id_list[x])
    print("For story id - " + str(id_list[i]) + ":", end=" ")
    print(similar_story_ids)
    
    # insert related story_ids to database
    # TODO: exception handling
    collection.insert_one({
        "story_id" : id_list[i],
        "related_story_id" : similar_story_ids 
    })