In [None]:
"""
1. Selection of cross_encoder model (line.58)
2. "n" value (line.147)
3. "k" value (line.150)
"""

from torch.utils.data import DataLoader
from sentence_transformers import models, losses, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers import LoggingHandler, SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from elasticsearch import Elasticsearch
from datetime import datetime
import logging
import csv
import sys
import tqdm
import math
import gzip
import os
import random
import pandas as pd
import numpy


def get_random_n_pairs(List, n):
    pairs = []
    for _ in range(n):
       pair = random.sample(List, 2)
       pair[0] = str(pair[0])
       pair[1] = str(pair[1])
       pairs.append(pair)
    return pairs


def select_top_k(silver_data, silver_scores,k):
    top_k = [x + [y] for x, y in zip(silver_data, silver_scores)]
    top_k = sorted(top_k, key=lambda x: x[2], reverse=True)
    top_k = top_k[:k]
    removed = [l.pop(2) for l in top_k]
    return top_k, numpy.array(removed)


#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

# supressing INFO messages for elastic-search logger
tracer = logging.getLogger('elasticsearch')
tracer.setLevel(logging.CRITICAL)
es = Elasticsearch()

#You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
cross_encoder_model_name = 'bert-base-uncased'
bi_encoder_model_name = 'sentence-transformers/paraphrase-distilroberta-base-v1'

batch_size = 16
num_epochs = 1
max_seq_length = 128

###### Read Datasets ######

#Check if dataset exsist. If not, download and extract  it
Gold_dataset_path = 'gold_dataset.tsv'

#Cross and bi encoder path
cross_encoder_path = 'model/cross-encoder/'+cross_encoder_model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
bi_encoder_path = 'model/bi-encoder/'+bi_encoder_model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

###### Cross-encoder (simpletransformers) ######
logging.info("Loading sentence-transformers model: {}".format(cross_encoder_model_name))
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model
cross_encoder = CrossEncoder(cross_encoder_model_name, num_labels=1)


###### Bi-encoder (sentence-transformers) ######
logging.info("Loading bi-encoder model: {}".format(bi_encoder_model_name))
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(bi_encoder_model_name, max_seq_length=max_seq_length)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [None]:
#####################################################################
#
# Step 1: Train cross-encoder model with Gold dataset
#
#####################################################################

logging.info("Step 1: Train cross-encoder: ({}) with Gold dataset".format(cross_encoder_model_name))

gold_samples = []
dev_samples = []
test_samples = []

with open(Gold_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t')#, quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row['score'])   # Normalize score to range 0 ... 1

        if row['split'] == 'dev':
            dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
        elif row['split'] == 'test':
            test_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
        else:
            #As we want to get symmetric scores, i.e. CrossEncoder(A,B) = CrossEncoder(B,A), we pass both combinations to the train set
            gold_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
            gold_samples.append(InputExample(texts=[row['sentence2'], row['sentence1']], label=score))


# We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader
train_dataloader = DataLoader(gold_samples, shuffle=True, batch_size=batch_size)


# We add an evaluator, which evaluates the performance during training
evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='Gold-dev_split')

# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

# Train the cross-encoder model
cross_encoder.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          output_path=cross_encoder_path)

In [None]:
############################################################################
#
# Step 2: Label sampled unlabeled dataset(silver dataset) using cross-encoder model
#
############################################################################

# Reading CSV of unlabeled dataset
col_list = ['definitions']
df = pd.read_csv('all_entries.csv', sep=',',
                 usecols=col_list).replace('"', '', regex=True)

# "n" no.of sentence pairs to be generated from unlabeled dataset
n = 50000

 # Selecting top "k" pairs based on their similarity scores
k = 1000

silver_data = get_random_n_pairs(df['definitions'].values.tolist(), n)

logging.info("Number of generated pairs from unlabeled dataset: {}".format(len(silver_data)))
logging.info("Step 2.2: Label the generated pairs with cross-encoder: {}".format(cross_encoder_model_name))

cross_encoder = CrossEncoder(cross_encoder_path)
silver_scores = cross_encoder.predict(silver_data)

#logging.info("Number of silver pairs selected from the labeled pairs: {}".format(k))
#silver_data, silver_scores = select_top_k(silver_data, silver_scores, k)

# All model predictions should be between [0,1]
assert all(0.0 <= score <= 1.0 for score in silver_scores)

In [None]:
#################################################################################################
#
# Step 3: Train bi-encoder model with both gold + silver dataset - Augmented SBERT
#
#################################################################################################

logging.info("Step 3: Train bi-encoder: {} with gold + silver dataset".format(bi_encoder_model_name))

# Convert the dataset to a DataLoader ready for training
logging.info("Read gold and silver train dataset")
silver_samples = list(InputExample(texts=[data[0], data[1]], label=score) for data, score in zip(silver_data, silver_scores))


train_dataloader = DataLoader(gold_samples + silver_samples, shuffle=True, batch_size=batch_size)
train_loss = losses.CosineSimilarityLoss(model=bi_encoder)

logging.info("Read development dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')

# Configure the training.
warmup_steps = math.ceil(len(train_dataloader) * num_epochs  * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

# Train the bi-encoder model
bi_encoder.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=bi_encoder_path
          )

In [None]:
######################################################################
#
# Evaluate Augmented SBERT performance on STS benchmark (test) dataset
#
######################################################################

# load the stored augmented-sbert model
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-distilroberta-base-v1')
logging.info("Read test dataset")
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='Gold-test_split')
test_evaluator(bi_encoder, output_path=bi_encoder_path)