In [None]:
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from datetime import datetime
from collections import defaultdict
from torch.utils.data import IterableDataset
from tqdm.notebook import tqdm

import numpy as np

import json
import time
import torch
import os
import logging

In [None]:
train_batch_size = 256
model_name = 'cross-encoder/ms-marco-TinyBERT-L-6'
model_save_path = 'models/crenc-exp7'

In [None]:
def get_triplets(Passage_dict):
    triplets = []
    for k, v in Passage_dict.items():
        for x in v[0]:
            for y in v[1]:
                triplets.append([k, x, y])

    return triplets

def get_dataset(triplets, corpus):
    dataset = []        
    for triplet in triplets:
        qid, pos_id, neg_id = triplet
        
        qid = str(qid)
        pos_id = str(pos_id)
        neg_id = str(neg_id)

        query_text = corpus[qid]
        pos_text = corpus[pos_id]
        neg_text = corpus[neg_id]

        pos_instance = InputExample(texts=[query_text, pos_text],label=1)
        neg_instance = InputExample(texts=[query_text, neg_text],label=0)

        dataset.append(pos_instance)
        dataset.append(neg_instance)

    return dataset


with open('./data/generated4/train_passage.json', 'r') as f:
    train_passage = json.load(f)

with open('./data/generated4/train_corpus.json', 'r') as f:
    train_corpus = json.load(f)

with open('./data/generated4/val_passage.json', 'r') as f:
    val_passage = json.load(f)

with open('./data/generated4/val_corpus.json', 'r') as f:
    val_corpus = json.load(f)

train_triplets = get_triplets(train_passage)
train_dataset = get_dataset(train_triplets, train_corpus)

val_triplets = get_triplets(val_passage)
val_dataset = get_dataset(val_triplets, val_corpus)

In [None]:
logging.basicConfig(
    format='- %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)

model = CrossEncoder(model_name)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
evaluator = CEBinaryClassificationEvaluator.from_input_examples(val_dataset, name='cross_encoder_val')

In [None]:
warmup_steps = int(len(train_dataloader) * 5 * 0.1)

model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=10,
    evaluation_steps=int(len(train_dataloader) / 2),
    warmup_steps=warmup_steps,
    save_best_model=True,
    output_path=model_save_path
)

In [None]:
del model
torch.cuda.empty_cache()

In [None]:
model = CrossEncoder('./models/crenc-exp1/')

In [None]:
positives = []
negatives = []

for triplet in val_triplets:
    query = val_corpus[triplet[0]]
    pos = val_corpus[str(triplet[1])]
    neg = val_corpus[str(triplet[2])]

    positives.append([query, pos])
    negatives.append([query, neg])

positive_scores = model.predict(positives)
negative_scores = model.predict(negatives)

In [None]:
positive_out = np.where(positive_scores < 0.5)[0]
negative_out = np.where(negative_scores >= 0.5)[0]

# sample some bad positive samples
for idx in np.random.choice(positive_out, 10, replace=False):
    score = positive_scores[idx]
    query = val_corpus[val_triplets[idx][0]]
    text = val_corpus[str(val_triplets[idx][1])]

    print(f'Query: {query}\nText: {text}\nScore:{score:.4f}\n')

In [None]:
# sample some bad negative samples
for idx in np.random.choice(negative_out, 10, replace=False):
    score = negative_scores[idx]
    query = val_corpus[val_triplets[idx][0]]
    text = val_corpus[str(val_triplets[idx][2])]

    print(f'Query: {query}\nText: {text}\nScore:{score:.4f}\n')

In [None]:
positive_good = np.where(positive_scores > 0.8)[0]
negative_good = np.where(negative_scores < 0.2)[0]

# sample some good positive samples
for idx in np.random.choice(positive_good, 10, replace=False):
    score = positive_scores[idx]
    query = val_corpus[val_triplets[idx][0]]
    text = val_corpus[str(val_triplets[idx][1])]

    print(f'Query: {query}\nText: {text}\nScore:{score:.4f}\n')

In [None]:
# sample some good negative samples
for idx in np.random.choice(negative_good, 10, replace=False):
    score = negative_scores[idx]
    query = val_corpus[val_triplets[idx][0]]
    text = val_corpus[str(val_triplets[idx][1])]

    print(f'Query: {query}\nText: {text}\nScore:{score:.4f}\n')