In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!pip3 install -U sentence-transformers



In [None]:
import torch
import pandas as pd
from tqdm import tqdm
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer, util
import pickle
import io
import os
import numpy
from sklearn.metrics import accuracy_score
import re
import time

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [5]:
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

class SBERT_WikiSimilarity() : 
    def __init__(self, dataset_name=None, save=False, saved_embeddings_path = None):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.multilingual_dataset, self.dataset = self.load_dataset(dataset_name)
        self.wiki_languages = ['en','ar','es','de','ja','fr','fa','pl']
        if saved_embeddings_path == None:
            self.model = SentenceTransformer('distiluse-base-multilingual-cased')
            self.model.max_seq_length = self.model.max_seq_length 
            self.pages_embeddings = self.compute_pages_embeddings()
            if save==True:
                self.save_embeddings_in_file(self.pages_embeddings, dataset_name)
                with open(dataset_name+'_embeddings.pkl', "wb") as fOut:
                    pickle.dump(self.pages_embeddings, fOut, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            if self.device == "cpu":
                with open(saved_embeddings_path+".pkl", "rb") as f:
                    self.pages_embeddings = CPU_Unpickler(f).load()
                #self.pages_embeddings_csv = pd.read_csv("multilingual_"+saved_embeddings_path+".csv")
            else:
                self.pages_embeddings = pd.read_pickle(saved_embeddings_path+".pkl") 

        self.results = self.get_similarity_scores()

    def load_dataset(self, dataset_name):
        multilingual_data = pd.read_csv("/content/gdrive/MyDrive/Dataset/multilingual_"+dataset_name+".csv") 
        if dataset_name == "wikipediaSimilarity353":
            data = pd.read_csv("wikipediaSimilarity353.csv")
            data['titleA'] = data['titleA'].replace(['Production, costs, and pricing'],'Production')#no wikipedia page for 'Production, costs, and pricing'
        elif dataset_name == "WikiSRS_relatedness" or dataset_name == "WikiSRS_similarity":
            data = pd.read_csv(dataset_name+".csv", sep='\t') 
            data = data.drop(['RawScores', 'StdDev'], axis = 1)
            data.rename(columns = {'Term1':'termA', 'String1':'titleA',
                                   'Term2':'termB', 'String2':'titleB',
                                   'Mean' :'relatedness'}, inplace = True)
            data['relatedness'] = data['relatedness'].div(10)
        return multilingual_data, data

    def save_embeddings_in_file(self, dataset, dataset_name):
        data = pd.DataFrame(dataset)
        data.to_csv('multilingual_'+dataset_name+'_embeddings.csv')
        return
    
    def split_sentence(self,sentence):
        if len(sentence.split()) > 200:
            s = []
            s.append(' '.join(sentence.split()[:200]))
            s.append(' '.join(self.split_sentence(' '.join(sentence.split()[200:]))))
            return s
        else: 
            return [sentence]

    def compute_pages_embeddings(self):
        embeddings = []
        for id, row in tqdm(self.multilingual_dataset.iterrows(), desc="compute page's Embeddings: "):
            sentences = [row.title]
            for sentence in sent_tokenize(row.content):
                s = self.split_sentence(sentence)
                sentences.extend(s)
            embeddings_list = self.model.encode(sentences)
            embeddings.append({"title": row.title, "lang": row.lang, "embedding": torch.mean(self.model.encode(sentences, convert_to_tensor=True), dim=0)}) 
        return embeddings

    def get_similarity_scores(self): 
        results = []
        for idx,row in self.dataset.iterrows():
            for lang in self.wiki_languages:
                try :
                    embedA = next(item for item in self.pages_embeddings if item["title"] == row['titleA'] and item['lang'] == lang)['embedding']
                    embedB = next(item for item in self.pages_embeddings if item["title"] == row['titleB'] and item['lang'] == lang)['embedding']
                except:
                    continue
                cosine_score = util.pytorch_cos_sim(embedA, embedB).item()
                results.append({"titleA":row['titleA'],"titleB":row['titleB'],"lang":lang,"predicted" : (cosine_score+1) / 2.0, "actual" : row['relatedness']/10.0})
            #true_values.append(row['relatedness']/10.0)
        #print("\nAverage of true values = {}".format((sum(true_values)/len(true_values))))
        #print("Average of model results = {}".format((sum(results)/len(results))))
        #print("MSE = {}".format(mean_squared_error(true_values, results)))
        return results

In [None]:
#M1 = SBERT_WikiSimilarity(dataset_name="wikipediaSimilarity353", save=True) #1208 wikipage
#M2 = SBERT_WikiSimilarity(dataset_name="WikiSRS_similarity", save=True)
#M3 = SBERT_WikiSimilarity(dataset_name="WikiSRS_relatedness", save=True)

In [None]:
M1 = SBERT_WikiSimilarity(dataset_name="wikipediaSimilarity353", saved_embeddings_path="wikipediaSimilarity353_embeddings")
M2 = SBERT_WikiSimilarity(dataset_name="WikiSRS_similarity", saved_embeddings_path="WikiSRS_similarity_embeddings")
M3 = SBERT_WikiSimilarity(dataset_name="WikiSRS_relatedness", saved_embeddings_path="WikiSRS_relatedness_embeddings")

In [None]:
res1 = pd.DataFrame(M1.results)
res1.head(50)

Unnamed: 0,titleA,titleB,lang,predicted,actual
0,Love,Sexual intercourse,en,0.714859,0.677
1,Love,Sexual intercourse,ar,0.759976,0.677
2,Love,Sexual intercourse,es,0.759873,0.677
3,Love,Sexual intercourse,de,0.770995,0.677
4,Love,Sexual intercourse,ja,0.645355,0.677
5,Love,Sexual intercourse,fr,0.779876,0.677
6,Love,Sexual intercourse,fa,0.727255,0.677
7,Love,Sexual intercourse,pl,0.810451,0.677
8,Tiger,Cat,en,0.820682,0.735
9,Tiger,Cat,ar,0.772624,0.735


In [None]:
res2 = pd.DataFrame(M2.results)
res2.head(50)

Unnamed: 0,titleA,titleB,lang,predicted,actual
0,Ferrari,Lamborghini,en,0.776059,0.976
1,Ferrari,Lamborghini,ar,0.882602,0.976
2,Ferrari,Lamborghini,es,0.927886,0.976
3,Ferrari,Lamborghini,de,0.90672,0.976
4,Ferrari,Lamborghini,ja,0.82768,0.976
5,Ferrari,Lamborghini,fr,0.932222,0.976
6,Ferrari,Lamborghini,fa,0.886546,0.976
7,Ferrari,Lamborghini,pl,0.895754,0.976
8,River Kent,River Thames,en,0.854385,0.936
9,Ronald Reagan,Barack Obama,en,0.818703,0.935


In [None]:
res3 = pd.DataFrame(M3.results)
res3.head(50)

Unnamed: 0,titleA,titleB,lang,predicted,actual
0,Vladimir Putin,Moscow,en,0.73147,0.988333
1,Vladimir Putin,Moscow,ar,0.733321,0.988333
2,Vladimir Putin,Moscow,es,0.734599,0.988333
3,Vladimir Putin,Moscow,de,0.786782,0.988333
4,Vladimir Putin,Moscow,ja,0.676469,0.988333
5,Vladimir Putin,Moscow,fr,0.749125,0.988333
6,Vladimir Putin,Moscow,fa,0.712074,0.988333
7,Vladimir Putin,Moscow,pl,0.752692,0.988333
8,Asia,China,en,0.776775,0.974
9,Asia,China,ar,0.767197,0.974


In [None]:
next(item for item in M2.pages_embeddings if item["title"] == "Ferrari" and item['lang'] == "ar")['embedding']

tensor([ 8.6102e-03, -1.2887e-02, -2.0471e-02,  4.8995e-04, -1.0507e-07,
        -1.9774e-02, -7.8557e-03,  1.6040e-02, -2.6895e-03, -1.3566e-02,
         1.3331e-04, -1.8660e-02, -2.0423e-02, -4.7293e-02, -9.4456e-03,
         1.9522e-02, -5.9723e-03,  6.7705e-05,  2.4307e-03,  1.3704e-02,
        -4.7827e-03, -1.0307e-02,  2.6249e-02,  1.3863e-02,  1.2934e-02,
         3.1949e-02,  1.1624e-02,  7.1191e-03, -8.6954e-03, -9.2462e-04,
         1.9326e-02, -3.7833e-03, -1.5200e-02, -1.7077e-02,  8.4244e-04,
         1.9544e-03,  1.5087e-04, -3.7390e-03,  7.9802e-03, -2.9853e-02,
        -1.2353e-02, -1.2938e-02,  3.0935e-02, -8.0540e-03,  2.2373e-02,
        -1.3837e-02,  8.5626e-03, -2.9976e-03, -1.6498e-02, -1.2389e-02,
         1.8140e-03,  2.6158e-02, -1.1838e-02,  6.4189e-05, -2.8256e-03,
         4.0803e-03, -1.6233e-02,  1.6566e-03,  2.5857e-02,  1.6329e-02,
        -2.0254e-02,  1.5187e-03,  1.8141e-02, -1.6628e-03, -1.0292e-02,
        -2.2492e-02, -2.7215e-02,  5.0040e-03, -2.2