In [None]:
import sys
sys.path.append('../..')

from src.index_files import *
from src.corenlp_base import Doc, Mention, Sentence

In [None]:
dataset = QualityDataset(split='dev')
# f = Factory(chunk_size=100, llm_name=None)
# article = dataset.get_article(dataset.data[2])

# Graph Tools

### GraphRAG

In [None]:
import graphrag

### My Tools

In [None]:
import spacy
from rank_bm25 import BM25Okapi
import spacy.tokens

class ChunkInfo(BaseModel):
    i: int
    chunk_text: str
    statements: List[str] = []
    entities: List[List[str]] = []
    ent_modifiers: list = []
        

class LongDoc:
    
    def __init__(self, factory:Factory, chunk_info_file:str=None) -> None:
        self.factory = factory
        self.nlp = spacy.load('en_core_web_lg')
        if chunk_info_file:
            self.chunk_infos = [ChunkInfo.parse_obj(line) for line in read_json(chunk_info_file)]
            self.enrich_index()
        
    # Index functions
    def build_index(self, article:str, chunk_info_file:str=None):
        pieces = self.factory.split_text(article)
        chunks, statements = self.generate_statements(pieces, chunk_size=3, summary_size=12)
        self.chunk_infos = [ChunkInfo(i=i, chunk_text=chunk, statements=statements[i]) for i, chunk in enumerate(chunks)]
        # missing_chunk_ids = [ci.i for ci in self.chunk_infos if not ci.statements]
        # temp_stm_groups = []
        # split_nums:List[int] = []
        # for missing_ci in missing_chunk_ids:
        #     temp_stms = statements[missing_ci]
        #     split_num = (len(temp_stms) + 1) // 5
        #     split_nums.append(split_num)
        #     for split in range(split_num):
        #         temp_stm_groups.append(temp_stms[split * 5 : (split + 1) * 5])
        # entities = self.extract_entities(temp_stm_groups)
        # start_stm_idx = 0
        # for cid, split_num in enumerate(split_nums):
        #     for sid in range(split_num):
        #         self.chunk_infos[missing_chunk_ids[cid]].statements.extend(temp_stm_groups[start_stm_idx + sid])
        #         if len(temp_stm_groups[start_stm_idx + sid]) == len(entities[start_stm_idx + sid]):
        #             self.chunk_infos[missing_chunk_ids[cid]].entities.extend(entities[start_stm_idx + sid])
        #         else:
        #             self.chunk_infos[missing_chunk_ids[cid]].entities.extend([[] for _ in range(len(temp_stm_groups[start_stm_idx + sid]))])
        #     start_stm_idx += split_num
        # missing_chunk_ids = [ci.i for ci in self.chunk_infos if not ci.statements]

        if chunk_info_file:
            write_json(chunk_info_file, [ci.dict() for ci in self.chunk_infos])
    
    def enrich_index(self):
        for ci in self.chunk_infos:
            for statement in ci.statements:
                addition_ents, ent_modifiers = self.collect_keywords_from_text(statement)
                ent_map = {}
                ci.entities.append([])
                for addition_ent in addition_ents:
                    # for ent in ci.entities[sid]:
                    #     if addition_ent.lower() in ent.lower():
                    #         ent_map[addition_ent] = ent
                    # if addition_ent not in ent_map:
                    ent_map[addition_ent] = addition_ent
                    ci.entities[-1].append(addition_ent)
                updated_ent_modifiers = []
                for ent, modifiers in ent_modifiers:
                    if isinstance(ent, str):
                        updated_ent_modifiers.append(json.dumps((ent_map[ent], modifiers)))
                    else:
                        ent, modifiers = modifiers, ent
                        updated_ent_modifiers.append(json.dumps((modifiers, ent_map[ent])))
                ci.ent_modifiers.append([json.loads(s) for s in set(updated_ent_modifiers)])

        self.build_ent_graph()
        self.build_semantic_graph()
        self.build_lexical_store()
        
    def build_ent_graph(self):
        self.ent_graph = nx.Graph()
        # semantic edges
        for ci in self.chunk_infos:
            for sid, related_ents in enumerate(ci.entities):
                loc = (ci.i, sid)
                # Insert node locs
                for e in related_ents:
                    if not self.ent_graph.has_node(e):
                        self.ent_graph.add_node(e, locs=[], norm=' '.join(self.normalize_entity(e)))
                    ent_locs:list = self.ent_graph.nodes[e]['locs']
                    if loc not in ent_locs:
                        ent_locs.insert(0, loc)
                for ent1, ent2 in itertools.combinations(related_ents, 2):
                    if not self.ent_graph.has_edge(ent1, ent2):
                        self.ent_graph.add_edge(ent1, ent2, locs=[])
                    edge_locs:list = self.ent_graph[ent1][ent2]['locs']
                    edge_locs.append(loc)
        
        self.normal2ents:Dict[str, List[str]] = defaultdict(list)
        for ent, normal in self.ent_graph.nodes(data='norm'):
            self.normal2ents[normal].append(ent)
        normals = list(self.normal2ents)
        normals.sort()
        self.ent_corpus = [ent.split() for ent in normals]
        self.ent_bm25 = BM25Okapi(self.ent_corpus)
        
        for normal in self.normal2ents:
            # Add edges between entities that have the same norm
            for ent1, ent2 in itertools.combinations(self.normal2ents[normal], 2):
                if not self.ent_graph.has_edge(ent1, ent2):
                    self.ent_graph.add_edge(ent1, ent2, locs=[])
                edge_locs:list = self.ent_graph[ent1][ent2]['locs']
                edge_locs.append(None)
            # Add edges between entities that have similar norms
            scores:List[float] = self.ent_bm25.get_scores(normal.split()).tolist()
            for score, similar_normal in zip(scores, self.ent_corpus):
                if score > 0:
                    similar_normal = ' '.join(similar_normal)
                    if normal != similar_normal:
                        for ent1, ent2 in itertools.product(self.normal2ents[normal], self.normal2ents[similar_normal]):
                            ent1, ent2 = (ent1, ent2) if len(ent2) < len(ent1) else (ent2, ent1)
                            if not self.ent_graph.has_edge(ent1, ent2):
                                self.ent_graph.add_edge(ent1, ent2, locs=[])
                            edge_locs:list = self.ent_graph[ent1][ent2]['locs']
                            edge_locs.append(None)
        
        for _, _, edge_data in self.ent_graph.edges.data():
            edge_data['weight'] = np.log(len([loc for loc in edge_data['locs'] if loc]) + 1)
                            
    def build_semantic_graph(self):
        self.semantic_graph = nx.DiGraph()
        ent1:str
        ent2:str
        for ci in self.chunk_infos:
            for sid, entities in enumerate(ci.entities):
                loc = (ci.i, sid)
                for ent1, ent2 in itertools.combinations(entities, 2):
                    ent1, ent2 = f'{ent1}_{ci.i}', f'{ent2}_{ci.i}'
                    if not self.semantic_graph.has_edge(ent1, ent2):
                        self.semantic_graph.add_edge(ent1, ent2, locs=[])
                    edge_locs:List[Tuple[int, int]] = self.semantic_graph[ent1][ent2]['locs']
                    edge_locs.append(loc)
        
        for ent, locs in self.ent_graph.nodes(data='locs'):
            chunk_ids:Set[int] = {loc[0] for loc in locs}
            for cid1, cid2 in itertools.combinations(chunk_ids, 2):
                cid1, cid2 = sorted([cid1, cid2])
                ent1, ent2 = f'{ent}_{cid1}', f'{ent}_{cid2}'
                self.semantic_graph.add_edge(ent1, ent2, locs=[None])
        for ent1, ent2, edge_locs in self.ent_graph.edges.data('locs'):
            if None in edge_locs:
                for ent1_loc, ent2_loc in itertools.product(self.ent_graph.nodes[ent1]['locs'], self.ent_graph.nodes[ent2]['locs']):
                    if ent1_loc == ent2_loc:
                        continue
                    temp_ent1, temp_ent2 = f'{ent1}_{ent1_loc[0]}', f'{ent2}_{ent2_loc[0]}'
                    # Entities from different chunks, must be similar entities
                    if ent1_loc[0] < ent2_loc[0]:
                        self.semantic_graph.add_edge(temp_ent1, temp_ent2, locs=[None])
                    elif ent1_loc[0] > ent2_loc[0]:
                        self.semantic_graph.add_edge(temp_ent2, temp_ent1, locs=[None])
                    else:
                        # Entities from same chunks, but are similar entities
                        if ent1_loc[1] < ent2_loc[1]:
                            if not self.semantic_graph.has_edge(temp_ent1, temp_ent2):
                                self.semantic_graph.add_edge(temp_ent1, temp_ent2, locs=[])
                            self.semantic_graph[temp_ent1][temp_ent2]['locs'].append(None)
                        elif ent1_loc[1] > ent2_loc[1]:
                            if not self.semantic_graph.has_edge(temp_ent2, temp_ent1):
                                self.semantic_graph.add_edge(temp_ent2, temp_ent1, locs=[])
                            self.semantic_graph[temp_ent2][temp_ent1]['locs'].append(None)
        
        for ent1, ent2, edge_data in self.semantic_graph.edges.data():
            # edge_data['explore_weight'] = 0 if ent1.rsplit('_', 1)[1] != ent2.rsplit('_', 1)[1] else 1
            edge_data['explore_weight'] = 0 if None in edge_data['locs'] else 1
                            
    def build_lexical_store(self):
        self.raw_corpus = [self.normalize_text(ci.chunk_text) for ci in self.chunk_infos]
        self.raw_bm25 = BM25Okapi(self.raw_corpus)
        
    # Retrieve functions
    def lexical_retrieval_chunks(self, query:str, n:int=5):
        chunk_idxs = self.bm25_retrieve(self.normalize_text(query), self.raw_bm25)
        return [(self.chunk_infos[idx].chunk_text, idx) for idx in chunk_idxs][:n]

    def lexical_retrieval_entities(self, query:str, n:int=5):
        tokenized_query = self.normalize_entity(query)
        normal_idxs = self.bm25_retrieve(tokenized_query, self.ent_bm25)
        candidate_normals = [' '.join(self.ent_corpus[idx]) for idx in normal_idxs]
        temp_ents = []
        for normal in candidate_normals:
            temp_ents.extend(self.normal2ents[normal])
        temp_ent_refs = [' '.join(self.split_lower_text(ent)) for ent in temp_ents]
        rouge_l = self.factory.rouge.compute(predictions=[' '.join(tokenized_query)] * len(temp_ent_refs), references=temp_ent_refs, use_aggregator=False)['rougeL']
        return [temp_ents[idx] for idx in np.argsort(rouge_l)[::-1]][:n]

    def exact_match_chunks(self, query:str):
        normalized_query = ' '.join(self.normalize_text(query))
        return [(self.chunk_infos[idx].chunk_text, idx) for idx, normalized_chunk in enumerate(self.raw_corpus) if normalized_query in ' '.join(normalized_chunk)]

    # LLM call and parser functions
    def generate_statements(self, pieces:List[str], chunk_size:int=5, summary_size:int=25, overlap:int=1):
        summary_chunks = concate_with_overlap(pieces, summary_size, overlap=overlap)
        summaries = self.factory.llm.generate([[HumanMessage(content=summary_prompt.format(chunk=' '.join(summary_chunk)))] for summary_chunk in summary_chunks])
        prompts = []
        chunks:List[str] = []
        for summary_chunk, summary in zip(summary_chunks, summaries.generations):
            temp_summary_size = min(summary_size, len(summary_chunk))
            summary_chunk = summary_chunk[:temp_summary_size]
            for batch_start in range((temp_summary_size + 1) // chunk_size):
                chunk = ' '.join(summary_chunk[batch_start * chunk_size : (batch_start + 1) * chunk_size])
                prompts.append(statement_prompt.format(summary=summary[0].text, chunk=chunk))
                chunks.append(chunk)
        return chunks, [self.parse_statements(gen[0].text) for gen in self.factory.llm.generate([[HumanMessage(content=prompt)] for prompt in prompts]).generations]

    def parse_statements(self, text:str):
        i = 1
        statements:List[str] = []
        for line in text.strip().splitlines():
            if line.startswith(f'{i}. '):
                statements.append(line.split(' ', 1)[1].strip())
                i += 1
        return statements

    def extract_entities(self, statements:List[List[str]]):
        return [self.parse_entities(gen[0].text) for gen in self.factory.llm.generate([[HumanMessage(content='List the entities in each line of the following statements.\nAvoid resolving the pronoun unless you are absolutely certain.\n\nStatements:\n' + '\n'.join([f'{sid+1}. {s}' for sid, s in enumerate(statement)]))] for statement in statements]).generations]

    def parse_entities(self, text:str):
        i = 1
        list_of_ent_list:List[List[str]] = []
        for line in text.strip().splitlines():
            line = line.strip()
            if line.startswith(f'{i}. '):
                temp_ent_list = line.split(' ', 1)[1].strip().split(',')
                ent_list:List[str] = []
                incomplete_ent = []
                for ent in temp_ent_list:
                    if '(' in ent and ')' not in ent:
                        incomplete_ent.append(ent)
                    elif '(' not in ent and ')' in ent:
                        incomplete_ent.append(ent)
                        ent_list.append(','.join(incomplete_ent).strip().strip('.'))
                        incomplete_ent.clear()
                    elif incomplete_ent:
                        incomplete_ent.append(ent)
                    else:
                        ent_list.append(ent.strip().strip('.'))
                ent_list = [ent.split(':', 1)[1].strip() if ent.startswith('Entities:') else ent for ent in ent_list]
                ent_list = [self.clean_entity(ent) for ent in ent_list]
                ent_list = [ent for ent in ent_list if ent]
                list_of_ent_list.append(ent_list)
                i += 1
        return list_of_ent_list

    # Helper functions
    def collect_keywords_from_text(self, text:str):
        
        def trim_det(noun_chunk:spacy.tokens.Span):
            for tid, t in enumerate(noun_chunk):
                if t.pos_ not in ['DET']:
                    return noun_chunk[tid:]
                
        doc = self.nlp(text)
        ncs = [trim_det(nc) for nc in doc.noun_chunks if nc.root.pos_ not in ['NUM', 'PRON']]
        ents = [trim_det(ent) for ent in doc.ents if ent.root.pos_ not in ['NUM', 'PRON']]
        
        ncs_spans = [(nc.start, nc.end) for nc in ncs if nc]
        ents_spans = [(ent.start, ent.end) for ent in ents if ent]
        nc_id, eid = 0, 0
        spans = []
        while nc_id < len(ncs_spans) and eid < len(ents_spans):
            nc_span, ent_span = ncs_spans[nc_id], ents_spans[eid]
            if set(range(*nc_span)).intersection(range(*ent_span)):
                merged_span = (min(nc_span[0], ent_span[0]), max(nc_span[1], ent_span[1]))
                spans.append(merged_span)
                nc_id += 1
                eid += 1
            else:
                if nc_span[0] < ent_span[0]:
                    spans.append(nc_span)
                    nc_id += 1
                else:
                    spans.append(ent_span)
                    eid += 1
        spans.extend(ncs_spans[nc_id:])
        spans.extend(ents_spans[eid:])
        updated_spans:List[Tuple[int, int]] = []
        for span in spans:
            doc_span = doc[span[0]:span[1]]
            if ',' in doc_span.text:
                start = doc_span.start
                for t in doc_span:
                    if t.text == ',':
                        if t.i != start:
                            updated_spans.append((start, t.i))
                        start = t.i + 1
                if start < span[1]:
                    updated_spans.append((start, span[1]))
            else:
                updated_spans.append(span)
        updated_spans = [span for span in updated_spans if any([t.pos_ in ['NOUN', 'PROPN'] for t in doc[span[0]:span[1]]])]
        updated_spans = sorted([span if doc[span[0]].pos_ != 'PRON' else (span[0]+1, span[1]) for span in updated_spans])
        ent_candidates:List[str] = []
        ent_mask = -np.ones(len(doc), dtype=np.int32)
        for span in updated_spans:
            ent = doc[span[0]:span[1]].text.strip('"\'')
            if len(ent) >= 2 and ent not in ent_candidates:
                ent_mask[span[0]:span[1]] = len(ent_candidates)
                ent_candidates.append(ent)
        
        ent_modifiers:List[Tuple[str, List[str]] | Tuple[List[str], str]] = []
        for t in doc:
            if t.pos_ in ['VERB', 'ADJ', 'AUX', 'ADP'] and ent_mask[t.i] < 0:
                modifiers = []
                # if t.pos_ in ['VERB', 'AUX']:
                if t.pos_ == 'VERB':
                    # if t.pos_ == 'AUX':
                    #     if t.dep_ == 'auxpass':
                    #         continue
                    #     else:
                    #         modifiers.append(f'{t.lemma_}_{t.i}')
                    # elif t.pos_ == 'VERB':
                    is_passive = False
                    for child in t.children:
                        if child.dep_ == 'auxpass':
                            modifiers.extend([f'{child.lemma_}_{child.i}', f'{t.text}_{t.i}'])
                            is_passive = True
                    if not is_passive:
                        modifiers.append(f'{t.lemma_}_{t.i}')
                    
                    subj_found = False
                    for child in t.children:
                        if 'subj' in child.dep_:
                            ent_modifiers.extend([(ent_candidates[ent_mask[subj.i]], modifiers) for subj in self.collect_parallel_ents(child, ent_mask)])
                            subj_found = True
                        elif 'obj' in child.dep_:
                            ent_modifiers.extend([(modifiers, ent_candidates[ent_mask[obj.i]]) for obj in self.collect_parallel_ents(child, ent_mask)])
                        elif 'advmod' == child.dep_ and ent_mask[child.i] < 0 and not child.is_stop:
                            if child.i < t.i:
                                modifiers.insert(0, f'{child.text}_{child.i}')
                            else:
                                modifiers.append(f'{child.text}_{child.i}')
                        elif 'prep' == child.dep_:
                            ent_modifiers.extend([(modifiers + new_modifiers, obj) for new_modifiers, obj in self.collect_prep_pobj(child, ent_mask, ent_candidates)])
                    if not subj_found:
                        for ancestor in t.ancestors:
                            for child in ancestor.children:
                                if 'subj' in child.dep_:
                                    ent_modifiers.extend([(ent_candidates[ent_mask[subj.i]], modifiers) for subj in self.collect_parallel_ents(child, ent_mask)])
                                    subj_found = True
                            if subj_found:
                                break
                            
                elif t.pos_ == 'ADJ':
                    modifiers.append(f'{t.text}_{t.i}')
                    for ancestor in t.ancestors:
                        if ent_mask[ancestor.i] >= 0:
                            ent_modifiers.extend([(ent_candidates[ent_mask[subj.i]], modifiers) for subj in self.collect_parallel_ents(ancestor, ent_mask)])
                            break
                    if t.dep_ == 'acomp':
                        for child in list(t.ancestors)[0].children:
                            if 'subj' in child.dep_:
                                ent_modifiers.extend([(ent_candidates[ent_mask[subj.i]], modifiers) for subj in self.collect_parallel_ents(child, ent_mask)])
                    for child in t.children:
                        if child.dep_ == 'prep':
                            ent_modifiers.extend([(modifiers + new_modifiers, obj) for new_modifiers, obj in self.collect_prep_pobj(child, ent_mask, ent_candidates)])
                        
                elif t.pos_ == 'ADP':
                    modifiers.append(f'{t.text.lower()}_{t.i}')
                    subjs:List[spacy.tokens.Token] = []
                    for ancestor in t.ancestors:
                        if ent_mask[ancestor.i] >= 0:
                            subjs.extend(self.collect_parallel_ents(ancestor, ent_mask))
                            break
                        if ancestor.pos_ == 'ADP':
                            modifiers.insert(0, f'{ancestor.text.lower()}_{ancestor.i}')
                        elif ancestor.pos_ == 'AUX':
                            for child in ancestor.children:
                                if 'subj' in child.dep_:
                                    subjs.extend(self.collect_parallel_ents(child, ent_mask))
                            break
                        else:
                            break
                    objs:List[spacy.tokens.Token] = []
                    for child in t.children:
                        if 'obj' in child.dep_:
                            objs.extend(self.collect_parallel_ents(child, ent_mask))
                    if subjs and objs:
                        ent_modifiers.extend([(ent_candidates[ent_mask[s.i]], modifiers) for s in subjs])
                        ent_modifiers.extend([(modifiers, ent_candidates[ent_mask[o.i]]) for o in objs])
                        
                elif t.pos_ == 'AUX':
                    modifiers.append(f'{t.lemma_}_{t.i}')
                    subjs:List[spacy.tokens.Token] = []
                    objs:List[spacy.tokens.Token] = []
                    for child in t.children:
                        if 'subj' in child.dep_:
                            subjs.extend(self.collect_parallel_ents(child, ent_mask))
                        if 'obj' in child.dep_:
                            objs.extend(self.collect_parallel_ents(child, ent_mask))
                    if subjs and objs:
                        ent_modifiers.extend([(ent_candidates[ent_mask[s.i]], modifiers) for s in subjs])
                        ent_modifiers.extend([(modifiers, ent_candidates[ent_mask[o.i]]) for o in objs])
        
        return ent_candidates, ent_modifiers

    def clean_entity(self, ent_text:str):
        ent_doc = self.nlp(ent_text, disable=['parser', 'ner'])
        for tid, t in enumerate(ent_doc):
            if t.pos_ not in ['DET', 'CCONJ', 'PRON']:
                return ent_doc[tid:].text
        
    def normalize_entity(self, ent_text:str):
        # return [t.text.lower() if t.pos_ != 'NOUN' else t.lemma_.lower() for t in self.nlp(ent_text, disable=['parser', 'ner']) if t.pos_ not in ['DET', 'PUNCT', 'ADP', 'SCONJ', 'PRON', 'CCONJ', 'PART', 'AUX']]
        return [t.text.lower() if t.pos_ != 'NOUN' else t.lemma_.lower() for t in self.nlp(ent_text, disable=['parser', 'ner']) if not (t.is_stop or t.pos_ == "PUNCT")]

    def normalize_text(self, text:str):
        return [t.lemma_.lower() if t.pos_ in ['NOUN', 'VERB'] else t.text.lower() for t in self.nlp(text, disable=['ner', 'parser']) if not t.is_stop]

    def split_lower_text(self, text:str) -> List[str]:
        return [t.text.lower() for t in self.nlp(text, disable=['ner', 'parser'])]

    def bm25_retrieve(self, tokenized_query:List[str], bm25:BM25Okapi):
        index_score_pairs = [(idx, score) for idx, score in enumerate(bm25.get_scores(tokenized_query)) if score > 0]
        index_score_pairs.sort(key=lambda x: x[1], reverse=True)
        return [idx for idx, _ in index_score_pairs]
        
    def collect_parallel_ents(self, ent:spacy.tokens.Token, ent_mask:np.ndarray):
        ret_list:List[spacy.tokens.Token] = []
        parallel_ents:List[spacy.tokens.Token] = [ent]
        while parallel_ents:
            temp_ent = parallel_ents[0]
            if ent_mask[temp_ent.i] >= 0:
                ret_list.append(temp_ent)
        
            for child in temp_ent.children:
                if child.dep_ in ['conj', 'appos']:
                    parallel_ents.append(child)
            parallel_ents.pop(0)
        return ret_list
    
    def collect_prep_pobj(self, child:spacy.tokens.Token, ent_mask:np.ndarray, ent_candidates:List[str]):
        ent_modifiers:List[Tuple[List[str], str]] = []
        if ent_mask[child.i] < 0:
            for grand_child in child.children:
                if 'obj' in grand_child.dep_:
                    ent_modifiers.extend([([f'{child.text}_{child.i}'], ent_candidates[ent_mask[obj.i]]) for obj in self.collect_parallel_ents(grand_child, ent_mask)])
                elif 'prep' == grand_child.dep_:
                    for grand_grand_child in grand_child.children:
                        if 'obj' in grand_grand_child.dep_:
                            ent_modifiers.extend([([f'{child.text}_{child.i}', f'{grand_child.text}_{grand_child.i}'], ent_candidates[ent_mask[obj.i]]) for obj in self.collect_parallel_ents(grand_grand_child, ent_mask)])
        return ent_modifiers

# longdoc = LongDoc(f)#, 'atomic_facts.json')
longdoc = LongDoc(f, 'atomic_facts2.json')

## Index Construction

In [None]:
longdoc.build_index(article, 'atomic_facts2.json')

In [None]:
longdoc.collect_keywords_from_text(longdoc.chunk_infos[1].statements[3])

In [None]:
cid = 0
longdoc.chunk_infos[cid].statements

In [None]:
sid = 2
print(longdoc.chunk_infos[cid].entities[sid])
print()
print(longdoc.chunk_infos[cid].ent_modifiers[sid])

In [None]:
from spacy.displacy import render
render(longdoc.nlp("The spaceship Leo is stranded on Mars' moon Phobos."))

In [None]:
doc = longdoc.nlp('Finding a cook on Phobos was difficult because it had only a handful of settlers and most of them had good-paying jobs.')

In [None]:
def cal_compression(source_text:str, output_text:str):
    return float(len(word_tokenize(output_text))) / len(word_tokenize(source_text))

for chunk_size in tqdm((5, 10, 15, 20)):
    chunks = concate_with_overlap(pieces, chunk_size=chunk_size)
    results = f.llm.generate([[HumanMessage(content=f'Rewrite the following passage into a list of statements.\nEach statement should tell an atomic fact in the passage.\nAll the statements together should cover all the information in the passage.\nTry to use the original words from the passage.\n\nPassage:\n{chunk}')] for chunk in chunks])
    compression_ratios = []
    for cid, (output, source) in enumerate(zip(results.generations, chunks)):
        if cid != len(chunks) - 1:
            compression_ratios.append(cal_compression(source, output[0].text))
    print(chunk_size, np.mean(compression_ratios))

## Question Examples

In [None]:
questions, answers = dataset.get_questions_and_answers(dataset.data[2])
questions

### Example 1

Why does the Skipper stop abruptly after he says "when you\'re running a blockade"?

In [None]:
longdoc.exact_match_chunks('''when you\'re running a blockade''')

In [None]:
longdoc.exact_match_chunks('''stop abruptly''')

In [None]:
list(longdoc.ent_graph.nodes)

### Example 2

Why does the Skipper allow the new chef to use the heat-cannon as an incinerator?

In [None]:
longdoc.lexical_retrieval_entities('heat-cannon', 10)

In [None]:
longdoc.lexical_retrieval_entities('incinerator', 10)

In [None]:
longdoc.lexical_retrieval_entities('skipper', 20)

In [None]:
longdoc.lexical_retrieval_entities('new chef', 20)

In [None]:
pr = nx.pagerank(longdoc.ent_graph.to_undirected(), personalization={'new cook': 1.0}, weight='weight')
sorted(list(pr.items()), key=lambda x: x[1], reverse=True)

In [None]:
nx.shortest_path(longdoc.ent_graph.to_undirected(), 'new cook', 'Old Man')

In [None]:
longdoc.ent_graph.to_undirected()['Old Man']['meal']

In [None]:
longdoc.chunk_infos[6].statements[2]

In [None]:
info_collection = defaultdict(Counter)
target_ents = ['heat-cannon', 'incinerator', 'skipper', 'new chef', 'new cook', 'old man']
for target_ent in target_ents:
    for real_ent in longdoc.lexical_retrieval_entities(target_ent, 20):
        if (target_ent, real_ent) in [('new chef', 'new incinerator shipshape'), ('new chef', 'new age'), ('new chef', 'new course'), ('new chef', 'new incinerator')]:
            continue
        for pid, sid in longdoc.ent_graph.nodes[real_ent]['locs']:
            info_collection[pid][target_ent] += 1
df = pd.DataFrame({target_ent: [info_collection[i][target_ent] for i in range(len(longdoc.chunk_infos))] for target_ent in target_ents}, index=range(len(longdoc.chunk_infos)))
df.plot(kind='bar', stacked=True)
plt.xlabel('Chunks')
plt.ylabel('Entity occurrence')

In [None]:
nx.has_path(longdoc.semantic_graph, 'skipper_0', 'Old Man_2')

In [None]:
list(nx.all_shortest_paths(longdoc.semantic_graph, 'skipper_0', 'Old Man_2'))

### Example 3

What would've happened if the new cook had told the Skipper about the ekalastron deposits earlier?

To effectively filter passages from the story that might answer the question "What would've happened if the new cook had told the Skipper about the ekalastron deposits earlier?", here are some key pieces of information to look for:

1. **Information about the new cook**:
   - The role and significance of the new cook in the story.
   - Any specific interactions the new cook has with the Skipper or other characters.

2. **Details about the ekalastron deposits**:
   - What the ekalastron deposits are and why they are important.
   - Any known impact or potential impact of discovering these deposits.

3. **The Skipper's role and decision-making**:
   - The Skipper's authority and responsibilities.
   - How the Skipper typically responds to important information.

4. **Consequences of the timing of information**:
   - Any events or outcomes directly influenced by the timing of discovering or sharing information about the ekalastron deposits.
   - Hypothetical scenarios or speculations within the story about different timings of revealing information.

5. **Reactions and outcomes**:
   - Characters’ reactions to discovering the ekalastron deposits.
   - Any explicit or implicit suggestions of what could have happened if the information was revealed earlier.

These points can help narrow down relevant passages that provide context, character motivations, and possible outcomes related to the timing of sharing information about the ekalastron deposits.

In [None]:
ents = longdoc.lexical_retrieval_entities('new cook', 20)
ent_cnts = [(ent, len(longdoc.ent_graph.nodes[ent]['locs'])) for ent in ents]
ent_cnts.sort(key=lambda x: x[1], reverse=True)
for eid, (ent, cnt) in enumerate(ent_cnts):
    print(f'{eid}. {ent}: {cnt}')

In [None]:
pr = nx.pagerank(longdoc.ent_graph, personalization={'new cook': 1.0}, weight='weight')
sorted(list(pr.items()), key=lambda x: x[1], reverse=True)

In [None]:
nx.shortest_path(longdoc.ent_graph, 'cook', 'Captain Slops')

In [None]:
longdoc.ent_graph['cook']['Mister Dugan']

In [None]:
longdoc.ent_graph['Mister Dugan']['Captain Slops']

In [None]:
longdoc.chunk_infos[4].statements[1]

In [None]:
longdoc.chunk_infos[4].statements[12]

In [None]:
nx.shortest_path(longdoc.ent_graph, 'new cook', 'Leo')

In [None]:
longdoc.ent_graph['new cook']['Leo']

In [None]:
longdoc.chunk_infos[9].statements[3]

In [None]:
ents = longdoc.lexical_retrieval_entities('ekalastron deposits', 20)
ent_cnts = [(ent, len(longdoc.ent_graph.nodes[ent]['locs'])) for ent in ents]
ent_cnts.sort(key=lambda x: x[1], reverse=True)
for eid, (ent, cnt) in enumerate(ent_cnts):
    print(f'{eid}. {ent}: {cnt}')

In [None]:
longdoc.ent_graph.nodes['ekalastron deposits']

In [None]:
longdoc.ent_graph.nodes['rich ekalastron deposits']

In [None]:
longdoc.chunk_infos[20].statements[17:18] + longdoc.chunk_infos[21].statements[1:3]

In [None]:
len(longdoc.chunk_infos)

In [None]:
pr = nx.pagerank(longdoc.semantic_graph.to_undirected(), personalization={'ekalastron deposits_21': 1}, alpha=0.3)
sorted(list(pr.items()), key=lambda x: x[1], reverse=True)

In [None]:
dist, path = nx.single_source_dijkstra(longdoc.semantic_graph.reverse(), 'rich ekalastron deposits_20', cutoff=5, weight='explore_weight')

In [None]:
list(nx.all_shortest_paths(longdoc.semantic_graph, 'Leo_3', 'rich ekalastron deposits_20', weight='explore_weight'))

In [None]:
dist

In [None]:
path

In [None]:
longdoc.semantic_graph.has_edge("Captain O'Hara_21", "Captain O'Hara_22")

In [None]:
pred, dist = nx.dijkstra_predecessor_and_distance(longdoc.semantic_graph.reverse(), 'ekalastron deposits_21', cutoff=5, weight='explore_weight')

In [None]:
list(longdoc.semantic_graph.predecessors('ekalastron deposits_21'))

In [None]:
dist

In [None]:
dist

In [None]:
longdoc.lexical_retrieval_entities('advice')

In [None]:
pr = nx.pagerank(longdoc.semantic_graph.reverse(), personalization={'rich ekalastron deposits_20': 1}, weight='explore_weight', alpha=0.3)
sorted(list(pr.items()), key=lambda x: x[1], reverse=True)

In [None]:
longdoc.chunk_infos[21].chunk_text

In [None]:
longdoc.chunk_infos[22].statements

In [None]:
ents = longdoc.lexical_retrieval_entities('skipper')
for eid, ent in enumerate(ents):
    print(f'{eid}. {ent}')

# Dataset

In [None]:
strategy_qa = read_json('../../data/strategyqa/strategyqa_train.json')

In [None]:
strategy_qa[0]

In [None]:
TempQuestions = read_json('../../data/TempQuestions/TempQuestions.json')

In [None]:
TempQuestions[0]

In [None]:
commonsense_qa = load_dataset('tau/commonsense_qa', split='test')

In [None]:
kilt_eli5 = load_dataset('facebook/kilt_tasks', 'eli5', split='validation')

In [None]:
asqa = load_dataset('din0s/asqa', split='dev')

In [None]:
asqa[0]

In [None]:
kilt_eli5[1]

In [None]:
narrativeqa = load_dataset('THUDM/LongBench', 'narrativeqa', split='test', trust_remote_code=True)
qasper = load_dataset('THUDM/LongBench', 'qasper', split='test', trust_remote_code=True)
gov_report = load_dataset('THUDM/LongBench', 'gov_report', split='test', trust_remote_code=True)
qmsum = load_dataset('THUDM/LongBench', 'qmsum', split='test', trust_remote_code=True)
multifieldqa_zh = load_dataset('THUDM/LongBench', 'multifieldqa_zh', split='test', trust_remote_code=True)
vcsum = load_dataset('THUDM/LongBench', 'vcsum', split='test', trust_remote_code=True)

quality = QualityDataset(split='dev')
squality = []
for sample in read_jsonline('../../data/squality/test.jsonl'):
    context, input = sample['input'].split('Question:\n')
    context = context.split('\n', 1)[1].strip()
    input = input.split('\n', 1)[0]
    squality.append({'input': input, 'context': context})

In [None]:
vcsum[7]

In [None]:
print(vcsum[7]['context'])

In [None]:
matching_tasks = [
    "passage_retrieval_en", # Match passage with summary
    # "trec", # Match question type
]

In [None]:
needs = []
inputs = []
for task in ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", \
            # "gov_report",
            "qmsum",
            # "multi_news", "vcsum", 
            # "triviaqa", # Few Shot QA
            # "samsum", # Few Shot Summary
            # "lsht", # Few Shot Summary
            # "passage_count", 
            # "lcc", # Code task
            # "repobench-p", # Code task
            ]:
    dataset = load_dataset('THUDM/LongBench', task, split='test')
    for i in range(20):
        inputs.append((dataset[i]['input'], task))

