In [25]:
from torch.utils.data import DataLoader
from sentence_transformers import models, losses, util, LoggingHandler, SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from sentence_transformers.readers import InputExample
from datetime import datetime
from zipfile import ZipFile
import logging
import csv
import sys
import torch
import math
import gzip
import os

In [26]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = 'bert-base-uncased'
batch_size = 16
num_epochs = 1
max_seq_length = 128
use_cuda = torch.cuda.is_available()

###### Read Datasets ######
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
qqp_dataset_path = 'quora-IR-dataset'


# Check if the STSb dataset exsist. If not, download and extract it
if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)


# Check if the QQP dataset exists. If not, download and extract
if not os.path.exists(qqp_dataset_path):
    logging.info("Dataset not found. Download")
    zip_save_path = 'quora-IR-dataset.zip'
    util.http_get(url='https://sbert.net/datasets/quora-IR-dataset.zip', path=zip_save_path)
    with ZipFile(zip_save_path, 'r') as zipIn:
        zipIn.extractall(qqp_dataset_path)


cross_encoder_path = 'output/cross-encoder/stsb_indomain_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
bi_encoder_path = 'output/bi-encoder/qqp_cross_domain_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

2022-04-04 14:27:24 - Dataset not found. Download


  0%|          | 0.00/93.6M [00:00<?, ?B/s]

In [None]:
###### Cross-encoder (simpletransformers) ######

#logging.info("Loading cross-encoder model: {}".format(model_name))
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model
#cross_encoder = CrossEncoder(model_name, num_labels=1)

###### Bi-encoder (sentence-transformers) ######

logging.info("Loading bi-encoder model: {}".format(model_name))

# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [None]:
###########################################################################
#
# Train bi-encoder (BERT) model with previously scored activity label pairs (this bi-encoder is the the final language model which can than be used as the basis for synonym detection applications)
#
###########################################################################

logging.info("Train bi-encoder: {} over labeled QQP (target dataset)".format(model_name))

# Convert the dataset to a DataLoader ready for training
logging.info("Loading BERT labeled QQP dataset")
to_be_scored = list(InputExample(texts=[data[0], data[1]], label=score) for (data, score) in zip(silver_data, binary_silver_scores))


train_dataloader = DataLoader(to_be_scored, shuffle=True, batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(bi_encoder)

###### Classification ######
# Given (activity label 1, activity label 2), is this a duplicate or not?
# The evaluator will compute the embeddings for both labels and then compute
# a cosine similarity. If the similarity is above a threshold, we have a duplicate.
logging.info("Read activity label dataset")

dev_sentences1 = []
dev_sentences2 = []
dev_labels = []

with open(os.path.join(to_be_scored, "classification/dev_pairs.tsv"), encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        dev_sentences1.append(row['question1'])
        dev_sentences2.append(row['question2'])
        dev_labels.append(int(row['is_duplicate']))

evaluator = BinaryClassificationEvaluator(dev_sentences1, dev_sentences2, dev_labels)

# Configure the training.
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

# Train the bi-encoder model
bi_encoder.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=bi_encoder_path
          )

In [62]:
###############################################################
#
# Evaluate performance on test sample 
# (using 15-fold cross-validation to queck performance on each of the available domains) 
# 
###############################################################

path = ''

# Loading the augmented sbert model 
bi_encoder = SentenceTransformer(bi_encoder_path)

logging.info("Read test dataset")
test_sentences1 = []
test_sentences2 = []
test_labels = []

with open(path, encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter=';', quoting=csv.QUOTE_NONE)
    for row in reader:
        test_sentences1.append(row['activity1'])
        test_sentences2.append(row['activity2'])
        test_labels.append(int(row['annotation']))

evaluator = BinaryClassificationEvaluator(test_sentences1, test_sentences2, test_labels)
bi_encoder.evaluate(evaluator)

2022-03-17 16:21:26 - Load pretrained SentenceTransformer: output/bi-encoder/qqp_cross_domain_bert-base-uncased-2022-03-17_15-25-31
2022-03-17 16:21:27 - Use pytorch device: cuda
2022-03-17 16:21:27 - Read QQP test dataset
2022-03-17 16:21:27 - Binary Accuracy Evaluation of the model on  dataset:
2022-03-17 16:21:27 - Accuracy with Cosine-Similarity:           78.08	(Threshold: 0.5953)
2022-03-17 16:21:27 - F1 with Cosine-Similarity:                 79.12	(Threshold: 0.4262)
2022-03-17 16:21:27 - Precision with Cosine-Similarity:          66.85
2022-03-17 16:21:27 - Recall with Cosine-Similarity:             96.90
2022-03-17 16:21:27 - Average Precision with Cosine-Similarity:  84.36

2022-03-17 16:21:27 - Accuracy with Manhatten-Distance:           80.25	(Threshold: 287.1740)
2022-03-17 16:21:27 - F1 with Manhatten-Distance:                 80.55	(Threshold: 297.9510)
2022-03-17 16:21:27 - Precision with Manhatten-Distance:          72.03
2022-03-17 16:21:27 - Recall with Manhatten-Di

0.8579594191354043