In [1]:
from tqdm.notebook import tqdm
from nltk import sent_tokenize
from transformers import AutoTokenizer
import sys
import seaborn as sb
sys.path.append('../..')
from rank_bm25 import BM25Okapi
from spacy.tokens import Span, Doc

from src import *
from src.test_utils import *
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [2]:
# gritlm = GritLM("GritLM/GritLM-7B", device_map="cuda:2", torch_dtype="auto")
retriever = Retriever(device='cuda:2', syn_dist=0.1)
doc_split = DocSplit(retriever.retriever_tokenizer)
# llm = LLM()
llm = 'mistralai/Mistral-7B-Instruct-v0.2'
# llm = None
longdoc = LongDoc(retriever, llm)
# dataset = NarrativeQADataset(llm)
dataset = QualityDataset(llm, split='dev')
# reading_agent = ReadingAgent(dataset, llm)

In [3]:
test_i = 2
sample = dataset.data[test_i]
questions, answers = dataset.get_questions_and_answers(sample)
article = dataset.get_article(sample)
questions = [q.splitlines()[0] for q in questions]
questions

['What is the most likely meaning of the slang O.Q.? (in twentieth-century American English)',
 'Why does the Skipper stop abruptly after he says "when you\'re running a blockade"?',
 'Who or what is Leo?',
 'Why does the Skipper allow the new chef to use the heat-cannon as an incinerator?',
 ' Lieutenant Dugan brings up the examples of "High G" Gordon and "Runt" Hake in order to illustrates that...',
 "Why didn't the Skipper follow the new cook's advice about avoiding Vesta?",
 'Why was the new cook so upset that the Skipper decided to surrender?',
 'What does the Skipper mean by "lady-logic"?',
 "What would've happened if the new cook had told the Skipper about the ekalastron deposits earlier?"]

# Index passages

In [None]:
paragraphs = read_json(os.path.join(dataset.data_dir, f'pages_{2}.json'))

# Retrieval

In [None]:
qid = 5
question = questions[qid]
print(question)

## Test Code