for i in range(20):
    inputs.append((squality[i]['input'], 'squality'))

needs = [[inputs[qid], needs[0].text] for qid, needs in enumerate(f.llm.generate([[HumanMessage(
    content=f'''You need to answer the following question based on a given document. Before reading the document, what information would you like to know from the document to answer this question?\n\nQuestion: {input_text[0]}''')] for input_text in inputs]).generations)]

write_json('needs.json', needs)

In [None]:
inputs[0]

In [None]:
random.shuffle(needs)

In [None]:
for n in needs:
    if n[0][1] == 'trec':
        print(n[0][0])

In [None]:
nid = 20
print(needs[nid][0])
print(needs[nid][1])

In [None]:
# with open('questions.txt') as f_in:
#     questions = f_in.read().splitlines()
random.shuffle(questions)
print('\n'.join([f'{qid+1}. {q}' for qid, q in enumerate(questions[:20])]))

In [None]:
print(qmsum[0]['context'])

In [None]:
print(qmsum[0]['answers'])

In [None]:
longdoc.build_index(qmsum[0]['context'], 'qmsum.json')

In [None]:
longdoc.enrich_index()

In [None]:
longdoc.lexical_retrieval_entities('remote control', 20)

In [None]:
longdoc.lexical_retrieval_entities('working design', 20)

