In [None]:
from tqdm.notebook import tqdm
from nltk import sent_tokenize
from transformers import AutoTokenizer
import sys
import seaborn as sb
sys.path.append('../..')

from src import *
from src.test_utils import *

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [None]:
# gritlm = GritLM("GritLM/GritLM-7B", device_map="cuda:2", torch_dtype="auto")
retriever = Retriever(device='cuda:2', syn_dist=0.1)
doc_split = DocSplit(retriever.retriever_tokenizer)
# llm = LLM()
llm = 'mistralai/Mistral-7B-Instruct-v0.2'
# llm = None
longdoc = LongDoc(retriever, llm)
# dataset = NarrativeQADataset(llm)
dataset = QualityDataset(llm, split='dev')
# reading_agent = ReadingAgent(dataset, llm)

In [None]:
test_i = 2
sample = dataset.data[test_i]
questions, answers = dataset.get_questions_and_answers(sample)
article = dataset.get_article(sample)
questions = [q.splitlines()[0] for q in questions]
questions

# Index passages

In [None]:
paragraphs = read_json(os.path.join(dataset.data_dir, f'pages_{2}.json'))

# Retrieval

In [None]:
qid = 5
question = questions[qid]
print(question)

## Test Code

In [None]:
# Encode questions
queries = ["the Skipper", "the new cook", "advice", "avoiding", "Vesta", "avoiding Vesta"]
q_emb = retriever.embed_paragraphs(queries, normalize=True, complete_return=True)
q_strs, q_lhs = [], []
for qid in range(len(q_emb.embeddings)):
    word_spans = word_split(q_emb.input_ids[qid], retriever.retriever_tokenizer, retriever.retriever_tokenizer.bos_token, retriever.retriever_tokenizer.eos_token)
    temp_q_strs, temp_q_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, q_emb.input_ids[qid], q_emb.last_hidden_states[qid], word_spans, False)
    q_strs.append(temp_q_strs)
    q_lhs.append(temp_q_lhs)

In [None]:
pages = doc_split.split_paragraphs(article, 512 // 5)
results, raw = longdoc.index_text_into_map(pages, 3)
write_json('temp.json', [ci.to_json() for ci in results])
write_json('raw.json', raw)

In [None]:
# Encode pages
pages = doc_split.split_paragraphs(article, 512 // 5)
p_emb = retriever.embed_paragraphs(pages, normalize=True, complete_return=True)
p_strs, p_lhs = [], []
for pid in range(len(p_emb.embeddings)):
    word_spans = sent_split(p_emb.input_ids[pid], retriever.retriever_tokenizer, retriever.retriever_tokenizer.bos_token, retriever.retriever_tokenizer.eos_token)
    temp_p_strs, temp_p_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_emb.input_ids[pid], p_emb.last_hidden_states[pid], word_spans, False)
    p_strs.append(temp_p_strs)
    p_lhs.append(temp_p_lhs)
tsne_plot(p_emb.embeddings)

In [None]:
score_mat = q_emb.embeddings @ p_emb.embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(len(pages)), yticklabels=queries, annot=True, ax=ax)

In [None]:
score_mat = q_emb.embeddings @ p_emb.embeddings.T
score_mat_min = score_mat.min(1, keepdims=True)
score_mat_max = score_mat.max(1, keepdims=True)
score_mat = (score_mat - score_mat_min) / (score_mat_max - score_mat_min)
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(len(pages)), yticklabels=queries, annot=True, ax=ax)

In [None]:
score_mat_max.shape

In [None]:
print_input_ids(p_strs, range(10))

In [None]:
print_input_ids(p_strs, [57])

In [None]:
print_pages(pages, range(80, 85))

In [None]:
# Question-page token-sent matching
xid, yid = 1, 0
x_start, x_end = 0, None
y_start, y_end = 0, None