In [None]:
pages = doc_split.split_paragraphs(article, 512 // 5)
results, raw = longdoc.index_text_into_map(pages, 3)
write_json('temp.json', [ci.to_json() for ci in results])
write_json('raw.json', raw)

### Test Navigation

In [4]:
pages = doc_split.split_paragraphs(article, 50)
all_summary = longdoc.lossless_index(pages, 5, 5, 5, 'relation')
write_json('all_summary.json', all_summary)

37
8


In [None]:
def build_summary_pyramid(
    longdoc:LongDoc, 
    summaries:List[str], 
    summary_chunk_num: int = 5, 
    prev_chunk_num: int = 5, 
    post_chunk_num:int = 5):
    batched_summary_ranges = []
    for batch_start in range(0, len(summaries), summary_chunk_num):
        prev_start = max(batch_start - prev_chunk_num, 0)
        batch_end = min(batch_start + summary_chunk_num, len(summaries))
        post_end = min(batch_end + post_chunk_num, len(summaries))
        chunk_start = batch_start - prev_start
        chunk_end = batch_end - prev_start
        batched_summary_ranges.append((prev_start, chunk_start, chunk_end, post_end))
    for prev_start, chunk_start, chunk_end, post_end in batched_summary_ranges:

In [None]:
pages = doc_split.split_paragraphs(article, 50)
all_summary = read_json('all_summary.json')

### TextGraph

In [None]:
def remove_unimportant(doc:Span, additional_pos_labels:Set[str]=set()):
    spans = []
    temp_span_start = 0
    tid = 0
    while tid < len(doc):
        t = doc[tid]
        if t.pos_ in {'DET', 'PRON', 'CCONJ', 'PUNCT', 'AUX', 'PART'} or t.pos_ in additional_pos_labels:
            if temp_span_start != tid:
                spans.append((temp_span_start, tid))
            temp_span_start = tid + 1
        tid += 1
    if temp_span_start < tid:
        spans.append((temp_span_start, tid))
    splitted_doc = [doc[span[0]:span[1]] for span in spans]
    return splitted_doc

def collect_keywords_from_text(doc:Doc):
    ncs = list(doc.noun_chunks)
    ents = doc.ents
    nc_id, eid = 0, 0
    spans:List[Span] = []
    # Merge noun chunks with entities
    while nc_id < len(ncs) and eid < len(ents):
        nc, ent = ncs[nc_id], ents[eid]
        if set(range(nc.start, nc.end)).intersection(range(ent.start, ent.end)):
            spans.append(doc[min(nc.start, ent.start) : max(nc.end, ent.end)])
            nc_id += 1
            eid += 1
        else:
            if nc.start < ent.end:
                spans.append(nc)
                nc_id += 1
            else:
                spans.append(ent)
                eid += 1
    spans.extend(ncs[nc_id:])
    spans.extend(ents[eid:])
    # Update each noun chunks
    updated_spans:List[Span] = []
    for span in spans:
        updated_spans.extend(remove_unimportant(span, {'ADJ', 'ADV'}))
    ent_candidates = {' '.join([t.lemma_ for t in span]) for span in updated_spans}
    return ent_candidates
    

In [None]:
class TextGraph:
    def __init__(self, docs:List[Doc]) -> None:
        self.text_graph = nx.DiGraph()
        self.ent_graph = nx.Graph()
        self.tokenized_corpus:List[List[str]] = []
        ent_pair_counter = Counter()
        for pid, doc in enumerate(docs):
            tokenized_page = [t.lemma_.lower() for t in doc]
            nouns = collect_keywords_from_text(doc)
            if len(nouns) >= 2:
                ent_pair_counter.update(map(frozenset, itertools.combinations(nouns, 2)))
            self.tokenized_corpus.append(tokenized_page)
            self.text_graph.add_node(pid, tokenized_page=tokenized_page, nouns=nouns)
        for (ent1, ent2), cnt in ent_pair_counter.items():
            self.ent_graph.add_edge(ent1, ent2, log_freq=np.log(cnt+1))
        self.ent_general_importance:Dict[str, float] = nx.pagerank(self.ent_graph, weight='log_freq')
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        for pid1 in range(len(docs)):
            bm25_scores = self.bm25.get_scores(self.tokenized_corpus[pid1])
            bm25_scores = bm25_scores / bm25_scores.sum()
            nouns1:Set[str] = self.text_graph.nodes[pid1]['nouns']
            for pid2 in range(len(docs)):
                if pid1 != pid2:
                    overlap = nouns1.intersection(self.text_graph.nodes[pid2]['nouns'])
                    if overlap:
                        ent_importance = sum([self.ent_general_importance[ent] for ent in overlap])
                        dist = 1 / np.log(np.e + np.abs(pid2 - pid1))
                        bm25_score = bm25_scores[pid2]
                        weight = statistics.harmonic_mean([ent_importance, bm25_score]) * dist
                        self.text_graph.add_edge(pid1, pid2, overlap=overlap, ent_importance=ent_importance, dist=dist, bm25_score=bm25_score, weight=weight)
        

In [None]:
tg = TextGraph([longdoc.nlp(p) for p in all_summary])

In [None]:
list(tg.text_graph.edges.data())[:5]

### Topic Modeling

In [None]:
from gensim import corpora
from gensim.parsing.preprocessing import preprocess_string, DEFAULT_FILTERS
from gensim.models import Phrases, CoherenceModel, LdaModel, EnsembleLda, LdaMulticore

In [None]:
pages = doc_split.split_paragraphs(article, 500)
all_summary = read_json('all_summary.json')

In [None]:
longdoc.llm_server(f'''
Summarize the following passage.

Passage:
{pages[1]}
''')

In [None]:
longdoc.llm_server(f'''
Summarize the following passage.

Passage:
{pages[2]}
''')

In [None]:
longdoc.llm_server(f'''
What are the common information in the following 2 passages.

Passage 1:
{pages[1]}

Passage 2:
{pages[2]}
''')

In [None]:
longdoc.llm_server(f'''
What are the different information between the following 2 passages.

Passage 1:
{pages[1]}

Passage 2:
{pages[2]}
''')

In [None]:
pages[2]

In [None]:
len(all_summary)

In [None]:
preprocess_funcs = DEFAULT_FILTERS[:-1] # Remove the stemming
preprocessed_summary = [preprocess_string(' '.join([t.lemma_ for t in longdoc.nlp(p, disable=['parser', 'ner'])]), preprocess_funcs) for p in all_summary]

# bigram = Phrases(preprocessed_summary, min_count=2, threshold=1)

# texts = [bigram[p] for p in preprocessed_summary]
texts = preprocessed_summary

# Create a dictionary from the corpus
dictionary = corpora.Dictionary(texts)

# Remove low-frequency terms from the dictionary
dictionary.filter_extremes(no_below=2)

# Convert the corpus into a bag-of-words representation
corpus = [dictionary.doc2bow(text) for text in texts]

In [None]:
lda_model = EnsembleLda(
    corpus=corpus, 
    id2word=dictionary, 
    passes=5, 
    iterations=100, 
    num_models=5, 
    # min_cores=10, 
    # min_samples=4,
    epsilon=0.05
    )

In [None]:
lda_model.print_topics()

In [None]:
topic2p = defaultdict(list)
for pid, p in enumerate(corpus):
    topic_id = sorted(lda_model[p], key=lambda x: x[1])[-1][0]
    topic2p[topic_id].append(all_summary[pid])
print(lda_model.stable_topics.shape)
print([(tid, len(topic2p[tid])) for tid in range(len(topic2p))])

In [None]:
topic2p[0]

In [None]:
topics = []
score = []
topic_models:Dict[int, LdaModel] = {}
min_docs_per_topic = 4
for topic_num in tqdm(range(4, len(all_summary) // min_docs_per_topic, 4)):
    # Build the LDA model
    lda_model = LdaMulticore(corpus, topic_num, dictionary, iterations=100, passes=5, workers=5)
    cm = CoherenceModel(lda_model, texts = texts, corpus=corpus, dictionary=dictionary, coherence='c_v')
    topics.append(topic_num)
    score.append(cm.get_coherence())
    topic_models[topic_num] = lda_model
    
plt.plot(topics, score)
plt.xlabel('Number of Topics')
plt.ylabel('Coherence Score')
plt.show()

In [None]:
lda_model = topic_models[44]

In [None]:
list(lda_model.get_document_topics(corpus))

In [None]:
for idx, topic in lda_model.print_topics(-1):
    print('Topic: {} \nWords: {}'.format(idx, topic))

### Test plot

In [None]:
raw = read_json('raw.json')

In [None]:
results = [ChunkInfo(**ci) for ci in read_json('temp.json')]

In [None]:
queries = ["the Skipper", "the new cook", "advice", 
        #    "avoiding", 
           "Vesta",
           "the cook",
        #    "avoiding Vesta"
           ]

In [None]:
plot_map(results, queries, retriever)

In [None]:
# Encode pages
pages = doc_split.split_paragraphs(article, 512 // 5)
p_emb = retriever.embed_paragraphs(pages, normalize=True, complete_return=True)
p_strs, p_lhs = [], []
for pid in range(len(p_emb.embeddings)):
    word_spans = sent_split(p_emb.input_ids[pid], retriever.retriever_tokenizer, retriever.retriever_tokenizer.bos_token, retriever.retriever_tokenizer.eos_token)
    temp_p_strs, temp_p_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_emb.input_ids[pid], p_emb.last_hidden_states[pid], word_spans, False)
    p_strs.append(temp_p_strs)
    p_lhs.append(temp_p_lhs)
tsne_plot(p_emb.embeddings)

In [None]:
score_mat = q_emb.embeddings @ p_emb.embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(len(pages)), yticklabels=queries, annot=True, ax=ax)

In [None]:
score_mat = q_emb.embeddings @ p_emb.embeddings.T
score_mat_min = score_mat.min(1, keepdims=True)
score_mat_max = score_mat.max(1, keepdims=True)
score_mat = (score_mat - score_mat_min) / (score_mat_max - score_mat_min)
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(len(pages)), yticklabels=queries, annot=True, ax=ax)

In [None]:
score_mat_max.shape

In [None]:
print_input_ids(p_strs, range(10))

In [None]:
print_input_ids(p_strs, [57])

In [None]:
print_pages(pages, range(80, 85))

In [None]:
# Question-page token-sent matching
xid, yid = 1, 0
x_start, x_end = 0, None
y_start, y_end = 0, None

score_mat = (q_lhs[yid] / np.expand_dims(np.linalg.norm(q_lhs[yid], axis=1), axis=1)) @ (p_lhs[xid] / np.expand_dims(np.linalg.norm(p_lhs[xid], axis=1), axis=1)).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=p_strs[xid][x_start:x_end], yticklabels=q_strs[yid][y_start:y_end], annot=True, ax=ax)
fig.savefig('qp.pdf')