In [None]:
eli5 = load_dataset('defunct-datasets/eli5')

In [None]:
!wget https://github.com/nyu-mll/SQuALITY/blob/main/data/v1-3/txt/dev.jsonl

In [None]:
with open('dev.jsonl') as f_in:
    squality = f_in.read()
    # squality = [json.loads(l) for l in f_in]

In [None]:
squality[:100]

In [None]:
squality = load_dataset('pszemraj/SQuALITY-v1.3')

# Coreference resolution

In [None]:
from stanza.server import CoreNLPClient, StartServer

def get_coref_pair(mentions:List[Mention]):
    rep, corefs = None, list[Mention]()
    for mention in mentions:
        if mention.isRepresentativeMention:
            if mention.type != 'PRONOMINAL':
                rep = mention
        else:
            corefs.append(mention)
    if not rep:
        return
    return [(*sorted([rep.sentNum, mention.sentNum]), rep.text, mention.text) for mention in corefs if mention.sentNum != rep.sentNum and mention.text != rep.text]

def plot_sents_by_coref(article:str, doc:Doc, sent_span:Tuple[int, int]):
    print(article[doc.sentences[sent_span[0] - 1].tokens[0].characterOffsetBegin : doc.sentences[sent_span[1] - 1].tokens[-1].characterOffsetEnd], '\n')

