# Retrieval observation

In [None]:
from index_files import LongDoc, read_jsonline, read_json, write_json
from datasets import load_dataset
import os

longdoc = LongDoc(llm_name="mistralai/Mistral-7B-Instruct-v0.2", device='cuda:0')

## Get context

In [None]:
task_i = 0
question_i = 4

### LongBench

In [None]:
datasets = [
    "narrativeqa", 
    # "qasper", 
    # "multifieldqa_en", 
    # "multifieldqa_zh", 
    # "hotpotqa", 
    # "2wikimqa", 
    # "musique", 
    # "dureader", 
    # "gov_report", 
    # "qmsum", 
    # "multi_news", 
    # "vcsum", 
    # "trec", 
    # "triviaqa", 
    # "samsum", 
    # "lsht", 
    # "passage_count", 
    # "passage_retrieval_en", 
    # "passage_retrieval_zh", 
    # "lcc", 
    # "repobench-p"
]
task_name = datasets[0]

dataset_dict = {task_name: load_dataset('THUDM/LongBench', task_name, split='test')}

In [None]:
context = dataset_dict[task_name][task_i]['context']
query = dataset_dict[task_name][task_i]['input']
answer = dataset_dict[task_name][task_i]['answers']

### QuALITY

In [None]:
quality_qa = read_jsonline('../../data/QuALITY/QuALITY.v1.0.1.htmlstripped.train')
task_name = 'quality'

In [None]:
context = quality_qa[task_i]['article']
query = quality_qa[task_i]['questions'][question_i]['question']
answer = quality_qa[task_i]['questions'][question_i]['options'][quality_qa[task_i]['questions'][question_i]['gold_label'] - 1]

## LLM inference

In [None]:
import json
for eid in range(10):
    with open(f'quality/pages_{eid}.json') as f_in:
        pages = json.load(f_in)
        index_file = f'quality/response_{eid}.json'
        write_json(index_file, longdoc.index_text(['\n'.join(page) for page in pages]))

### Coreference Resolution (Test)

