In [None]:
from tqdm.notebook import tqdm
from nltk import sent_tokenize
from transformers import AutoTokenizer
import sys
sys.path.append('../..')

from src import *
from src.test_utils import * 

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')

In [None]:
tokenizer.model_max_length

In [None]:
# gritlm = GritLM("GritLM/GritLM-7B", device_map="cuda:2", torch_dtype="auto")
retriever = Retriever()
# llm = LLM()
# llm = 'mistralai/Mistral-7B-Instruct-v0.2'
llm = None
longdoc = LongDoc(retriever, llm)

# Experiment

In [None]:
dataset = NarrativeQADataset(llm)

In [None]:
dataset = QualityDataset(llm, split='dev')

In [None]:
reading_agent = ReadingAgent(dataset, llm)

In [None]:
test_i = 2
results = [ChunkInfo(**ci) for ci in read_json(os.path.join(dataset.data_dir, f'index_wg_2_{test_i}.json'))]
relation_graph = longdoc.build_relation_graph(results)
pages = [ci.passage for ci in results]
questions, answers = dataset.get_questions_and_answers(dataset.data[test_i])
questions = [q.splitlines()[0] for q in questions]
questions

In [None]:
for ci in results:
    print(len(gritlm.tokenizer(ci.passage)['input_ids']))

# Index passages

In [None]:
paragraphs = ['\n'.join(p) for p in read_json(os.path.join(dataset.data_dir, f'pages_{1}.json'))]

## Eval

In [None]:
results[11].print()

In [None]:
results[11].prev_summaries

In [None]:
print(results[11].recap_str)

# Retrieval

In [None]:
qid = 8
print(questions[qid])

## Eval

### Contriever

#### Query Encode With Note, Doc Encode Without Note

In [None]:
ent_candidates = longdoc.collect_entities_from_text(questions[qid])
prev_ent_descriptions, prev_relation_descriptions = longdoc.retrieve_descriptions(results, relation_graph, ent_candidates, 1, 2)
q_info = ChunkInfo(len(results), questions[qid], prev_ent_descriptions=prev_ent_descriptions, prev_relation_descriptions=prev_relation_descriptions)
recap_str = f'''Recap:\n{q_info.recap_str}\n\nQuery:\n'''
full_input = recap_str + questions[qid]
print(len(retriever.retriever_tokenizer(full_input)['input_ids']))

q_embedding = retriever.embed_paragraphs([full_input], normalize=False, complete_return=True)
page_embeddings = retriever.embed_paragraphs(pages, normalize=False, complete_return=True)
c_retriever_tokenizer = retriever.retriever_tokenizer
c_q_input_ids, c_q_emb, c_q_lhs = hidden_states_wo_instruction(q_embedding.input_ids.copy(), q_embedding.last_hidden_states.copy(), q_embedding.attention_mask.copy(), c_retriever_tokenizer([recap_str])['attention_mask'], True)
c_p_input_ids, c_p_emb, c_p_lhs = hidden_states_wo_instruction(page_embeddings.input_ids.copy(), page_embeddings.last_hidden_states.copy(), page_embeddings.attention_mask.copy(), c_retriever_tokenizer([''])['attention_mask'], True)
c_pids, c_scores = retriever.dense_retrieval(c_q_emb, c_p_emb, None, normalize=False, return_score=True)
q_spans = word_split(c_q_input_ids[0], c_retriever_tokenizer, False, True)
query_indicatiors(c_retriever_tokenizer, questions[qid], pages, c_q_lhs[0], c_q_input_ids[0], c_p_lhs, c_p_input_ids, c_pids, c_scores, q_spans=q_spans)

#### Query Encode Without Note, Doc Encode Without Note

In [None]:
recap_str = ''
full_input = recap_str + questions[qid]
print(len(retriever.retriever_tokenizer(full_input)['input_ids']))

q_embedding = retriever.embed_paragraphs([full_input], normalize=False, complete_return=True)
page_embeddings = retriever.embed_paragraphs(pages, normalize=False, complete_return=True)
c_retriever_tokenizer = retriever.retriever_tokenizer
c_q_input_ids, c_q_emb, c_q_lhs = hidden_states_wo_instruction(q_embedding.input_ids.copy(), q_embedding.last_hidden_states.copy(), q_embedding.attention_mask.copy(), c_retriever_tokenizer([recap_str])['attention_mask'], True)
c_p_input_ids, c_p_emb, c_p_lhs = hidden_states_wo_instruction(page_embeddings.input_ids.copy(), page_embeddings.last_hidden_states.copy(), page_embeddings.attention_mask.copy(), c_retriever_tokenizer([''])['attention_mask'], True)
c_pids, c_scores = retriever.dense_retrieval(c_q_emb, c_p_emb, None, normalize=False, return_score=True)
q_spans = word_split(c_q_input_ids[0], c_retriever_tokenizer, False, True)
query_indicatiors(c_retriever_tokenizer, questions[qid], pages, c_q_lhs[0], c_q_input_ids[0], c_p_lhs, c_p_input_ids, c_pids, c_scores)#, q_spans=q_spans)

### GritLM

