In [None]:
import time, datetime, json, os
from tqdm.notebook import tqdm
from collections import defaultdict, Counter
import numpy as np

from index_files import LongDoc, write_json, QualityDataset, NarrativeQADataset, ReadingAgent, read_json, read_jsonline, LLM, Retriever

In [None]:
retriever = Retriever()
llm = LLM()
# llm = 'mistralai/Mistral-7B-Instruct-v0.2'

# Experiment

In [None]:
dataset = NarrativeQADataset(llm)

In [None]:
dataset = QualityDataset(llm, split='dev')

In [None]:
reading_agent = ReadingAgent(dataset, llm)

## Index passages

In [None]:

from typing import List, Tuple, Set, Dict
import re
import itertools
import networkx as nx
from prompt import prompt_shorten_template, prompt_ent_description_template, \
    prompt_relation_description_template, prompt_shorten_w_note_template, \
    prompt_ent_description_w_note_template, prompt_relation_description_w_note_template

def match_entities(target_ents:List[str], refer_ents:List[str]):
    target_ents_emb = retriever.embed_paragraphs(target_ents, True)
    refer_ents_emb = retriever.embed_paragraphs(refer_ents, True)
    sim_mat:np.ndarray = np.matmul(target_ents_emb, refer_ents_emb.T)
    ent_map:Dict[str, str] = {}
    for eid, ent in enumerate(target_ents):
        max_idx = sim_mat[eid].argmax()
        if sim_mat[eid, max_idx] > 0.8:
            ent_map[ent] = refer_ents[max_idx]
    return ent_map

def remove_citation(text:str):
    # if 'source passage' in text.lower():
    #     text = text[:text.lower().index('source passage')]
    #     return text.strip('( )')
    # else:
    #     return text
    return re.sub("[\(\[].*?[\)\]]", "", text).strip()

def index_text(paragraphs:List[str], recap_num:int=2, context_type:str='novel'):
    results = []
    test_results = []
    for paragraph in tqdm(paragraphs):
        
        # Extract important entities
        list_entity_prompt = f'''Context:\n\n{paragraph}\n\nAbove is part of a {context_type}. List the important named entities in the above context that are relevant to most of its content. Don't give any explanation. Generate your response in the following format: "Important entities:\n1. Entity 1\n2. Entity 2\n3. Entity 3\n..."'''
        chat_response = llm(list_entity_prompt, 5, 0.7)[0]
        ent_lists:List[str] = []
        ent_cnt = Counter()
        for response in chat_response:
            i = 1
            temp_ents = []
            for line in response.splitlines():
                if line.startswith(f'{i}. '):
                    temp_ents.append(line.split(' ', 1)[1].strip().strip('.'))
                    i += 1
            ent_lists.append(temp_ents)
            ent_cnt.update(temp_ents)
        g = nx.Graph()
        for list1, list2 in itertools.combinations(ent_lists, 2):
            g.add_edges_from(match_entities(list1, list2).items())
        ent_cluster:Set[str]
        rep_cnt = {}
        for ent_cluster in nx.connected_components(g):
            cnts = [(ent_cnt[ent], ent) for ent in ent_cluster]
            cnts.sort(key=lambda x: x[0], reverse=True)
            rep_cnt[cnts[0][1]] = sum([cnt for cnt, _ in cnts])
        important_ents = [rep for rep, cnt in rep_cnt.items() if cnt >= 3]
        
        # Generate entity description, summary, relation description
        important_ents_str = '\n'.join(important_ents)
        prompt_ent_description = prompt_ent_description_template.format(paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str, important_ents_0=important_ents[0], important_ents_1=important_ents[1])
        prompt_shorten = prompt_shorten_template.format(paragraph)
        # prompt_relation_description = prompt_relation_description_template.format(paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str)
        
        ent_description, shorten = llm([prompt_ent_description, prompt_shorten])
        ent_description, shorten = ent_description[0], shorten[0]
        description_dict = {}
        for line in ent_description.splitlines():
            if line:
                ent, description = line.split(': ', 1)
                ent = ent.strip()
                description = description.strip()
                if ent in important_ents:
                    description_dict[ent] = description
        test_results.append({
            'paragraph': paragraph, 
            'important_ents': important_ents, 
            'description_dict': description_dict, 
            'shorten': shorten, 
            # 'relation_description': relation_description,
            'prompt_ent_description': prompt_ent_description,
            'prompt_shorten': prompt_shorten,
            # 'prompt_relation_description': prompt_relation_description
        })
        
        if len(results):
            recaps = []
            for rid, result in enumerate(results[-recap_num:]):
                prev_description_dict:Dict[str, str] = result['description_dict']
                match_dict = match_entities(important_ents, list(prev_description_dict.keys()))
                prev_description = '\n'.join([f'{ent}: {remove_citation(prev_description_dict[ent])}' for _, ent in match_dict.items()])
                recap = f'Passage {rid - recap_num}:\nEntity descriptions:\n{prev_description}\nSummary:\n{result["shorten"]}'
                recaps.append(recap)
            recap_str = '\n\n'.join(recaps)
            prompt_ent_description = prompt_ent_description_w_note_template.format(recap=recap_str, paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str, important_ents_0=important_ents[0], important_ents_1=important_ents[1])
            prompt_shorten = prompt_shorten_w_note_template.format(recap_str, paragraph)
            # prompt_relation_description = prompt_relation_description_w_note_template.format(recap=recap_str, paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str)
            
            ent_description, shorten = llm([prompt_ent_description, prompt_shorten])
            ent_description, shorten = ent_description[0], shorten[0]
            description_dict = {}
            for line in ent_description.splitlines():
                if ': ' in line:
                    ent, description = line.split(': ', 1)
                    ent = ent.strip()
                    description = description.strip()
                    if ent in important_ents:
                        description_dict[ent] = description
        
            results.append({
                'paragraph': paragraph, 
                'important_ents': important_ents, 
                'description_dict': description_dict, 
                'shorten': shorten, 
                # 'relation_description': relation_description,
                'prompt_ent_description': prompt_ent_description,
                'prompt_shorten': prompt_shorten,
                # 'prompt_relation_description': prompt_relation_description
            })
        else:
            results.append({
                'paragraph': paragraph, 
                'important_ents': important_ents, 
                'description_dict': description_dict, 
                'shorten': shorten, 
                # 'relation_description': relation_description,
                'prompt_ent_description': prompt_ent_description,
                'prompt_shorten': prompt_shorten,
                # 'prompt_relation_description': prompt_relation_description
            })
    return test_results, results

