In [None]:
from tqdm.notebook import tqdm
from nltk import sent_tokenize
from transformers import AutoTokenizer
import sys
import seaborn as sb
sys.path.append('../..')
from rank_bm25 import BM25Okapi
from spacy.tokens import Span, Doc

from src import *
from src.test_utils import *
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [None]:
# gritlm = GritLM("GritLM/GritLM-7B", device_map="cuda:2", torch_dtype="auto")
retriever = Retriever(device='cpu', syn_dist=0.1)
doc_split = DocSplit(retriever.retriever_tokenizer)
# llm = LLM()
llm = 'mistralai/Mistral-7B-Instruct-v0.2'
# llm = None
longdoc = LongDoc(retriever, llm)
# dataset = NarrativeQADataset(llm)
dataset = QualityDataset(llm, split='dev')
# reading_agent = ReadingAgent(dataset, llm)

In [None]:
test_i = 2
sample = dataset.data[test_i]
questions, answers = dataset.get_questions_and_answers(sample)
article = dataset.get_article(sample)
questions = [q.splitlines()[0] for q in questions]
questions

# Index passages

# Retrieval

## Test Code

In [None]:
pages = doc_split.split_paragraphs(article, 512 // 5)
results, raw = longdoc.index_text_into_map(pages, 3)
write_json('temp.json', [ci.to_json() for ci in results])
write_json('raw.json', raw)

### Test Navigation

In [None]:
pages = doc_split.split_paragraphs(article, 50)
all_summary = longdoc.lossless_index(pages, 5, 5, 5, 'relation')
write_json('all_summary.json', all_summary)

In [None]:
pages = doc_split.split_paragraphs(article, 50)
all_summary = read_json('all_summary.json')

In [None]:
tree = longdoc.build_summary_pyramid(pages, all_summary)
dump_tree('temp_tree.json', tree)

In [None]:
tree = load_tree('temp_tree.json')

In [None]:
len(tree[-1])

In [None]:
test_node = tree[-1][0]

In [None]:
test_node.children

In [None]:
test_node.unique_in_right

### TextGraph

In [None]:
def remove_unimportant(doc:Span, additional_pos_labels:Set[str]=set()):
    spans = []
    temp_span_start = 0
    tid = 0
    while tid < len(doc):
        t = doc[tid]
        if t.pos_ in {'DET', 'PRON', 'CCONJ', 'PUNCT', 'AUX', 'PART'} or t.pos_ in additional_pos_labels:
            if temp_span_start != tid:
                spans.append((temp_span_start, tid))
            temp_span_start = tid + 1
        tid += 1
    if temp_span_start < tid:
        spans.append((temp_span_start, tid))
    splitted_doc = [doc[span[0]:span[1]] for span in spans]
    return splitted_doc

def collect_keywords_from_text(doc:Doc):
    ncs = list(doc.noun_chunks)
    ents = doc.ents
    nc_id, eid = 0, 0
    spans:List[Span] = []
    # Merge noun chunks with entities
    while nc_id < len(ncs) and eid < len(ents):
        nc, ent = ncs[nc_id], ents[eid]
        if set(range(nc.start, nc.end)).intersection(range(ent.start, ent.end)):
            spans.append(doc[min(nc.start, ent.start) : max(nc.end, ent.end)])
            nc_id += 1
            eid += 1
        else:
            if nc.start < ent.end:
                spans.append(nc)
                nc_id += 1
            else:
                spans.append(ent)
                eid += 1
    spans.extend(ncs[nc_id:])
    spans.extend(ents[eid:])
    # Update each noun chunks
    updated_spans:List[Span] = []
    for span in spans:
        updated_spans.extend(remove_unimportant(span, {'ADJ', 'ADV'}))
    ent_candidates = {' '.join([t.lemma_ for t in span]) for span in updated_spans}
    return ent_candidates
    

In [None]:
class TextGraph:
    def __init__(self, docs:List[Doc]) -> None:
        self.text_graph = nx.DiGraph()
        self.ent_graph = nx.Graph()
        self.tokenized_corpus:List[List[str]] = []
        ent_pair_counter = Counter()
        for pid, doc in enumerate(docs):
            tokenized_page = [t.lemma_.lower() for t in doc]
            nouns = collect_keywords_from_text(doc)
            if len(nouns) >= 2:
                ent_pair_counter.update(map(frozenset, itertools.combinations(nouns, 2)))
            self.tokenized_corpus.append(tokenized_page)
            self.text_graph.add_node(pid, tokenized_page=tokenized_page, nouns=nouns)
        for (ent1, ent2), cnt in ent_pair_counter.items():
            self.ent_graph.add_edge(ent1, ent2, log_freq=np.log(cnt+1))
        self.ent_general_importance:Dict[str, float] = nx.pagerank(self.ent_graph, weight='log_freq')
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        for pid1 in range(len(docs)):
            bm25_scores = self.bm25.get_scores(self.tokenized_corpus[pid1])
            bm25_scores = bm25_scores / bm25_scores.sum()
            nouns1:Set[str] = self.text_graph.nodes[pid1]['nouns']
            for pid2 in range(len(docs)):
                if pid1 != pid2:
                    overlap = nouns1.intersection(self.text_graph.nodes[pid2]['nouns'])
                    if overlap:
                        ent_importance = sum([self.ent_general_importance[ent] for ent in overlap])
                        dist = 1 / np.log(np.e + np.abs(pid2 - pid1))
                        bm25_score = bm25_scores[pid2]
                        weight = statistics.harmonic_mean([ent_importance, bm25_score]) * dist
                        self.text_graph.add_edge(pid1, pid2, overlap=overlap, ent_importance=ent_importance, dist=dist, bm25_score=bm25_score, weight=weight)
        

In [None]:
tg = TextGraph([longdoc.nlp(p) for p in all_summary])

In [None]:
list(tg.text_graph.edges.data())[:5]

### Topic Modeling

In [None]:
from gensim import corpora
from gensim.parsing.preprocessing import preprocess_string, DEFAULT_FILTERS
from gensim.models import Phrases, CoherenceModel, LdaModel, EnsembleLda, LdaMulticore

In [None]:
pages = doc_split.split_paragraphs(article, 500)
all_summary = read_json('all_summary.json')

In [None]:
longdoc.llm_server(f'''
Summarize the following passage.

Passage:
{pages[1]}
''')

In [None]:
longdoc.llm_server(f'''
Summarize the following passage.

Passage:
{pages[2]}
''')

In [None]:
longdoc.llm_server(f'''
What are the common information in the following 2 passages.

Passage 1:
{pages[1]}

Passage 2:
{pages[2]}
''')

In [None]:
longdoc.llm_server(f'''
What are the different information between the following 2 passages.

Passage 1:
{pages[1]}

Passage 2:
{pages[2]}
''')

In [None]:
pages[2]

In [None]:
len(all_summary)

In [None]:
preprocess_funcs = DEFAULT_FILTERS[:-1] # Remove the stemming
preprocessed_summary = [preprocess_string(' '.join([t.lemma_ for t in longdoc.nlp(p, disable=['parser', 'ner'])]), preprocess_funcs) for p in all_summary]

# bigram = Phrases(preprocessed_summary, min_count=2, threshold=1)

# texts = [bigram[p] for p in preprocessed_summary]
texts = preprocessed_summary

# Create a dictionary from the corpus
dictionary = corpora.Dictionary(texts)

# Remove low-frequency terms from the dictionary
dictionary.filter_extremes(no_below=2)

# Convert the corpus into a bag-of-words representation
corpus = [dictionary.doc2bow(text) for text in texts]

In [None]:
lda_model = EnsembleLda(
    corpus=corpus, 
    id2word=dictionary, 
    passes=5, 
    iterations=100, 
    num_models=5, 
    # min_cores=10, 
    # min_samples=4,
    epsilon=0.05
    )

In [None]:
lda_model.print_topics()

In [None]:
topic2p = defaultdict(list)
for pid, p in enumerate(corpus):
    topic_id = sorted(lda_model[p], key=lambda x: x[1])[-1][0]
    topic2p[topic_id].append(all_summary[pid])
print(lda_model.stable_topics.shape)
print([(tid, len(topic2p[tid])) for tid in range(len(topic2p))])

In [None]:
topic2p[0]

In [None]:
topics = []
score = []
topic_models:Dict[int, LdaModel] = {}
min_docs_per_topic = 4
for topic_num in tqdm(range(4, len(all_summary) // min_docs_per_topic, 4)):
    # Build the LDA model
    lda_model = LdaMulticore(corpus, topic_num, dictionary, iterations=100, passes=5, workers=5)
    cm = CoherenceModel(lda_model, texts = texts, corpus=corpus, dictionary=dictionary, coherence='c_v')
    topics.append(topic_num)
    score.append(cm.get_coherence())
    topic_models[topic_num] = lda_model
    
plt.plot(topics, score)
plt.xlabel('Number of Topics')
plt.ylabel('Coherence Score')
plt.show()

In [None]:
lda_model = topic_models[44]

In [None]:
list(lda_model.get_document_topics(corpus))

In [None]:
for idx, topic in lda_model.print_topics(-1):
    print('Topic: {} \nWords: {}'.format(idx, topic))