In [None]:
!pip install transformers[SentencePiece] datasets evaluate rouge-score SentencePiece accelerate sentence_transformers faiss-cpu

In [None]:
import torch
from datasets import load_dataset
import evaluate
from transformers import (
    T5Tokenizer, 
    T5ForConditionalGeneration, 
    DataCollatorForSeq2Seq, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
    ) 
import nltk
nltk.download('punkt')
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from tqdm import tqdm

In [None]:
qasper_train = load_dataset("qasper", split="train")
qasper_dev = load_dataset("qasper", split="validation")
qasper_test = load_dataset("qasper", split="test")

In [None]:
qasper_qc_train = load_dataset("json", data_files="question_context_train.json")
qasper_qc_dev = load_dataset("json", data_files="question_context_dev.json")

In [None]:
genq_df = pd.read_csv("/content/qasper-flant5-genq-pairs-2-final-cleaned.tsv", sep="\t")

In [None]:
genq_df = genq_df.dropna()

In [None]:
model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')

In [None]:
genq_df.columns = ["question","context"]

In [None]:
genq_dedup_list = list(set(genq_df["context"].values))

In [None]:
embed = []
for passage in tqdm(genq_dedup_list):
  embed.append(model.encode(passage))

In [None]:
embed_array = np.array(embed, dtype=np.float32)

In [None]:
d = embed_array.shape[1]
index = faiss.IndexFlatL2(d)

In [None]:
# index.add(embed_array)

In [None]:
index.ntotal

In [None]:
import random

In [None]:
k = 5
triplets = []
for idx, row in tqdm(genq_df.iterrows()):
  query = row["question"]
  xq = model.encode([query])
  pos = row["context"]
  _, I = index.search(xq, k)
  top_k = I[0]
  random.shuffle(top_k)
  for i in top_k:
    neg = genq_dedup_list[i]
    if neg != pos:
      triplets.append(query +'\t'+ pos +'\t'+ neg)

In [None]:
with open('negative_mine_triplets.tsv', 'w', encoding='utf-8') as fp:
    fp.write('\n'.join(triplets))  # save training data to file

In [None]:
len(triplets)

### GPL: Pseudo Labelling

In [None]:
from sentence_transformers import CrossEncoder

In [None]:
# initialize the cross encoder model first
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [None]:
negative_mine_triplets = pd.read_csv("negative_mine_triplets.tsv", sep="\t", header=None)
negative_mine_triplets.columns = ["q", "p", "n"]

In [None]:
label_lines = []
for idx,row in tqdm(negative_mine_triplets.iterrows()):
    q, p, n = (row["q"], row["p"], row["n"])
    # predict (Q, P+) and (Q, P-) scores
    p_score = model.predict((q, p))
    n_score = model.predict((q, n))
    # calculate the margin score
    margin = p_score - n_score    
    label_lines.append(
        q + '\t' + p + '\t' + n + '\t' + str(margin)
    )

with open("triplets_margin.tsv", 'w', encoding='utf-8') as fp:
    fp.write('\n'.join(label_lines))

### Train Sentence Transformer on Triplet with Margin MSE

In [None]:
import pandas as pd

In [None]:
triplets_w_margin = pd.read_csv("/content/drive/MyDrive/triplets_margin.tsv", sep="\t", header=None)
triplets_w_margin.columns = ["q", "p", "n", "margin"]

In [None]:
from sentence_transformers import InputExample

triplet_train = []

for idx,row in tqdm(triplets_w_margin.iterrows()):
    q, p, n, margin = (row["q"], row["p"], row["n"], row["margin"])
    triplet_train.append(InputExample(
        texts=[q, p, n],
        label=float(margin)
    ))

len(triplet_train)

In [None]:
import torch

torch.cuda.empty_cache()  # clear GPU

batch_size = 16

loader = torch.utils.data.DataLoader(
    triplet_train, batch_size=batch_size, shuffle=True
)

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('multi-qa-distilbert-dot-v1')
model

In [None]:
from sentence_transformers import losses

loss = losses.MarginMSELoss(model)

In [None]:
epochs = 1
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='multi-qa-distilbert-dot-v1-qasper-retriever',
    show_progress_bar=True
)

In [None]:
!zip -r /content/multi-qa-distilbert-dot-v1-qasper-retriever.zip /content/multi-qa-distilbert-dot-v1-qasper-retriever