# Q2Q Retrieval

## Imports

In [None]:
%pip install tqdm==4.66.5 sentence-transformers einops

In [None]:
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

import torch
import torch.nn.functional as F
from torch import Tensor

from tqdm import tqdm
import numpy as np
import json
import glob

In [None]:
MODEL_NAME='sentence-transformers/all-mpnet-base-v2'
model = SentenceTransformer(MODEL_NAME)

## Data Loading

In [None]:
!git clone https://github.com/RegNLP/ObliQADataset.git

In [8]:
def load_json_files_from_directory(directory_path):
    """Loads all JSON files from a given directory into a list of JSON objects."""

    json_files = glob.glob(directory_path + "/*.json")
    json_data_list = []
    for json_file in json_files:
        with open(json_file, 'r') as f:
            try:
                json_data = json.load(f)
                json_data_list.append((json_file, json_data))
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON in file {json_file}: {e}")
    return json_data_list

directory_path = "ObliQADataset/StructuredRegulatoryDocuments"
json_data_list = load_json_files_from_directory(directory_path)
flattened_json_data_list = [element for json_file, json_data in json_data_list for element in json_data]

dp_id_to_id = {f'{element["DocumentID"]}:{element["PassageID"]}'.replace(' ', '_'):element["ID"] for element in flattened_json_data_list}

## Initializing

In [10]:
# limit model length in case of memory constraint
if model.max_seq_length > 1024:
    model.max_seq_length = 1024

In [9]:
# computes batch embedding of n samples
def get_embedding(e, task=None):
    if len(e) > 0:
        embeddings = torch.Tensor(model.encode([k for k in e], batch_size=32, show_progress_bar = False))
        return np.array(F.normalize(embeddings, p=2, dim=1))
    else:
        return e

def compute_batch_embeddings(sentences, batch_size=32):
    output_embedding = []
    for e in tqdm(range(0, len(sentences), batch_size)):
        try:
            output_embedding.append([e for e in get_embedding(sentences[e:e+batch_size])])
        except:
            print(sentences[e:e+batch_size])
            stupefy

    return np.concatenate(output_embedding)

In [None]:
# builds Q-A cache from train + dev dataset.
# computes the embedding of the questions so that it can be used later on

train_dataset_map = {}
train_dataset_ind = []

data = json.load(open(f'ObliQADataset/ObliQA_train.json', 'r')) + json.load(open(f'ObliQADataset/ObliQA_dev.json', 'r'))

questions = []

for each_question in tqdm(data[:]):
    questions.append(each_question['Question'])
    train_dataset_map[each_question['QuestionID']] = [dp_id_to_id[f"{e['DocumentID']}:{e['PassageID'].replace(' ', '_')}"] for e in each_question['Passages']]
    train_dataset_ind.append(each_question['QuestionID'])

train_embeddings = np.array(compute_batch_embeddings(questions))

## Inference

In [11]:
# should be 'train' when we want to compute for train dataset (Q-A cache is always `train` even if we set `eval_set` to test)
eval_set = "test"

In [None]:
pred_rels = open("q2q.trec", "w")
with open(f'ObliQADataset/ObliQA_{eval_set}.json', 'r') as file:
    data = json.load(file)
    questions = []

    for each_question in tqdm(data[:]):
        questions.append(each_question['Question'])

    question_embeddings = compute_batch_embeddings(questions)

    for each_embedding, each_question in tqdm(zip(question_embeddings, data)):
        # Question to Question embedding
        relevant_passages = each_embedding.dot(train_embeddings.T)
        indices = np.argsort(relevant_passages)
        top_10_passages = indices[-35:]

        i = 0
        cache_passages = set()
        for e in top_10_passages:
            for each_passage_id in train_dataset_map[train_dataset_ind[e]]:
                if each_passage_id in cache_passages:
                    continue

                cache_passages.add(each_passage_id)
                line = f"{each_question['QuestionID']} 0 {each_passage_id} {i+1} {relevant_passages[e]} alg"
                pred_rels.write(line + "\n")
                i += 1

pred_rels.close()