# Importing required libraries and dependencies

In [27]:
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
import math
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import pickle
from sklearn.model_selection import train_test_split

# Using Hard Negatives by training on the bm25 retrieved documents

In [84]:
df = pd.read_csv("bm25scores_new_paras.csv", index_col="Unnamed: 0")

In [86]:
theme_count = []
for theme in df["Theme"].unique():
    theme_count.append((df.loc[df["Theme"] == theme].shape[0], theme))
theme_count.sort(reverse=True)

In [88]:
theme_count[:10]

[(212, 'IPod'),
 (186, '2008_Sichuan_earthquake'),
 (124, 'Pub'),
 (104, 'Catalan_language'),
 (70, 'Adult_contemporary_music'),
 (50, 'Canadian_Armed_Forces'),
 (37, 'Cardinal_(Catholicism)'),
 (34, 'Paper'),
 (29, 'Heresy'),
 (26, 'Human_Development_Index')]

In [90]:
theme_test = pd.DataFrame(columns=df.columns)
theme_train = pd.DataFrame(columns=df.columns)

for theme in df["Theme"].unique():
  group = df.loc[df["Theme"] == theme]
  group_train, group_test = train_test_split(group, test_size=0.2)
  theme_test = pd.concat([theme_test,group_test])
  theme_train = pd.concat([theme_train,group_train])

# Setting training args and loading the student and teacher models

In [76]:
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/minilml2_mse_12_domain'

In [77]:
teacher_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', num_labels=1, device="cuda:3")
student_model = CrossEncoder('output/minilml2_mse_12', num_labels=1, device="cuda:3")

In [78]:
def get_kd_samples(df):
  samples = []
  
  for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    
    teacher_input = []
    for i in range(10):
      if pd.isnull(row["resb_par" + str(i)]):
        continue
      teacher_input.append([row['Question'], row["resb_par" + str(i)]])
    
    if len(teacher_input) == 0:
        continue
    
    ce_logit = teacher_model.predict(teacher_input)
    ind = np.argmax(ce_logit)
    
    if (teacher_input[ind][1] != row["Paragraph"]):
      continue
      
    for i in range(len(ce_logit)):  
      samples.append(InputExample(texts=teacher_input[i], label=ce_logit[i]))

  return samples

X = get_kd_samples(df)

100%|█████████████████████████████████████████████████████████████████████████████████| 944/944 [00:27<00:00, 34.76it/s]


In [64]:
with open("samples_minilm12", "rb") as fp:   
  X = pickle.load(fp)

In [79]:
train_dataloader = DataLoader(X, shuffle=True, batch_size=train_batch_size)
evaluator = CECorrelationEvaluator.from_input_examples(X)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up

In [80]:
student_model.fit(train_dataloader=train_dataloader,
          loss_fct=torch.nn.MSELoss(),
          evaluator=evaluator,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/246 [00:00<?, ?it/s]

# Evaluation

In [81]:
import pandas as pd
import csv
import os
from tqdm import tqdm
from haystack.document_stores import FAISSDocumentStore
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.nodes import SentenceTransformersRanker
from haystack.nodes import BM25Retriever
from haystack import Document as document
import numpy as np
import time
import pickle

In [106]:
ranker = SentenceTransformersRanker(model_name_or_path="output/minilml2_mse_13", devices=["cuda:1"])

In [107]:
X = []
y = []
for index, row in tqdm(theme_test.iterrows(), total=theme_test.shape[0]):
  filtered_docs = []
  for i in range(10):
    if pd.isnull(row["resb_par" + str(i)]):
      continue
    filtered_docs.append(document.from_dict({'content': row["resb_par" + str(i)], 'meta': {'name': row["Theme"]}}))  

  if (len(filtered_docs) == 0):
    X.append(np.zeros((10, 2)))
    y.append(np.zeros((10, )))
    continue

  sample_res_2 = ranker.predict(
      query = row["Question"],
      top_k = len(filtered_docs),
      documents = filtered_docs
  )

  scores = {}
  gold = {}
  for res in sample_res_2:
    scores[res.id] = [[res.score]]
    gold[res.id] = (1 if res.content == row["Paragraph"] else 0)


  _X = np.concatenate(list(scores.values()))
  _X = np.pad(_X, [(0, 10 - _X.shape[0]),(0, 0)], "constant")
  _y = np.array(list(gold.values()))
  _y = np.pad(_y, (0, 10 - _y.shape[0]), "constant")

  X.append(_X)
  y.append(_y)


100%|█████████████████████████████████████████████████████████████████████████████████| 194/194 [00:04<00:00, 39.92it/s]


In [108]:
from collections import Counter

def get_predicted_rank():
  pos = []
  for i in range(len(X)):
    _pos = np.argmax(y[i])
    if (sum(y[i]) == 0):
      pos.append(-1)
      continue
    pos.append(_pos)
    

  return pos, Counter(pos)

In [109]:
pos, count_pos = get_predicted_rank()
print(count_pos)
print(count_pos[0]/len(pos))

Counter({0: 155, 1: 16, -1: 13, 2: 6, 4: 2, 3: 2})
0.7989690721649485


In [93]:
pos, count_pos = get_predicted_rank()
print(count_pos)
print(count_pos[0]/len(pos))

Counter({0: 155, 1: 17, -1: 13, 2: 6, 4: 2, 3: 1})
0.7989690721649485


In [70]:
pos, count_pos = get_predicted_rank()
print(count_pos)
print(count_pos[0]/len(pos))

Counter({0: 2832, 1: 151, -1: 92, 2: 43, 3: 18, 5: 10, 4: 9, 7: 4, 6: 3, 8: 2, 9: 1})
0.8947867298578199


In [None]:
# Finetuned theme overlap split, e = 9
# Counter({0: 2855, 1: 142, -1: 76, 2: 39, 3: 18, 4: 11, 5: 10, 6: 6, 7: 4, 9: 3, 8: 1})
# 0.9020537124802528

In [None]:
# Finetuned theme overlap split, e = 5
# Counter({0: 2831, 1: 163, -1: 76, 2: 37, 3: 22, 4: 14, 6: 8, 7: 6, 5: 5, 9: 3})
# 0.8944707740916271

In [None]:
# Finetuned theme overlap split, e = 3
# Counter({0: 2776, 1: 181, -1: 100, 2: 53, 3: 21, 4: 14, 6: 7, 7: 6, 5: 4, 9: 3})
# 0.8770932069510269

In [None]:
# Pretrained theme overlap split
# Counter({0: 2706, 1: 209, -1: 100, 2: 63, 3: 33, 4: 25, 6: 9, 5: 8, 7: 4, 8: 4, 9: 4})
# 0.8549763033175355