In [None]:
paragraphs = ['\n'.join(p) for p in read_json(os.path.join(dataset.data_dir, f'pages_{1}.json'))]

In [None]:
test_results, results = index_text(paragraphs, 3)
write_json('results3.json', results)
write_json('test_results3.json', test_results)

## Analyze results

In [None]:
results = dataset.load_and_eval_result(0, 10, {'gist': 'wo_c_', 'dpr': 'wo_c_'})

In [None]:
examples = []
for i in range(len(results['gist'])):
    # if not results['dpr'][i]['acc'] and results['gist'][i]['acc']:
    #     examples.append({'index': results['index'][i], 'gist': results['gist'][i], 'dpr': results['dpr'][i], 'i': i})
        examples.append({'gist': results['gist'][i], 'dpr': results['dpr'][i], 'i': i})
print(len(examples))

In [None]:
examples[0]['gist'].keys()

In [None]:
eid = 6
paragraphs = ['\n'.join(page) for page in read_json(os.path.join(dataset.data_dir, f"pages_{examples[eid]['gist']['task_i']}.json"))]
print(f'''
Task id: {examples[eid]['gist']['task_i']}
Question id: {examples[eid]['gist']['q_i']}

Question: {examples[eid]['gist']['query']}

Gold answer: {examples[eid]['gist']['gold']}


DPR result:

{examples[eid]['dpr']['predict']}

{examples[eid]['dpr']['acc']}

{examples[eid]['dpr']['generation']}



GIST result:

{examples[eid]['gist']['predict']}

{examples[eid]['gist']['acc']}

{examples[eid]['gist']['generation']}
''')

In [None]:
examples[eid]['gist']['steps']

In [None]:
print(examples[eid]['gist']['steps'][0][1])

In [None]:
examples[eid]['dpr']['steps']

In [None]:
len(paragraphs)

In [None]:
print(paragraphs[11])

In [None]:
for pid, p in enumerate(paragraphs):
    if 'gold' in p.lower():
        print(f'Passage {pid}:\n' + p)

In [None]:
longdoc = LongDoc(dataset, llm=llm, device='cpu')

In [None]:
longdoc.index_text(paragraphs[14:15])