#Investigating general sentence embedding model for semantic textual similarity

#Setup

In [None]:
# get datasets
%cd /tmp
!gdown https://drive.google.com/uc?id=1J7Ptu7aKUw0IfjHAKjj--CUFktdaQ7qf
!unzip 'datasets.zip'

/tmp
Downloading...
From: https://drive.google.com/uc?id=1J7Ptu7aKUw0IfjHAKjj--CUFktdaQ7qf
To: /tmp/datasets.zip
4.64MB [00:00, 21.7MB/s]
Archive:  datasets.zip
   creating: datasets/
   creating: datasets/dscs/
  inflating: datasets/dscs/dscs_dataset.tsv  
   creating: datasets/PIT/
  inflating: datasets/PIT/test.label  
  inflating: datasets/PIT/test.data  
   creating: datasets/biosses/
  inflating: datasets/biosses/scores_f.txt  
  inflating: datasets/biosses/pairs_f.txt  
   creating: datasets/opusparcus/
  inflating: datasets/opusparcus/opusparcus.txt  
   creating: datasets/parade/
  inflating: datasets/parade/test.txt  
   creating: datasets/TURL/
  inflating: datasets/TURL/turl_test.txt  
   creating: datasets/MRPC/
  inflating: datasets/MRPC/msr_paraphrase_test.txt  
  inflating: datasets/MRPC/msr_paraphrase_train.txt  
   creating: datasets/paws/
   creating: datasets/paws/wiki/
  inflating: datasets/paws/wiki/test.tsv  
   creating: datasets/SentEval/
   creating: datasets/

In [None]:
# install packages
!pip install git+https://github.com/facebookresearch/SentEval
!pip install sentence_transformers

Collecting git+https://github.com/facebookresearch/SentEval
  Cloning https://github.com/facebookresearch/SentEval to ./pip-req-build-r25970i2
  Running command git clone -q https://github.com/facebookresearch/SentEval /tmp/pip-req-build-r25970i2
