In [1]:
import os

import json
import numpy as np


from langchain_community.vectorstores import FAISS
import faiss
import torch
# from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate

# from abc import abstractmethod
# import concurrent.futures
# from concurrent.futures import ThreadPoolExecutor

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from dotenv import load_dotenv
from langchain_core.documents import Document
from typing_extensions import List, TypedDict

from sentence_transformers import SentenceTransformer


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
"""Normally User-Inputted Args"""
dataset = 'musique'
model_label = 'SBERT'
unit = 'hippo'

vector_path = f'data/{dataset}/{dataset}_{model_label}_{unit}_vectors_norm.npy'
index_path = f'data/{dataset}/{dataset}_{model_label}_{unit}_ip_norm.index'

#check if index has been built, otherwise build it using Sentence BERT
if(os.path.isfile(index_path)):
    if dataset == 'musique':
        faiss_index = faiss.read_index('data/musique/musique_facebook_contriever_proposition_ip_norm.index')
    else:
        faiss_index = faiss.read_index('data/2wikimultihopqa/2wikimultihopqa_facebook_contriever_proposition_ip_norm.index')
else:
    corpus_contents = []
    if dataset == 'musique':
        corpus = json.load(open('data/musique_corpus.json', 'r'))
    elif dataset == '2wikimultihopqa':
        corpus = json.load(open('data/2wikimultihopqa_corpus.json', 'r'))
    for item in corpus:
        corpus_contents.append(item['title'] + '\n' + item['text'])
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    sentence_embeddings = model.encode(corpus_contents)
    # sentence_embeddings.shape #dimension
    nlist = 50  # how many cells
    quantizer = faiss.IndexFlatL2(d)
    index = faiss.IndexIVFFlat(quantizer, d, nlist)
    index.train(sentence_embeddings)
    print(f'status of index training: {index.is_trained}')
    index.add(sentence_embeddings)
    print(f'number of embeddings indexed: {index.ntotal}')
    index.nprobe = 10
    fp = open(index_path, 'w')
    faiss.write_index(index, index_path)
    print('index saved to {}'.format(index_path))
    print('index size: {}'.format(index.ntotal))

In [9]:
#sanity check
k = 8
# xq = model.encode(["When was the person who Messi's goals in Copa del Rey compared to get signed by Barcelona?"])
# %%time
D, I = index.search(sentence_embeddings[:5], k)  # search
print(I)
print(D)

[[   0   15    5 3237 3467 9778 4049   12]
 [   1 1151 5183 1182 9383 5416    2    6]
 [   2 1151    6 1401 1392 1157 1182    9]
 [   3   10   17 1153 1161 4157 1148    7]
 [   4   15 4157 1161 4816 1148 1162 4812]]
[[ 0.       62.610523 68.035484 69.34587  70.68082  75.65885  76.97651
  81.67229 ]
 [ 0.       62.02321  74.02515  75.6939   78.519455 81.68285  82.059555
  85.76392 ]
 [ 0.       55.484562 57.086422 59.82897  61.155815 65.204315 67.402985
  67.54554 ]
 [ 0.       26.270998 37.556023 45.910477 49.71403  50.727356 52.845726
  57.460052]
 [ 0.       78.98981  82.67015  87.23726  89.43833  91.34279  92.65403
  92.92813 ]]


In [None]:
query = model.encode(["When was the person who Messi's goals in Copa del Rey compared to get signed by Barcelona?"])
k = 8
D, I = index.search(query, k)  # search
print(I)
print(D)

Recall Evaluation

In [None]:
"""User-inputted arg"""
max_steps = 1

if dataset == 'musique':
    data = json.load(open('data/musique.json', 'r'))
    if corpus is not None:
        corpus = json.load(open('data/musique_corpus.json', 'r'))
    # prompt_path = 'data/ircot_prompts/musique/gold_with_3_distractors_context_cot_qa_codex.txt'
    max_steps = max_steps if max_steps is not None else 4
elif dataset == '2wikimultihopqa':
    data = json.load(open('data/2wikimultihopqa.json', 'r'))
    if corpus is not None:
        corpus = json.load(open('data/2wikimultihopqa_corpus.json', 'r'))
    # prompt_path = 'data/ircot_prompts/2wikimultihopqa/gold_with_3_distractors_context_cot_qa_codex.txt'
    max_steps = max_steps if max_steps is not None else 2
else:
    raise NotImplementedError(f'Dataset {dataset} not implemented')

top_k = 100

In [None]:
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

In [None]:
def retrieve(state: State):
    D, I = faiss_index.search(state["question"], top_k)
    return I.tolist()[0], D.tolist()[0]

def generate(state: State):
    