In [None]:
# text = "Chris Manning is a nice person. Chris wrote a simple sentence. It was a joke."
text = ' '.join(dataset.get_article(dataset.data[2]).split())
with CoreNLPClient(
    start_server=StartServer.DONT_START,
    annotators=['tokenize','ssplit','pos','lemma','ner','depparse','coref'],
    endpoint='http://172.22.224.150:9000') as client:
    ann = client.annotate(text, output_format='json')
    doc = Doc(**ann)

In [None]:
len(doc.sentences)

In [None]:
len(doc.corefs)

In [None]:
doc.corefs.keys()

In [None]:
doc.corefs['285']

In [None]:
dists = []
for last_id, mentions in doc.corefs.items():
    coref_pairs = get_coref_pair(mentions)
    if coref_pairs:
        dists.extend([(sum([len(doc.sentences[sid].tokens) for sid in range(s1, s2+1)]), s1, s2, rep, men) for s1, s2, rep, men in coref_pairs])
sorted(dists)

In [None]:
dkg = list[Sentence]()
dkg_miss = list[Sentence]()
for sent in doc.sentences:
    deps = {dep.dep for dep in sent.basicDependencies}
    if 'nsubj' in deps and deps.intersection({'obj', 'obl', 'nmod:poss'}):
        dkg.append(sent)
    else:
        dkg_miss.append(sent)