In [None]:
# Page-page matching
score_mat = p_emb.embeddings @ p_emb.embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(score_mat.shape[1]), yticklabels=range(score_mat.shape[0]), annot=True, ax=ax)
fig.savefig('pp_all.pdf')

In [None]:
# Page-page sent-sent matching
xid, yid = 9, 10
x_start, x_end = 0, None
y_start, y_end = 0, None

score_mat = (p_lhs[yid] / np.expand_dims(np.linalg.norm(p_lhs[yid], axis=1), axis=1)) @ (p_lhs[xid] / np.expand_dims(np.linalg.norm(p_lhs[xid], axis=1), axis=1)).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1]/2, score_mat.shape[0]/2))
sb.heatmap(score_mat, xticklabels=range(x_start, score_mat.shape[1] + x_start), yticklabels=range(y_start, score_mat.shape[0] + y_start), annot=True, ax=ax)
fig.savefig('pp.pdf')
print('x:\n', pages[xid])
print('y:\n', pages[yid])

In [None]:
test_x_sent = 2
test_y_sent = 2
print(score_mat[test_y_sent, test_x_sent])
print(p_strs[xid][test_x_sent])
print(p_strs[yid][test_y_sent])

In [None]:
p_input_ids, pid2embs_3, pid2lhs_3 = slide_encode(pages, retriever, 3)
# pid2embs_5 = slide_encode(pages, retriever, 5)
p_input_ids, pid2embs_1, pid2lhs_1 = slide_encode(pages, retriever, 1)