In [None]:
coref_batch = 4
coref_resolver:fastcoref.spacy_component.FastCorefResolver = longdoc.nlp.get_pipe('fastcoref')
coref_resolved_paragraphs = []
for bid in tqdm(range((len(paragraphs[:10]) - 1) // coref_batch + 1)):
    if bid == 0:
        frozen_paragraphs = []
        update_paragraphs = paragraphs[:coref_batch]
    else:
        frozen_paragraphs = coref_resolved_paragraphs[-1:]
        update_paragraphs = paragraphs[bid * coref_batch : (bid + 1) * coref_batch]
    
    batch_paragraphs = frozen_paragraphs + update_paragraphs
    doc = longdoc.nlp(''.join(batch_paragraphs))
    prev_char_num = 0
    paragraph_char_seps = []
    for paragraph in batch_paragraphs:
        paragraph_char_seps.append((prev_char_num, prev_char_num + len(paragraph)))
        prev_char_num = paragraph_char_seps[-1][1]
    
    clusters:List[List[Tuple[int, int]]] = doc._.coref_clusters
    
    # Normalize the referred entities
    idx2nc = {nc.root.i: nc for nc in doc.ents if '\n\n' not in nc.text}
    new_clusters = []
    for cluster in clusters:
        indices = coref_resolver._get_span_noun_indices(doc, cluster)
        if indices:
            new_cluster = []
            for sid, span in enumerate(cluster):
                if sid in indices:
                    doc_span = doc.char_span(span[0], span[1])
                    if ',' in doc_span.text:
                        root_i = doc_span.root.i
                        if root_i in idx2nc:
                            new_cluster.append((idx2nc[root_i].start_char, idx2nc[root_i].end_char))
                        continue
                new_cluster.append(span)
            new_clusters.append(new_cluster)
    clusters = new_clusters
    
    # Resolve part of prons
    resolved = list(tok.text_with_ws for tok in doc)
    all_spans = [span for cluster in clusters for span in cluster]
    for cluster in clusters:
        indices = coref_resolver._get_span_noun_indices(doc, cluster)
        if indices and doc.char_span(cluster[indices[0]][0], cluster[indices[0]][1]).root.i in idx2nc:
            mention_span, mention = coref_resolver._get_cluster_head(doc, cluster, indices)
            marked = ([True] * len(frozen_paragraphs)) + ([False] * len(update_paragraphs))
            pid = 0
            for pid, (p_start, p_end) in enumerate(paragraph_char_seps):
                if mention[0] >= p_start and mention[0] < p_end:
                    marked[pid] = True
                    break
            for coref in cluster:
                if coref != mention and not coref_resolver._is_containing_other_spans(coref, all_spans):
                    while pid < len(marked) and marked[pid]:
                        pid += 1
                    if pid == len(marked):
                        break
                    if coref[0] >= paragraph_char_seps[pid][0] and coref[0] < paragraph_char_seps[pid][1]:
                        marked[pid] = True
                        coref_resolver._core_logic_part(doc, coref, resolved, mention_span)
    coref_resolved_paragraphs.append("".join(resolved)[sum([len(p) for p in frozen_paragraphs]):])

In [None]:
print(''.join(coref_resolved_paragraphs))

In [None]:
print(''.join(paragraphs))

## Interact with LLM

In [None]:
for r_tool in ['dpr', 'index', 'gist']:
    for task_i in range(10):
        index_results = []
        temp_result = {}
        for line in read_jsonline(f'quality/response_{r_tool}_{task_i}_log.jsonl'):
            temp_result[line[0]] = line[1]
            if line[0] == 'current_summary':
                index_results.append(temp_result)
                temp_result = {}
        results = []
        for qid in range(len(quality_qa[task_i]['questions'])):
            current_summary = index_results[qid]['current_summary']
            query = quality_qa[task_i]['questions'][qid]['question']
            options = '\n'.join([f'{oid + 1}: {option}' for oid, option in enumerate(quality_qa[task_i]['questions'][qid]['options'])])
            answer = quality_qa[task_i]['questions'][qid]['gold_label']
            writer_answer = quality_qa[task_i]['questions'][qid]['writer_label']
            prompt = f'''Answer the question based on a given summary.\n\n{current_summary}\n\nQuestion: {query}\n{options}\n\nChoose the correct option above and return the option number. Generate your answer in the following format:"Answer: the option number".'''
            gen = longdoc._call_llm(prompt).choices[0].message.content
            fail_cnt = 0
            while not gen.strip().lower().startswith('answer: ') or not gen.strip()[8].isnumeric():
                fail_cnt += 1
                if fail_cnt >= 5:
                    break
                gen = longdoc._call_llm(prompt).choices[0].message.content
            results.append({'prompt': prompt, 'gen': gen, 'gold': answer, 'writer': writer_answer})
            print(task_i, qid)
        write_json(f'quality/generation_{r_tool}_{task_i}.json', results)

In [None]:
accuracy = {}
results = {}
retrieved_passages = {}
for r_tool in ['index', 'dpr', 'gist']:
    accuracy[r_tool] = []
    results[r_tool] = []
    retrieved_passages[r_tool] = []
    for task_i in range(0, 10):
        results[r_tool].extend(read_json(f'quality/generation_{r_tool}_{task_i}.json'))
        retrieved_passages[r_tool].extend([(task_i, line[1]) for line in read_jsonline(f'quality/response_{r_tool}_{task_i}_log.jsonl') if line[0] == 'retrieval_result'])
    for result in results[r_tool]:
        try:
            answer_start = result['gen'].lower().index('answer: ')
            accuracy[r_tool].append(int(result['gen'][answer_start + 8]) == result['writer'])
        except:
            accuracy[r_tool].append(False)
    print(r_tool, sum(accuracy[r_tool]) * 1. / len(accuracy[r_tool]))

In [None]:
examples = []
for i in range(len(retrieved_passages['index'])):
    if accuracy['index'][i] and not accuracy['dpr'][i]:
        examples.append({'index': results['index'][i], 'gist': results['gist'][i], 'dpr': results['dpr'][i], 'task_id': retrieved_passages['index'][i][0], 'index_ret': retrieved_passages['index'][i][1], 'gist_ret': retrieved_passages['gist'][i][1], 'dpr_ret': retrieved_passages['dpr'][i][1]})

In [None]:
eid = 0
print(f'''Task id: {examples[eid]['task_id']}

Index result:

{examples[eid]['index']['prompt']}

{examples[eid]['index']['gen']}



DPR result:

{examples[eid]['dpr']['prompt']}

{examples[eid]['dpr']['gen']}



GIST result:

{examples[eid]['gist']['prompt']}

{examples[eid]['gist']['gen']}



Index Retrieved:

{examples[eid]['index_ret']}



DPR Retrieved:

{examples[eid]['dpr_ret']}



GIST Retrieved:

{examples[eid]['gist_ret']}
''')

In [None]:
print(quality_qa[3]['article'])

In [None]:
print(quality_qa[2]['questions'][1]['question'])

In [None]:
quality_qa[2]['questions'][1]

In [None]:
lid = 4
print(f'''{index_results[lid]['menu']}\n\n\n\n{index_results[lid]['retrieval_command']}\n\n\n\n{index_results[lid]['retrieval_result']}''')

In [None]:
index_results[lid]

In [None]:
# Load index
results = longdoc.main(query, f'{task_name}/response_{task_i}.json')

In [None]:
results

### Step 1: identity entities of interest

In [None]:
query_entity_prompt = f'''Question: {query}\nYou need to answer the above question based on a given story. Before reading the story, identify some entities and keywords in the question you want to query from the story for useful information. Don't give any explanation. Generate your response in the following format:\n"Query entities:\nthe first entity, the second entity, ...\n\nQuery keywords:\nthe first keyword, the second keyword, ...".'''

In [None]:
chat_response = longdoc.llm_server.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[
        {"role": "user", "content": query_entity_prompt},
    ]
)
print(chat_response.choices[0].message.content)

In [None]:
entities, keywords = longdoc.identify_entity_keyword(query)
entities + keywords

In [None]:
mention_sets = longdoc.retrieve_node(doc_index.graph, entities + keywords, 10)
# ent_sets = longdoc.retrieve_node(all_graph, ['the witch', "the character's residence in the story"])
mention_sets

### Step 2: retrieve summary/original text

In [None]:
menu, pair2sids = longdoc.retrieve_menu(mention_sets, doc_index)
print(menu)

### Step 3: analyze retrieved info

### Step 4: continue searching or start answering

In [None]:
decision_prompt = f'''Question: {query}\n\nYou need to answer the above question based on a given story.\nBelow is a list of related entities and entity pairs contained in each passage from the story. The passage numbers are assigned based on the original order of the passages in the text.\n\n'''
decision_prompt += menu
retrieved_passage_idx_str = ', '.join(map(str, passage_indices))
decision_prompt += f'''Below is the summary of useful information from passage {retrieved_passage_idx_str}.\n\n'''
decision_prompt += current_summary
decision_prompt += '''Now, you need to choose whether to continue searching for more information or to start answering the question.

If the information is not adequate, you may choose to continue searching. Select a retrieval type and the passage numbers. For the retrieval type, you may choose "original text" to retrieve the original passages, or "summary" to retrieve the summary of the entities in the passage. For passage selection, you may select passage numbers that do not exist in the above list to obtain continuous contextual information. You may retrieve either 5 passages for "original text" or 10 passages for "summary". Generate your response in the following format:\n"Retrieval type: summary/original text\nPassage numbers: first passage number, second passage number, ...".

Otherwise, if the information is adequate, you may choose to start answering the question. Generate your answer to the question in the following format:\n"Answer: your answer here".

For either choice, don't give any explanation.'''

## Long context and multi-hop reasoning

In [None]:
from datasets import load_dataset
import json

### HotpotQA

In [None]:
hotpot_qa = load_dataset('hotpot_qa', 'distractor', split='validation')

In [None]:
hotpot_qa[1]

### NarrativeQA

In [None]:
narrative_qa = load_dataset('narrativeqa', split='train')

In [None]:
narrative_qa[0]

### QASPER

In [None]:
qasper = load_dataset('allenai/qasper', split='train')

In [None]:
qasper[0]

### QuALITY

In [None]:
quality_qa = [json.loads(l) for l in open('../../QuALITY.v1.0.1.train')]

In [None]:
quality_qa[0]

### openbookqa

In [None]:
openbookqa = load_dataset('openbookqa', 'main', split='train')

In [None]:
openbookqa[0]

### LongBench

In [None]:
datasets = [
    # "narrativeqa", 
    # "qasper", 
    # "multifieldqa_en", 
    # "multifieldqa_zh", 
    # "hotpotqa", 
    # "2wikimqa", 
    "musique", 
    # "dureader", 
    # "gov_report", 
    # "qmsum", 
    # "multi_news", 
    # "vcsum", 
    # "trec", 
    # "triviaqa", 
    # "samsum", 
    # "lsht", 
    # "passage_count", 
    # "passage_retrieval_en", 
    # "passage_retrieval_zh", 
    # "lcc", 
    # "repobench-p"
]
task_name = datasets[0]

dataset_dict = {task_name: load_dataset('THUDM/LongBench', task_name, split='test')}
print(dataset_dict[task_name][1]['context'])

In [None]:
dataset_dict[task_name][1]

### LooGLE

In [None]:
datasets = [
    # "shortdep_qa", 
    # "shortdep_cloze", 
    "longdep_qa", 
    # "longdep_summarization"
]

for testset in datasets:
    data = load_dataset('bigainlco/LooGLE', testset, split='test')

In [None]:
data[0]