In [None]:
len(dkg) * 1. / len(doc.sentences)

In [None]:
np.mean([0.6272965879265092, 0.7296222664015904, 0.6307977736549165])

In [None]:
i = 6
plot_sents_by_coref(text, doc, (dkg[i].index + 1, dkg[i].index + 1))

In [None]:
i = 11
plot_sents_by_coref(text, doc, (dkg_miss[i].index + 1, dkg_miss[i].index + 1))

In [None]:
for sent in dkg:
    plot_sents_by_coref(text, doc, (sent.index + 1, sent.index + 1))

In [None]:
dkg_miss[11].basicDependencies

In [None]:
plot_sents_by_coref(text, doc, (297,
  304))

In [None]:
'PRONOMINAL'

In [None]:
doc.sentences[368].tokens[10]

In [None]:
' '.join([t.word for t in doc.sentences[368].tokens])

In [None]:
' '.join([t.word for t in doc.sentences[369].tokens])

In [None]:
doc.sentences[47].tokens[0].characterOffsetBegin

In [None]:
doc.sentences[47].tokens[-1].characterOffsetEnd

In [None]:
text[3631:3678]

In [None]:
sent_coref_graph = nx.Graph()
for last_id, mentions in doc.corefs.items():
    sents = list({mention.sentNum for mention in mentions})
    if len(sents) > 1 and mentions[0].type != 'PRONOMINAL':
        sent_coref_graph.add_edges_from(zip(sents[:-1], sents[1:]))

