In [None]:
# !pip install transformers[Sentencepiece] sentence-transformers faiss-cpu datasets evaluate

In [None]:
import json
import torch
from datasets import load_dataset, concatenate_datasets 
import evaluate
from transformers import (
    T5Tokenizer, 
    T5ForConditionalGeneration
    )
import nltk
nltk.download('punkt')

In [None]:
from qasper_utils import get_QAE2, get_all_paragraphs, get_all_questions
import json
import pandas as pd
from tqdm import tqdm
import numpy as np
import faiss

In [None]:
# !cp /content/drive/MyDrive/multi-qa-distilbert-dot-v1-qasper-retriever.zip multi-qa-distilbert-dot-v1-qasper-retriever.zip

In [None]:
# !unzip multi-qa-distilbert-dot-v1-qasper-retriever.zip

In [None]:
# !cp /content/drive/MyDrive/flant5_reader.zip flant5_reader.zip

In [None]:
# !unzip flant5_reader.zip

In [None]:
dev_path = "/content/drive/MyDrive/qasper-dev-v0.3.json"
test_path = "/content/drive/MyDrive/qasper-test-v0.3.json"

with open(dev_path, 'r') as f:
    dev_data = json.load(f)
    
with open(test_path, 'r') as f:
    test_data = json.load(f)

In [None]:
from sentence_transformers import SentenceTransformer

retriever_model = SentenceTransformer("/content/content/multi-qa-distilbert-dot-v1-qasper-retriever")

In [None]:
model_checkpoint = "/content/content/flant5_reader"
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
reader_model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [None]:
def get_answer(question, context):
  input_text = f"question: {question}  context: {context} </s>"
  features = tokenizer([input_text], return_tensors='pt')

  output = reader_model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'], max_new_tokens=128)

  return tokenizer.decode(output[0])

In [None]:
dev_paragraphs = get_all_paragraphs(dev_data)
dev_paragraph_df = pd.DataFrame(dev_paragraphs)

In [None]:
dev_indexes = {}
dev_paragraph_dict = {}
for name, group in tqdm(dev_paragraph_df.groupby("paper_id")):
  paper_para = group["paragraph"].values
  dev_paragraph_dict[name] = paper_para.tolist()
  dev_embed_array = retriever_model.encode(paper_para)
  d = dev_embed_array.shape[1]
  dev_index = faiss.IndexFlatIP(d)
  dev_index.add(dev_embed_array)
  dev_indexes[name]= dev_index

In [None]:
dev_questions = get_all_questions(dev_data)

In [None]:
dev_questions_df = pd.DataFrame(dev_questions)

In [None]:
dev_predictions = [] 
for name,group in tqdm(dev_questions_df.groupby("paper_id")):
  for idx,row in group.iterrows():
    question = row["question"]
    xq = retriever_model.encode([question])
    _, I = dev_indexes[name].search(xq,2)
    evidence = [dev_paragraph_dict[name][i] for i in I[0]]
    context = " ".join(evidence)
    answer = get_answer(question, context)
    dev_predictions.append(
        {
            "question_id": row["question_id"],
            "predicted_answer": answer,
            "predicted_evidence": evidence
        }
    )

In [None]:
cleaned_dev_predictions = []
for dev_pred in dev_predictions:
  cleaned_dev_predictions.append(
      {
          "question_id": dev_pred["question_id"],
          "predicted_answer": dev_pred["predicted_answer"].replace("<pad>","").replace("</s>", "").strip(),
          "predicted_evidence": dev_pred["predicted_evidence"]
      }
  )

In [None]:
with open("cleaned_dev_predictions.jsonl", 'w') as out:
    for pred in cleaned_dev_predictions:
        jout = json.dumps(pred) + '\n'
        out.write(jout)

In [None]:
dev_predictions_df = pd.DataFrame(cleaned_dev_predictions)

In [None]:
sum(dev_predictions_df["predicted_answer"]==104)

### Test Data

In [None]:
test_paragraphs = get_all_paragraphs(test_data)
test_paragraph_df = pd.DataFrame(test_paragraphs)

In [None]:
test_indexes = {}
test_paragraph_dict = {}
for name, group in tqdm(test_paragraph_df.groupby("paper_id")):
  paper_para = group["paragraph"].values
  test_paragraph_dict[name] = paper_para.tolist()
  test_embed_array = retriever_model.encode(paper_para)
  d = test_embed_array.shape[1]
  test_index = faiss.IndexFlatIP(d)
  test_index.add(test_embed_array)
  test_indexes[name]= test_index

In [None]:
test_questions = get_all_questions(test_data)

In [None]:
test_questions_df = pd.DataFrame(test_questions)

In [None]:
test_predictions = [] 
for name,group in tqdm(test_questions_df.groupby("paper_id")):
  for idx,row in group.iterrows():
    question = row["question"]
    xq = retriever_model.encode([question])
    _, I = test_indexes[name].search(xq,2)
    evidence = [test_paragraph_dict[name][i] for i in I[0]]
    context = " ".join(evidence)
    answer = get_answer(question, context)
    test_predictions.append(
        {
            "question_id": row["question_id"],
            "predicted_answer": answer.replace("<pad>","").replace("</s>", "").strip(),
            "predicted_evidence": evidence
        }
    )

In [None]:
with open("cleaned_test_predictions.jsonl", 'w') as out:
    for pred in test_predictions:
        jout = json.dumps(pred) + '\n'
        out.write(jout)