#### Query Encode With Note, Doc Encode With Note

In [None]:
ent_candidates = longdoc.collect_entities_from_text(questions[qid])
prev_ent_descriptions, prev_relation_descriptions = longdoc.retrieve_descriptions(results, relation_graph, ent_candidates, 4, True)
q_info = ChunkInfo(len(results), questions[qid], prev_ent_descriptions=prev_ent_descriptions, prev_relation_descriptions=prev_relation_descriptions)
g_q_emb, g_q_input_ids, g_q_lhs = gritlm.encode([questions[qid]], max_length=8192, instructions=[LongDocPrompt.embed_w_note(q_info.recap_str, 'query')])

g_p_emb, g_p_input_ids, g_p_lhs = gritlm.encode(pages, batch_size=5, max_length=8192, instructions=[LongDocPrompt.embed_w_note(ci.recap_str, 'passage') for ci in results])
g_retriever_tokenizer = gritlm.tokenizer
q_spans = word_split(g_q_input_ids[0], g_retriever_tokenizer)
g_pids, g_scores = retriever.dense_retrieval(g_q_emb, g_p_emb, None, normalize=False, return_score=True)
query_indicatiors(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, g_pids, g_scores)#, q_spans=q_spans)

#### Query Encode Without Note, Doc Encode With Note

In [None]:
g_q_emb, g_q_input_ids, g_q_lhs = gritlm.encode([questions[qid]])#, instructions=["Retrieve relevant passages from a story to answer a given question."])

g_p_emb, g_p_input_ids, g_p_lhs = gritlm.encode(pages, batch_size=5, max_length=8192, instructions=[LongDocPrompt.embed_w_note(ci.recap_str, 'passage') for ci in results])
g_retriever_tokenizer = gritlm.tokenizer
q_spans = word_split(g_q_input_ids[0], g_retriever_tokenizer)
g_pids, g_scores = retriever.dense_retrieval(g_q_emb, g_p_emb, None, normalize=False, return_score=True)
query_indicatiors(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, g_pids, g_scores)#, q_spans=q_spans)

#### Query Encode With Note, Doc Encode Without Note

In [None]:
ent_candidates = longdoc.collect_entities_from_text(questions[qid])
prev_ent_descriptions, prev_relation_descriptions = longdoc.retrieve_descriptions(results, relation_graph, ent_candidates, 1, 2)

In [None]:
prev_ent_descriptions

In [None]:
prev_relation_descriptions

In [None]:
ent_candidates = longdoc.collect_entities_from_text(questions[qid])
prev_ent_descriptions, prev_relation_descriptions = longdoc.retrieve_descriptions(results, relation_graph, ent_candidates, 1, 2)
q_info = ChunkInfo(len(results), questions[qid], prev_ent_descriptions=prev_ent_descriptions, prev_relation_descriptions=prev_relation_descriptions)
instruction = gritlm.gritlm_instruction('Use the recap context to help you understand the query and retrieve relevant passages from a story to answer the query.')
recap_str = f'''{instruction}\nRecap:\n{q_info.recap_str}\n\nQuery:\n'''
print(len(gritlm.tokenizer(recap_str + questions[qid])['input_ids']))
g_q_emb, g_q_input_ids, g_q_lhs = gritlm.encode([questions[qid]], max_length=8192, instructions=[recap_str])

g_p_emb, g_p_input_ids, g_p_lhs = gritlm.encode(pages, max_length=8192)
g_retriever_tokenizer = gritlm.tokenizer
q_spans = word_split(g_q_input_ids[0], g_retriever_tokenizer)
g_pids, g_scores = retriever.dense_retrieval(g_q_emb, g_p_emb, None, normalize=False, return_score=True)
query_indicatiors(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, g_pids, g_scores, q_spans=q_spans)

In [None]:
important_page_tokens(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_q_emb[0], g_p_lhs, g_p_input_ids, g_pids, g_scores)

#### Query Encode Without Note, Doc Encode Without Note

In [None]:
g_q_emb, g_q_input_ids, g_q_lhs = gritlm.encode([questions[qid]], instructions=[gritlm.gritlm_instruction("Retrieve relevant passages from a story to answer a given question.")])

g_p_emb, g_p_input_ids, g_p_lhs = gritlm.encode(pages, max_length=8192)
g_retriever_tokenizer = gritlm.tokenizer
q_spans = word_split(g_q_input_ids[0], g_retriever_tokenizer)
g_pids, g_scores = retriever.dense_retrieval(g_q_emb, g_p_emb, None, normalize=False, return_score=True)
query_indicatiors(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, g_pids, g_scores, q_spans=q_spans)

## Test Code

In [None]:
results[11].print()

In [None]:
%matplotlib widget

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.widgets import Cursor

# Fixing random state for reproducibility
np.random.seed(19680801)

fig, ax = plt.subplots(figsize=(8, 6))

x, y = 4*(np.random.rand(2, 100) - .5)
ax.plot(x, y, 'o')
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)

# Set useblit=True on most backends for enhanced performance.
cursor = Cursor(ax, useblit=True, color='red', linewidth=2)

plt.show()