In [None]:
list(nx.connected_components(sent_coref_graph))

In [None]:
coref_spans = list[tuple[int, int]]()
for last_id, mentions in doc.corefs.items():
    sents = list({mention.sentNum for mention in mentions})
    if len(sents) > 1 and mentions[0].type != 'PRONOMINAL':
        for mention in mentions[1:]:
            if mention.sentNum - mentions[0].sentNum <= 30:
                coref_spans.append((mentions[0].sentNum, mention.sentNum))
coref_spans.sort()

merged_coref_spans = list[tuple[int, int]]()
for coref_span in coref_spans:
    if not merged_coref_spans:
        merged_coref_spans.append(coref_span)
    else:
        if coref_span[1] <= merged_coref_spans[-1][1]:
            continue
        elif coref_span[1] - merged_coref_spans[-1][0] <= 30:
            merged_coref_spans[-1] = (merged_coref_spans[-1][0], coref_span[1])
        else:
            merged_coref_spans.append(coref_span)

In [None]:
coref_spans

In [None]:
merged_coref_spans

# Connections

In [None]:
dataset = QualityDataset(split='dev')
f = Factory(chunk_size=50, llm_name=None, embeder_name='sentence-transformers/all-MiniLM-L6-v2')
article = dataset.get_article(dataset.data[2])

In [None]:
chunks = f.split_text(article)

In [None]:
chunk_embs = np.array(f.embeder.embed_documents(chunks))

In [None]:
from sklearn.cluster import DBSCAN
clustering = DBSCAN(eps=0.4, min_samples=2, metric='cosine').fit(chunk_embs)
# clustering.labels_
# clustering

In [None]:
clustering.labels_

In [None]:
from src.cluster_utils import *

In [None]:
clusters = perform_clustering(np.array(chunk_embs), dim=20, threshold=0.5)

In [None]:
clusters

In [None]:
cluster2chunks = defaultdict(list)
for c, chunk in zip(clustering.labels_, chunks):
    cluster2chunks[int(c)].append(chunk)

In [None]:
cluster2chunks[0]

In [None]:
len(chunks)

In [None]:
chunks[0]