In [None]:
# p_weight_5 = np.array([[0., 0., 0.3, 0., 0.]])
p_weight_3 = np.array([0., 0., 0.])
p_weight_1 = np.array([1.0])
# p_embeddings = np.vstack([(p_weight_5 @ embs_5)[0] + (p_weight_1 @ embs_1)[0] for embs_5, embs_1 in zip(pid2embs_5, pid2embs_1)])
p_embeddings = np.vstack([(np.expand_dims(p_weight_3, 0) @ embs_3)[0] + (np.expand_dims(p_weight_1, 0) @ embs_1)[0] for embs_3, embs_1 in zip(pid2embs_3, pid2embs_1)])
p_lhs = [(lhs_3 * np.expand_dims(p_weight_3, (1,2))).mean(0) + (lhs_1 * np.expand_dims(p_weight_1, (1,2))).mean(0) for lhs_3, lhs_1 in zip(pid2lhs_3, pid2lhs_1)]
tsne_plot(p_embeddings, 4)

In [None]:
# p_weight_5 = np.array([[0., 0., 0.3, 0., 0.]])
p_weight_3 = np.array([0., 0.5, 0.])
p_weight_1 = np.array([0.5])
# p_embeddings = np.vstack([(p_weight_5 @ embs_5)[0] + (p_weight_1 @ embs_1)[0] for embs_5, embs_1 in zip(pid2embs_5, pid2embs_1)])
p_embeddings = np.vstack([(np.expand_dims(p_weight_3, 0) @ embs_3)[0] + (np.expand_dims(p_weight_1, 0) @ embs_1)[0] for embs_3, embs_1 in zip(pid2embs_3, pid2embs_1)])
p_lhs = [(lhs_3 * np.expand_dims(p_weight_3, (1,2))).mean(0) + (lhs_1 * np.expand_dims(p_weight_1, (1,2))).mean(0) for lhs_3, lhs_1 in zip(pid2lhs_3, pid2lhs_1)]
tsne_plot(p_embeddings, 4)

