# Training the CrossEncoder

In [None]:
import sentence_transformers
from importlib import reload

reload(sentence_transformers)

In [None]:
import logging
from sentence_transformers import LoggingHandler, util
from tqdm.autonotebook import tqdm, trange

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)
#### /print debug information to stdout

In [None]:
import json
from datasets import load_dataset

##Dataset.remove_columns will prune columns better!!

datafiles = {'train':'training_dataset.jsonl', 'validation': 'eval_dataset.jsonl', 'test':'test_dataset.jsonl'}

dataset = load_dataset("json",
                       data_files=datafiles)

In [None]:
from sentence_transformers.readers import InputExample

train_samples = []
dev_samples = []
test_samples = []

for row in dataset['train']:
    train_samples.append(InputExample(texts=[row['text1'], row['text2']], label=row['label']))

for row in dataset['validation']:
    dev_samples.append(InputExample(texts=[row['text1'], row['text2']], label=row['label']))

for row in dataset['test']:
    test_samples.append(InputExample(texts=[row['text1'], row['text2']], label=row['label']))

In [None]:
from torch.utils.data import DataLoader

train_batch_size = 16

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size= train_batch_size)
dev_dataloader = DataLoader(dev_samples, shuffle=True, batch_size= train_batch_size)

In [None]:
from sentence_transformers.cross_encoder import CrossEncoder
from datetime import datetime

model_save_path = "output/training_allnli-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Define our CrossEncoder model. We use distilroberta-base as basis and setup it up to predict 3 labels
model = CrossEncoder("distilroberta-base",
                     num_labels=1)

In [None]:
import sentence_transformers.cross_encoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator, CEBinaryAccuracyEvaluator, CEF1Evaluator
from sentence_transformers.evaluation import SequentialEvaluator
import math

import sentence_transformers.evaluation

num_epochs = 5

# During training, we use various evaluators to measure the performance on the dev set
binary_clf_eval = CEBinaryClassificationEvaluator.from_input_examples(dev_samples)
binary_acc_eval = CEBinaryAccuracyEvaluator.from_input_examples(dev_samples)
binary_f1_eval = sentence_transformers.cross_encoder.evaluation.CEF1Evaluator.from_input_examples(dev_samples)

evaluator = SequentialEvaluator([binary_clf_eval, binary_acc_eval])

warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))

In [None]:

## There is a bug in the library - It won't be fixed until the major overhaul of CrossEncoder is completed
# https://github.com/UKPLab/sentence-transformers/issues/2737

## As a temp fix I have hashed out lines 495-498 in 'sentence-transformers/cross-encoder/CrossEncoder.py'

## Indentified a second bug where Argmax is called on alist with axis=1 which causes an error

# Train the model
model.fit(
    train_dataloader=train_dataloader,
    epochs=num_epochs,
    evaluation_steps=500,
    warmup_steps=warmup_steps,
    evaluator=evaluator,
    save_best_model=False
)

In [None]:
model.save('compat_matrix_model')

In [None]:
# model = CrossEncoder('compat_matrix_model')

In [None]:
def validate_prediction(row_number):
    pred = model.predict([dataset['test'][row_number]['text1'],
                          dataset['test'][row_number]['text2']])
    result = dataset['test'][row_number]['label']
    return pred, result

y_pred = []
y = []
for i in range(0,len(dataset['test'])):
    pred, result = validate_prediction(i)
    y_pred.append(pred)
    y.append(result)

In [None]:
normalised_ypred = []
for value in y_pred:

    if value < 0.5:
        normalised_ypred.append(0)
    elif value > 0.5:
        normalised_ypred.append(1)
    else:
        print(f'{value}: appending -1')
        normalised_ypred.append(-1)

In [None]:
import pandas as pd
pd.Series(normalised_ypred).unique()


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, normalised_ypred, labels=[0,1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=[0,1])

disp.plot()
plt.show()