Building wheels for collected packages: SentEval
  Building wheel for SentEval (setup.py) ... [?25l[?25hdone
  Created wheel for SentEval: filename=SentEval-0.1.0-py3-none-any.whl size=34996 sha256=be76c908af1573def6b14e55c7ae16ddb235b97ac43d974919df613f714815c6
  Stored in directory: /tmp/pip-ephem-wheel-cache-5p21vks_/wheels/f2/30/24/158bcbeb1361691b11f52434aec28432627e4a6ae1dd00dfb7
Successfully built SentEval
Installing collected packages: SentEval
Successfully installed SentEval-0.1.0
Collecting sentence_transformers
  Downloading sentence-transformers-2.0.0.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 3.5 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.9.2-py3-none-any.whl (2.6 MB)
[K 

In [None]:
# import packages
import senteval
from sentence_transformers import SentenceTransformer

from absl import logging
import tensorflow as tf
import tensorflow_hub as hub
from scipy.stats import pearsonr
from sklearn.metrics import accuracy_score, average_precision_score
from scipy import spatial
from matplotlib import pyplot as plt
import numpy as np
from abc import ABC, abstractmethod

# Reduce logging output.
logging.set_verbosity(logging.ERROR)

#Common Code

In [None]:
#@title Generics

class AbstractModel(ABC):
  model = None

  @abstractmethod
  def encode(batch):
    pass


class SentenceTransformerWrapper(AbstractModel):

  def __init__(self, model_name: str):
    self.model = SentenceTransformer(model_name)

  def encode(self, batch):
    return self.model.encode(batch)


class USEWrapper(AbstractModel):

  def __init__(self, model_name: str):
    self.model = hub.load(model_name)

  def encode(self, batch):
    return np.array(self.model(batch))


class ModelFactory:
  
  @staticmethod
  def create(model_name):
    if 'universal-sentence-encoder' in model_name:
      return USEWrapper(model_name)
    else:
      return SentenceTransformerWrapper(model_name)


In [None]:
#@title Utility functions

"""
Note that `spatial.distance.cosine` computes the distance between two vectors, 
not the similarity, so you must subtract the value from 1 to get the similarity.
"""
def cosine_similarity(first_vector, second_vector):
  return 1 - spatial.distance.cosine(first_vector, second_vector)


def batch_generator(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]


def read_data(params):
  tsv_file = open(params['path_to_data'], 'r')
  lines = tsv_file.readlines()

  if params['remove_header']:
    lines.pop(0)
  lines = [line.split('\t') for line in lines]

  labels = [float(line[params['label_index']].strip()) for line in lines]
  if params['label_type'] == 'int':
    labels = [int(label) for label in labels]

  sentences = [[line[params['first_sentence_index']].strip(), line[params['second_sentence_index']].strip()] for line in lines]

  return sentences, labels

In [None]:
#@title Evaluate functions

def calculate_similarity(model_name, sentences, batch_size):
  predictions = []
  embed = ModelFactory.create(model_name)
  
  for batch in batch_generator(sentences, batch_size):
    first_sentences = [sample[0] for sample in batch]
    second_sentences = [sample[1] for sample in batch]

    first_sent_embeddings = embed.encode(first_sentences)
    second_sent_embeddings = embed.encode(second_sentences)

    for i in range(len(first_sent_embeddings)):
      predictions.append(cosine_similarity(first_sent_embeddings[i], second_sent_embeddings[i]))
  
  return predictions


def concatenate_embeddings(embeds, first_sentences, second_sentences):
    first_sent_embeddings_list = []
    second_sent_embeddings_list = []
    for embed in embeds:
      first_sent_embeddings_list.append(embed.encode(first_sentences))
      second_sent_embeddings_list.append(embed.encode(second_sentences))

    if len(embeds) > 1:
      first_sent_embeddings = np.concatenate((first_sent_embeddings_list[0], first_sent_embeddings_list[1]), axis=1)
      second_sent_embeddings = np.concatenate((second_sent_embeddings_list[0], second_sent_embeddings_list[1]), axis=1)
      for i in range(2, len(embeds)):
        first_sent_embeddings = np.concatenate((first_sent_embeddings, first_sent_embeddings_list[i]), axis=1)
        second_sent_embeddings = np.concatenate((second_sent_embeddings, second_sent_embeddings_list[i]), axis=1)
    else:
      first_sent_embeddings = first_sent_embeddings_list[0]
      second_sent_embeddings = second_sent_embeddings_list[0]

    return first_sent_embeddings, second_sent_embeddings


def calculate_similarity_ansamble(models, sentences, batch_size):
  predictions = []
  embeds = [ModelFactory.create(model) for model in models]
  for batch in batch_generator(sentences, batch_size):
    first_sentences = [sample[0] for sample in batch]
    second_sentences = [sample[1] for sample in batch]

    first_sent_embeddings, second_sent_embeddings = concatenate_embeddings(embeds, first_sentences, second_sentences)

    for i in range(len(first_sent_embeddings)):
      predictions.append(cosine_similarity(first_sent_embeddings[i], second_sent_embeddings[i]))
  
  return predictions


def get_predictions(model, sentences, batch_size, is_ansamble):
  if is_ansamble:
    predictions = calculate_similarity_ansamble(model, sentences, batch_size)
  else:
    predictions = calculate_similarity(model, sentences, batch_size)
  return predictions


def get_performance(true_labels, predictions, metric):
  if metric == 'pearson':
    corr, _ = pearsonr(true_labels, predictions)
    return corr
  elif metric == 'average_precision':
    return average_precision_score(true_labels, predictions)
  else:
    predictions = [round(prediction) for prediction in predictions]
    return accuracy_score(true_labels, predictions)


def evaluate(model, params, is_ansamble):
  sentences, labels = read_data(params['data'])
  predictions = get_predictions(model, sentences, params['batch_size'], is_ansamble)
  return round(get_performance(labels, predictions, params['metric']), 4)

#Evaluate Models against Datasets

##Prepare Datasets

In [None]:
#@title Dataset Params

stsb_params = {
    'data': {
      'path_to_data': '/tmp/datasets/SentEval/downstream/STS/STSBenchmark/sts-test.csv', 
      'label_type': 'float',
      'label_index': 4, 
      'first_sentence_index': 5, 
      'second_sentence_index': 6,
      'remove_header': False
    },
    'batch_size': 512,
    'metric': 'pearson'
}

mrpc_params = {
    'data': {
      'path_to_data': '/tmp/datasets/MRPC/msr_paraphrase_test.txt', 
      'label_type': 'int',
      'label_index': 0, 
      'first_sentence_index': 3, 
      'second_sentence_index': 4,
      'remove_header': True
    },
    'batch_size': 512,
    'metric': 'accuracy'
}

dscs_params = {
    'data': {
      'path_to_data': '/tmp/datasets/dscs/dscs_dataset.tsv', 
      'label_type': 'float',
      'label_index': 2, 
      'first_sentence_index': 0, 
      'second_sentence_index': 1,
      'remove_header': False
    },
    'batch_size': 512,
    'metric': 'pearson'
}

paws_params = {
    'data': {
      'path_to_data': '/tmp/datasets/paws/wiki/test.tsv', 
      'label_type': 'int',
      'label_index': 3, 
      'first_sentence_index': 1, 
      'second_sentence_index': 2,
      'remove_header': True
    },
    'batch_size': 512,
    'metric': 'accuracy'
}

parade_params = {
    'data': {
      'path_to_data': '/tmp/datasets/parade/test.txt', 
      'label_type': 'int',
      'label_index': 1, 
      'first_sentence_index': 3, 
      'second_sentence_index': 4,
      'remove_header': True
    },
    'batch_size': 512,
    'metric': 'accuracy'
}

opusparcus_params = {
    'data': {
      'path_to_data': '/tmp/datasets/opusparcus/opusparcus.txt', 
      'label_type': 'int',
      'label_index': 3, 
      'first_sentence_index': 1, 
      'second_sentence_index': 2,
      'remove_header': False
    },
    'batch_size': 512,
    'metric': 'accuracy'
}

pit_params = {
    'data': {
      'path_to_data': '/tmp/datasets/PIT/test.data', 
      'label_type': 'int',
      'label_index': 4, 
      'first_sentence_index': 2, 
      'second_sentence_index': 3,
      'remove_header': False
    },
    'batch_size': 512,
    'metric': 'pearson'
}

In [None]:
#@title Prepare SentEval

def evaluate_senteval(model, is_ansamble):

    # SentEval prepare and batcher
    def prepare(params, samples):
        return

    def batcher(params, batch):
        batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
        if is_ansamble:
          embeddings = get_embeddings_conc(params['model'], batch)
        else:
          embeddings = params['model'].encode(batch)
        return embeddings

    def get_embeddings_conc(model, batch):
        embeddings_list = []
        for embed in model:
          embeddings_list.append(embed.encode(batch))

        embeddings = np.concatenate((embeddings_list[0], embeddings_list[1]), axis=1)
        for i in range(2, len(models)):
          embeddings = np.concatenate((embeddings, embeddings_list[i]), axis=1)

        return embeddings

    if is_ansamble:
        encoder = [ModelFactory.create(model_name) for model_name in model]
    else:
        encoder = ModelFactory.create(model)

    # Set params for SentEval
    params_senteval = {
        'task_path': '/tmp/datasets/SentEval', 
        'batch_size': 512,
        'model': encoder }

    se = senteval.engine.SE(params_senteval, batcher, prepare)
    transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'SICKRelatedness']
    return se.eval(transfer_tasks)

In [None]:
#@title Prepare Biosses

def read_biosses():
    pairs_f = open('/tmp/datasets/biosses/pairs_f.txt', 'r')
    scores_f = open('/tmp/datasets/biosses/scores_f.txt', 'r')

    lines = scores_f.readlines()
    lines = [line.strip().split('\t') for line in lines]
    lines = [line[1:] for line in lines]
    lines = [[int(l) for l in line] for line in lines]
    labels = [round(sum(line)/len(line), 1) for line in lines]

    lines = pairs_f.readlines()
    lines = [line.strip().split('\t') for line in lines]
    sentences = [[line[1], line[2]] for line in lines]

    return sentences, labels


def evaluate_biosses(model, is_ansamble):
    params = {
        'batch_size': 512,
        'metric': 'pearson'
    }

    sentences, labels = read_biosses()
    predictions = get_predictions(model, sentences, params['batch_size'], is_ansamble)
    return round(get_performance(labels, predictions, params['metric']), 4)

In [None]:
#@title Prepare TURL

def fix_labels(sentences, labels):
    final_sentences = []
    final_labels = []

    for i, l in enumerate(labels):
      if l == 3:
        continue
      if l <= 2:
        final_labels.append(0)
      else:
        final_labels.append(1)

      final_sentences.append(sentences[i])

    return final_sentences, final_labels


def evaluate_turl(model, is_ansamble):
    params = {
      'batch_size': 512,
      'metric': 'average_precision'
    }
    tsv_file = open('/tmp/datasets/TURL/turl_test.txt', 'r')
    lines = tsv_file.readlines()
    lines = [line.strip().split('\t') for line in lines]
    temp_labels = [int(line[2].split(',')[0][1:]) for line in lines]
    temp_sentences = [[line[0].strip(), line[1].strip()] for line in lines]

    sentences, labels = fix_labels(temp_sentences, temp_labels)
    predictions = get_predictions(model, sentences, params['batch_size'], is_ansamble)
    return round(get_performance(labels, predictions, params['metric']), 4)

In [None]:
#@title Prepare Opusparcus

def evaluate_opusparcus(model, is_ansamble):
    sentences, labels = read_data(opusparcus_params['data'])
    labels = [1 if label >= 3 else 0 for label in labels]
    predictions = get_predictions(model, sentences, opusparcus_params['batch_size'], is_ansamble)
    return round(get_performance(labels, predictions, opusparcus_params['metric']), 4)


##Explanation

In [None]:
#@markdown ####sentence transformer models: https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
#@markdown ####universal sentence encoder model: https://tfhub.dev/google/universal-sentence-encoder-large/5
#@markdown ---
#@markdown ####use = universal sentence encoder large v5 
#@markdown ####nli = nli-mpnet-base-v2
#@markdown ####para = paraphrase-mpnet-base-v2
#@markdown ####stsb = stsb-mpnet-base-v2
#@markdown ####quora = quora-distilbert-base
#@markdown ---
#@markdown ####ansamble_1 = nli + para + stsb
#@markdown ####ansamble_2 = use + para + stsb
#@markdown ####ansamble_3 = use + quora + stsb
#@markdown ####ansamble_4 = para + quora + stsb
#@markdown ####ansamble_5 = nli + para + use
#@markdown ####ansamble_6 = all single models
#@markdown ---
#@markdown ####SentEval = SemEval(STS12-16) + SICK-R
#@markdown #####Use some json parser tool to see SentEval results e.g. https://jsonparseronline.com/
#@markdown ---

##Evaluate

In [None]:
#@markdown ###Choose models:
use = False #@param {type:"boolean"}
nli = True #@param {type:"boolean"}
para = False #@param {type:"boolean"}
stsb = False #@param {type:"boolean"}
quora = False #@param {type:"boolean"}
ansamble_1 = False #@param {type:"boolean"}
ansamble_2 = False #@param {type:"boolean"}
ansamble_3 = False #@param {type:"boolean"}
ansamble_4 = False #@param {type:"boolean"}
ansamble_5 = False #@param {type:"boolean"}
ansamble_6 = False #@param {type:"boolean"}
#@markdown ---
#@markdown ###Choose datasets:
SentEval = False #@param {type:"boolean"}
STSb = False #@param {type:"boolean"}
MRPC = False #@param {type:"boolean"}
DSCS = False #@param {type:"boolean"}
Biosses = False #@param {type:"boolean"}
PAWS = False #@param {type:"boolean"}
PARADE = True #@param {type:"boolean"}
Pit = False #@param {type:"boolean"}
TURL = False #@param {type:"boolean"}
opusparcus = False #@param {type:"boolean"}


def get_config():
    models = []

    if use:
      models.append(['use', 'https://tfhub.dev/google/universal-sentence-encoder-large/5'])
    if nli:
      models.append(['nli', 'nli-mpnet-base-v2'])
    if para:
      models.append(['para', 'paraphrase-mpnet-base-v2'])
    if stsb:
      models.append(['stsb', 'stsb-mpnet-base-v2'])
    if quora:
      models.append(['quora', 'quora-distilbert-base'])
    if ansamble_1:
      models.append(['ansamble_1', ['nli-mpnet-base-v2', 'paraphrase-mpnet-base-v2', 'stsb-mpnet-base-v2']])
    if ansamble_2:
      models.append(['ansamble_2', ['paraphrase-mpnet-base-v2', 'stsb-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5']])
    if ansamble_3:
      models.append(['ansamble_3', ['quora-distilbert-base', 'stsb-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5']])
    if ansamble_4:
      models.append(['ansamble_4', ['quora-distilbert-base', 'stsb-mpnet-base-v2', 'paraphrase-mpnet-base-v2']])
    if ansamble_5:
      models.append(['ansamble_5', ['nli-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5', 'paraphrase-mpnet-base-v2']])
    if ansamble_6:
      models.append(['ansamble_6', ['nli-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5', 'paraphrase-mpnet-base-v2', 'quora-distilbert-base', 'stsb-mpnet-base-v2']])

    datasets = []

    if SentEval:
      datasets.append(['SentEval'])
    if STSb:
      datasets.append(['STSb', stsb_params])
    if MRPC:
      datasets.append(['MRPC', mrpc_params])
    if DSCS:
      datasets.append(['DSCS', dscs_params])
    if Biosses:
      datasets.append(['Biosses'])
    if PAWS:
      datasets.append(['PAWS', paws_params])
    if PARADE:
      datasets.append(['PARADE', parade_params])
    if Pit:
      datasets.append(['Pit', pit_params])
    if TURL:
      datasets.append(['TURL'])
    if opusparcus:
      datasets.append(['opusparcus'])

    return models, datasets


In [None]:
#@title Evaluate!

def pretty_print_model(model_name):
    print('=======================================')
    print('***', model_name, '***')


def pretty_print_dataset(dataset_name, result):
    print('---------------------------------------')
    print(dataset_name, '>>', result)


def print_results(alias, model_name, is_ansamble):
    pretty_print_model(alias)
    for dataset in datasets:
        if len(dataset) == 2:
            pretty_print_dataset(dataset[0], evaluate(model_name, dataset[1], is_ansamble))
        else:
            if dataset[0] == 'Biosses':
                pretty_print_dataset(dataset[0], evaluate_biosses(model_name, is_ansamble))
            elif dataset[0] == 'TURL':
                pretty_print_dataset(dataset[0], evaluate_turl(model_name, is_ansamble))
            elif dataset[0] == 'SentEval':
                pretty_print_dataset(dataset[0], evaluate_senteval(model_name, is_ansamble))
            elif dataset[0] == 'opusparcus':
                pretty_print_dataset(dataset[0], evaluate_opusparcus(model_name, is_ansamble))

models, datasets = get_config()
if len(models) == 0 or len(datasets) == 0:
  print('You must choose at least one model and one dataset!')
else:
  for model in models:
    alias = model[0]
    model_name = model[1]
    if type(model_name) is str:
        print_results(alias, model_name, is_ansamble=False)
    else:
        print_results(alias, model_name, is_ansamble=True)

*** use ***
---------------------------------------
PARADE >> 0.7237
*** nli ***
---------------------------------------
PARADE >> 0.6013
*** para ***


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=690.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3699.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=594.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=229.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438022897.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=53.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=239.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466166.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1193.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231536.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…


---------------------------------------
PARADE >> 0.6352
*** stsb ***


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=868.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3668.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=588.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=229.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438022897.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=52.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=239.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466166.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1187.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231536.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…


---------------------------------------
PARADE >> 0.6662
*** quora ***


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=345.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3689.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=540.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=229.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265486777.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=53.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466081.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=490.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…


---------------------------------------
PARADE >> 0.5733
*** ansamble_1 ***
---------------------------------------
PARADE >> 0.6279
*** ansamble_2 ***
---------------------------------------
PARADE >> 0.6581
*** ansamble_3 ***
---------------------------------------
PARADE >> 0.5859
*** ansamble_4 ***
---------------------------------------
PARADE >> 0.5881
*** ansamble_5 ***
---------------------------------------
PARADE >> 0.6197
*** ansamble_6 ***
---------------------------------------
PARADE >> 0.5932


# Demo

In [None]:
#@title Demo time!
model_name = "nli" #@param ["use", "nli", "para", "stsb", "quora", "ansamble_1", "ansamble_2", "ansamble_3", "ansamble_4", "ansamble_5", "ansamble_6"]
sentence1 = "I do not enjoy to play basketball" #@param {type:"string"}
sentence2 = "basketball is my favourite game" #@param {type:"string"}

if model_name == 'use':
    model = 'https://tfhub.dev/google/universal-sentence-encoder-large/5'
if model_name == 'nli':
    model = 'nli-mpnet-base-v2'
if model_name == 'para':
    model = 'paraphrase-mpnet-base-v2'
if model_name == 'stsb':
    model = 'stsb-mpnet-base-v2'
if model_name == 'quora':
    model = 'quora-distilbert-base'
if model_name == 'ansamble_1':
    model = ['nli-mpnet-base-v2', 'paraphrase-mpnet-base-v2', 'stsb-mpnet-base-v2']
if model_name == 'ansamble_2':
    model = ['paraphrase-mpnet-base-v2', 'stsb-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5']
if model_name == 'ansamble_3':
    model = ['quora-distilbert-base', 'stsb-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5']
if model_name == 'ansamble_4':
    model = ['quora-distilbert-base', 'stsb-mpnet-base-v2', 'paraphrase-mpnet-base-v2']
if model_name == 'ansamble_5':
    model = ['nli-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5', 'paraphrase-mpnet-base-v2']
if model_name == 'ansamble_6':
    model = ['nli-mpnet-base-v2', 'https://tfhub.dev/google/universal-sentence-encoder-large/5', 'paraphrase-mpnet-base-v2', 'quora-distilbert-base', 'stsb-mpnet-base-v2']

is_ansamble = True if 'ansamble' in model_name else False

predictions = get_predictions(model, [[sentence1, sentence2]], 512, is_ansamble)
prediction = round(predictions[0]*100, 2)
print('Similarity >> ', prediction, '%')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=690.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3663.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=587.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=229.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438022897.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=52.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=239.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466166.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1186.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231536.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…


Similarity >>  43.55 %