In [None]:
# Page-page matching
normalized_p_embeddings = p_embeddings / np.expand_dims(np.linalg.norm(p_embeddings, axis=1), 1)
score_mat = normalized_p_embeddings @ normalized_p_embeddings.T
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=range(score_mat.shape[1]), yticklabels=range(score_mat.shape[0]), annot=True, ax=ax)
fig.savefig('pp_all.pdf')

In [None]:
score_mat[60, 20]

In [None]:
print_pages(pages, range(60, 70))

In [None]:
# Page-page sent-sent matching
xid, yid = 60, 20

x_word_spans = sent_split(p_input_ids[xid], retriever.retriever_tokenizer)
y_word_spans = sent_split(p_input_ids[yid], retriever.retriever_tokenizer)
plot_score_matrix(retriever.retriever_tokenizer, p_input_ids[xid], p_lhs[xid], x_word_spans, p_input_ids[yid], p_lhs[yid], y_word_spans, False, False)

In [None]:
def norm(x):
    return x / np.linalg.norm(x)

In [None]:
norm(p_lhs[yid][0:21].mean(0)).dot(norm(p_lhs[xid][33:77].mean(0)))

In [None]:
# Page-page sent-sent matching
xid, yid = 60, 20
# x_start, x_end = 33, 77
# y_start, y_end = 0, 21

x_start, x_end = 0, None
y_start, y_end = 0, None

x_word_spans = word_split(p_input_ids[xid], retriever.retriever_tokenizer)
x_strs, x_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_input_ids[xid], p_lhs[xid], [], False, True)
y_word_spans = word_split(p_input_ids[yid], retriever.retriever_tokenizer)
y_strs, y_lhs = merge_words_and_embeddings(retriever.retriever_tokenizer, p_input_ids[yid], p_lhs[yid], [], False, True)

score_mat = (y_lhs) @ (x_lhs).T
score_mat = score_mat[y_start:y_end, x_start:x_end]
fig, ax = plt.subplots(figsize=(score_mat.shape[1], score_mat.shape[0]))
sb.heatmap(score_mat, xticklabels=x_strs[x_start:score_mat.shape[1] + x_start], yticklabels=y_strs[y_start:score_mat.shape[0] + y_start], annot=True, ax=ax)
fig.savefig('pp.pdf')
print('x:\n', x_strs)
print('y:\n', y_strs)

In [None]:
p_lhs = [np.array(pid2embs[pid]).mean(0) for pid in range(len(pid2embs))]
p_embeddings = np.array([lhs.mean(0) for lhs in p_lhs])
p_norm = np.linalg.norm(p_embeddings, axis=1)
p_embeddings = p_embeddings / np.expand_dims(p_norm, 1)
p_lhs = [lhs / n for lhs, n in zip(p_lhs, p_norm)]
pids, scores = retriever.dense_retrieval(q_emb.embeddings, p_embeddings, None, normalize=False, return_score=True)
pids

In [None]:
query_distribution(retriever.retriever_tokenizer, q_emb.last_hidden_states[0], q_emb.input_ids[0], p_lhs, 5, q_spans=word_spans[3:-1])

In [None]:
query_indicatiors(retriever.retriever_tokenizer, question, [f'passage: {p}' for p in pages], q_emb.last_hidden_states[0], q_emb.input_ids[0], p_lhs, p_input_ids, pids[:10], scores, 5, q_spans=word_spans)

In [None]:
p_emb = retriever.embed_paragraphs([f'passage: {p}' for p in pages], normalize=True, complete_return=True)
pids, scores = retriever.dense_retrieval(q_emb.embeddings, p_emb.embeddings, None, normalize=False, return_score=True)
pids

In [None]:
query_indicatiors(retriever.retriever_tokenizer, question, [f'passage: {p}' for p in pages], q_emb.last_hidden_states[0], q_emb.input_ids[0], p_emb.last_hidden_states, p_emb.input_ids, pids, scores, q_spans=word_spans)

In [None]:
%matplotlib widget

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.widgets import Cursor

# Fixing random state for reproducibility
np.random.seed(19680801)

fig, ax = plt.subplots(figsize=(8, 6))

x, y = 4*(np.random.rand(2, 100) - .5)
ax.plot(x, y, 'o')
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)

# Set useblit=True on most backends for enhanced performance.
cursor = Cursor(ax, useblit=True, color='red', linewidth=2)

plt.show()