score_mat = (q_lhs[yid] / np.expand_dims(np.linalg.norm(q_lhs[yid], axis=1), axis=1)) @ (p_lhs[xid] / np.expand_dims(np.linalg.norm(p_lhs[xid], axis=1), axis=1)).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=p_strs[xid][x_start:x_end], yticklabels=q_strs[yid][y_start:y_end], annot=True, ax=ax)
fig.savefig('qp.pdf')

In [None]:
# Page-page matching
score_mat = p_emb.embeddings @ p_emb.embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(score_mat.shape[1]), yticklabels=range(score_mat.shape[0]), annot=True, ax=ax)
fig.savefig('pp_all.pdf')

In [None]:
# Page-page sent-sent matching
xid, yid = 9, 10
x_start, x_end = 0, None
y_start, y_end = 0, None

score_mat = (p_lhs[yid] / np.expand_dims(np.linalg.norm(p_lhs[yid], axis=1), axis=1)) @ (p_lhs[xid] / np.expand_dims(np.linalg.norm(p_lhs[xid], axis=1), axis=1)).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1]/2, score_mat.shape[0]/2))
sb.heatmap(score_mat, xticklabels=range(x_start, score_mat.shape[1] + x_start), yticklabels=range(y_start, score_mat.shape[0] + y_start), annot=True, ax=ax)
fig.savefig('pp.pdf')
print('x:\n', pages[xid])
print('y:\n', pages[yid])

In [None]:
test_x_sent = 2
test_y_sent = 2
print(score_mat[test_y_sent, test_x_sent])
print(p_strs[xid][test_x_sent])
print(p_strs[yid][test_y_sent])

In [None]:
p_input_ids, pid2embs_3, pid2lhs_3 = slide_encode(pages, retriever, 3)
# pid2embs_5 = slide_encode(pages, retriever, 5)
p_input_ids, pid2embs_1, pid2lhs_1 = slide_encode(pages, retriever, 1)

In [None]:
# p_weight_5 = np.array([[0., 0., 0.3, 0., 0.]])
p_weight_3 = np.array([0., 0., 0.])
p_weight_1 = np.array([1.0])
# p_embeddings = np.vstack([(p_weight_5 @ embs_5)[0] + (p_weight_1 @ embs_1)[0] for embs_5, embs_1 in zip(pid2embs_5, pid2embs_1)])
p_embeddings = np.vstack([(np.expand_dims(p_weight_3, 0) @ embs_3)[0] + (np.expand_dims(p_weight_1, 0) @ embs_1)[0] for embs_3, embs_1 in zip(pid2embs_3, pid2embs_1)])
p_lhs = [(lhs_3 * np.expand_dims(p_weight_3, (1,2))).mean(0) + (lhs_1 * np.expand_dims(p_weight_1, (1,2))).mean(0) for lhs_3, lhs_1 in zip(pid2lhs_3, pid2lhs_1)]
tsne_plot(p_embeddings, 4)

In [None]:
# p_weight_5 = np.array([[0., 0., 0.3, 0., 0.]])
p_weight_3 = np.array([0., 0.5, 0.])
p_weight_1 = np.array([0.5])
# p_embeddings = np.vstack([(p_weight_5 @ embs_5)[0] + (p_weight_1 @ embs_1)[0] for embs_5, embs_1 in zip(pid2embs_5, pid2embs_1)])
p_embeddings = np.vstack([(np.expand_dims(p_weight_3, 0) @ embs_3)[0] + (np.expand_dims(p_weight_1, 0) @ embs_1)[0] for embs_3, embs_1 in zip(pid2embs_3, pid2embs_1)])
p_lhs = [(lhs_3 * np.expand_dims(p_weight_3, (1,2))).mean(0) + (lhs_1 * np.expand_dims(p_weight_1, (1,2))).mean(0) for lhs_3, lhs_1 in zip(pid2lhs_3, pid2lhs_1)]
tsne_plot(p_embeddings, 4)

