In [None]:
import pandas as pd
import json 
import glob
import faiss
from fuzzywuzzy import fuzz
from llmsherpa.readers import LayoutPDFReader
import matplotlib.pyplot as plt
import numpy as np
import os
import datetime
import torch
from transformers import AutoTokenizer, AutoModel

In [None]:
import tqdm

In [None]:
from src.rechunker import Rechunker
from src.encoder.sentence_transformer import Encoder
from src.faiss.flat_idx import FlatIdx
from utils.utils import flatten_list, write_list_to_file, read_list_from_file
from src.eval import Eval


## Data

In [None]:
save_path = r"C:\Users\J C SINGLA\Downloads\External - take_home_challenge_(withJSONs)\take_home_challenge_(withJSONs)\data"

In [None]:
all_data_sherpa = read_list_from_file(save_path, "sherpa_paras_and_tables")
filenames_sherpa = read_list_from_file(save_path, "sherpa_paras_and_tables_filenames")
assert (len(all_data_sherpa)==len(filenames_sherpa))

In [None]:
ground_truth_path = r"C:\Users\J C SINGLA\Downloads\External - take_home_challenge_(withJSONs)\take_home_challenge_(withJSONs)\document_questions.xlsx"
ground_truth = pd.read_excel(ground_truth_path)
ground_truth_text = ground_truth
test_data = list(ground_truth_text["relevant questions"])
test_labels = list(ground_truth_text["answer"])

In [None]:
len(test_labels), len(all_data_sherpa)

## Model

In [None]:
def normalize_vectors(vectors):
    """Normalize vectors to unit length."""
    norm = np.linalg.norm(vectors, axis=1, keepdims=True)
    return vectors / np.where(norm == 0, 1, norm)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('facebook/dragon-plus-query-encoder')
query_encoder = AutoModel.from_pretrained('facebook/dragon-plus-query-encoder')
context_encoder = AutoModel.from_pretrained('facebook/dragon-plus-context-encoder')

In [None]:
query =  test_data
contexts = all_data_sherpa

In [None]:
ctx_emb = torch.empty((0, 768))
for i in tqdm.tqdm(range(len(all_data_sherpa))):
    ctx_input = tokenizer(contexts[i:i+1], padding=True, truncation=True, return_tensors='pt', max_length = 512)
    temp_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
    ctx_emb = torch.cat((ctx_emb, temp_emb), dim=0)

In [None]:
ctx_emb_ = normalize_vectors(ctx_emb.detach().numpy())

In [None]:
save_path = r"C:\Users\J C SINGLA\Downloads\External - take_home_challenge_(withJSONs)\take_home_challenge_(withJSONs)\data\tensors\dragon_paras_norm.pt"
torch.save(ctx_emb_, save_path)

QUERY VECTORS

In [None]:
xq = torch.empty((0, 768))
for i in tqdm.tqdm(range(len(test_data))):
    q_input = tokenizer(test_data[i:i+1], padding=True, truncation=True, return_tensors='pt', max_length = 512)
    temp_emb = context_encoder(**q_input).last_hidden_state[:, 0, :]
    xq = torch.cat((xq, temp_emb), dim=0)

In [None]:
xq_ = normalize_vectors(xq.detach().numpy())

In [None]:
index = faiss.IndexFlatL2(ctx_emb.shape[1])

In [None]:
index.add(ctx_emb_)

## Search

In [None]:
k = 10
distances, indices = index.search(xq_, k)

In [None]:
ret_context = []
for i in range(len(test_data)):
    retrieved_items = [contexts[i] for i in list(indices[i])]
    ret_context.append(retrieved_items)

In [None]:
metric = Eval(k=10)
recall, incorrect, correct = metric.recall_k(test_labels, ret_context)

In [None]:
print ("Recall is ", recall)
print ("MRR is ", metric.mean_reciprocal_rank(ret_context, test_labels))

In [None]:
incorrect[3]

In [None]:
test_data[3]