In [None]:
# Page-page matching
normalized_p_embeddings = p_embeddings / np.expand_dims(np.linalg.norm(p_embeddings, axis=1), 1)
score_mat = normalized_p_embeddings @ normalized_p_embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(score_mat.shape[1]), yticklabels=range(score_mat.shape[0]), annot=True, ax=ax)
fig.savefig('pp_all.pdf')

In [None]:
score_mat[60, 20]

In [None]:
print_pages(pages, range(60, 70))

In [None]:
# Page-page sent-sent matching
xid, yid = 60, 20

x_word_spans = sent_split(p_input_ids[xid], retriever.retriever_tokenizer)
y_word_spans = sent_split(p_input_ids[yid], retriever.retriever_tokenizer)
plot_score_matrix(retriever.retriever_tokenizer, p_input_ids[xid], p_lhs[xid], x_word_spans, p_input_ids[yid], p_lhs[yid], y_word_spans, False, False)

In [None]:
def norm(x):
    return x / np.linalg.norm(x)

In [None]:
norm(p_lhs[yid][0:21].mean(0)).dot(norm(p_lhs[xid][33:77].mean(0)))

In [None]:
# Page-page sent-sent matching
xid, yid = 60, 20
# x_start, x_end = 33, 77
# y_start, y_end = 0, 21

x_start, x_end = 0, None
y_start, y_end = 0, None

x_word_spans = word_split(p_input_ids[xid], retriever.retriever_tokenizer)
x_strs, x_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_input_ids[xid], p_lhs[xid], [], False, True)
y_word_spans = word_split(p_input_ids[yid], retriever.retriever_tokenizer)
y_strs, y_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_input_ids[yid], p_lhs[yid], [], False, True)

score_mat = (y_lhs) @ (x_lhs).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=x_strs[x_start:score_mat.shape[1] + x_start], yticklabels=y_strs[y_start:score_mat.shape[0] + y_start], annot=True, ax=ax)
fig.savefig('pp.pdf')
print('x:\n', x_strs)
print('y:\n', y_strs)

In [None]:
p_lhs = [np.array(pid2embs[pid]).mean(0) for pid in range(len(pid2embs))]
p_embeddings = np.array([lhs.mean(0) for lhs in p_lhs])
p_norm = np.linalg.norm(p_embeddings, axis=1)
p_embeddings = p_embeddings / np.expand_dims(p_norm, 1)
p_lhs = [lhs / n for lhs, n in zip(p_lhs, p_norm)]
pids, scores = retriever.dense_retrieval(q_emb.embeddings, p_embeddings, None, normalize=False, return_score=True)
pids

In [None]:
query_distribution(retriever.retriever_tokenizer, q_emb.last_hidden_states[0], q_emb.input_ids[0], p_lhs, 5, q_spans=word_spans[3:-1])

In [None]:
query_indicatiors(retriever.retriever_tokenizer, question, [f'passage: {p}' for p in pages], q_emb.last_hidden_states[0], q_emb.input_ids[0], p_lhs, p_input_ids, pids[:10], scores, 5, q_spans=word_spans)

In [None]:
p_emb = retriever.embed_paragraphs([f'passage: {p}' for p in pages], normalize=True, complete_return=True)
pids, scores = retriever.dense_retrieval(q_emb.embeddings, p_emb.embeddings, None, normalize=False, return_score=True)
pids

In [None]:
query_indicatiors(retriever.retriever_tokenizer, question, [f'passage: {p}' for p in pages], q_emb.last_hidden_states[0], q_emb.input_ids[0], p_emb.last_hidden_states, p_emb.input_ids, pids, scores, q_spans=word_spans)

In [None]:
%matplotlib widget

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.widgets import Cursor

# Fixing random state for reproducibility
np.random.seed(19680801)

fig, ax = plt.subplots(figsize=(8, 6))

x, y = 4*(np.random.rand(2, 100) - .5)
ax.plot(x, y, 'o')
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)

# Set useblit=True on most backends for enhanced performance.
cursor = Cursor(ax, useblit=True, color='red', linewidth=2)

plt.show()