# Retrieval observation

In [1]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from typing import List, Tuple
import torch
from datasets import load_dataset
from nltk import sent_tokenize
from openai import OpenAI
import networkx as nx
from collections import defaultdict, Counter
import itertools
from tqdm import tqdm
import json
import pickle
from rouge_metric import PyRouge
import spacy
import numpy as np
import fastcoref.spacy_component

In [2]:
datasets = [
    "narrativeqa", 
    # "qasper", 
    # "multifieldqa_en", 
    # "multifieldqa_zh", 
    # "hotpotqa", 
    # "2wikimqa", 
    # "musique", 
    # "dureader", 
    # "gov_report", 
    # "qmsum", 
    # "multi_news", 
    # "vcsum", 
    # "trec", 
    # "triviaqa", 
    # "samsum", 
    # "lsht", 
    # "passage_count", 
    # "passage_retrieval_en", 
    # "passage_retrieval_zh", 
    # "lcc", 
    # "repobench-p"
]
task_name = datasets[0]

dataset_dict = {task_name: load_dataset('THUDM/LongBench', task_name, split='test')}

In [3]:
def split_trec(text:str):
    lines = text.splitlines()
    return ['\n'.join(lines[i * 2 : i * 2 + 1]) for i in range(len(lines) // 2)]

def split_triviaqa(text:str):
    lines = text.splitlines()
    paragraphs = []
    paragraph = []
    lid = 0
    while lid < len(lines):
        paragraph.append(lines[lid])
        if lines[lid] == 'Answer:':
            lid += 1
            paragraph.append(lines[lid])
            paragraphs.append('\n'.join(paragraph))
            paragraph.clear()
        lid += 1
    return paragraphs

def split_samsum(text:str):
    paragraphs = []
    paragraph = []
    for line in text.splitlines():
        paragraph.append(line)
        if line.startswith('Summary: '):
            paragraphs.append('\n'.join(paragraph))
            paragraph.clear()
    return paragraphs

class LongDoc:
    paragraph_sep_map = {
        'qasper': '\n', 
        'multifieldqa_zh': '\n', 
        'qmsum': '\n', 
        'multi_news': '\n', 
        'vcsum': '\n', 
        'trec': (split_trec, '\n'), 
        'triviaqa': (split_triviaqa, '\n'), 
        'samsum': (split_samsum, '\n'), 
    }
    
    def __init__(self, retriever_model_name:str='facebook/contriever', llm_name:str='meta-llama/Llama-2-7b-hf') -> None:
        self.device = torch.device('cuda:0')
        self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_name)
        # self.nlp = spacy.load('en_core_web_lg')#, disable=['attribute_ruler', 'lemmatizer', 'ner'])
        # self.nlp.add_pipe('coreferee')
        # self.nlp = spacy.load("en_core_web_lg")
        # self.nlp.add_pipe("fastcoref", 
        #      config={'model_architecture': 'LingMessCoref', 'model_path': 'biu-nlp/lingmess-coref', 'device': 'cuda:1', 'enable_progress_bar': False}
        # )
        self.retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
        self.retriever_model = AutoModel.from_pretrained(retriever_model_name)
        self.retriever_model.cuda(device=self.device)
        
    def get_task_paragraph_sep(self, task_name:str):
        sep = self.paragraph_sep_map.get(task_name, '\n\n')
        if not isinstance(sep, str):
            func, sep = sep
        return sep
    
    def split_context_to_paragraphs(self, context:str, task_name:str):
        sep = self.paragraph_sep_map.get(task_name, '\n\n')
        if isinstance(sep, str):
            return context.split(sep)
        else:
            func, sep = self.paragraph_sep_map[task_name]
            return func(context)
    
    # Mean pooling
    @staticmethod
    def _mean_pooling(token_embeddings:torch.Tensor, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings:torch.Tensor  = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

    def _append_paragraph(self, paragraphs:list, tokenized_p:List[str]):
        paragraph = self.llm_tokenizer.decode(tokenized_p)
        paragraphs.append(paragraph)
        tokenized_p.clear()
    
    def split_single_paragraph(self, text:str, paragraph_size:int=300, is_natural_language:bool=True):
        splited_paragraphs:List[str] = []
        splited_paragraph = []
        sentences:List[str] = sent_tokenize(text) if is_natural_language else text.split('\n')
        for sent in sentences:
            tokenized_s = self.llm_tokenizer.encode(sent)[1:]
            if len(tokenized_s) <= paragraph_size:
                if len(splited_paragraph) + len(tokenized_s) > paragraph_size:
                    self._append_paragraph(splited_paragraphs, splited_paragraph)
                splited_paragraph.extend(tokenized_s)
            else:
                if splited_paragraph:
                    self._append_paragraph(splited_paragraphs, splited_paragraph)
                chunk_size = (len(tokenized_s) - 1) // paragraph_size + 1
                for i in range(chunk_size - 1):
                    self._append_paragraph(splited_paragraphs, tokenized_s[i * paragraph_size: (i+1) * paragraph_size])
                splited_paragraph = tokenized_s[(chunk_size - 1) * paragraph_size:]
        
        return splited_paragraphs, splited_paragraph
            
        
    def split_paragraphs(self, text:str, task_name:str, paragraph_size:int=300):
        reformated_paragraphs:List[str] = []
        completion_labels:List[bool] = []
        reformated_paragraph = []
        
        paragraph_sep = self.get_task_paragraph_sep(task_name)
        paragraphs = text.split(paragraph_sep)
        for p in paragraphs:
            tokenized_p = self.llm_tokenizer.encode(p + paragraph_sep)[1:]
            if len(tokenized_p) <= paragraph_size:
                if len(reformated_paragraph) + len(tokenized_p) > paragraph_size:
                    self._append_paragraph(reformated_paragraphs, reformated_paragraph)
                    completion_labels.append(True)
                reformated_paragraph.extend(tokenized_p)
            else:
                if reformated_paragraph:
                    self._append_paragraph(reformated_paragraphs, reformated_paragraph)
                    completion_labels.append(True)
                splited_paragraphs, splited_paragraph = self.split_single_paragraph(p, paragraph_size)
                reformated_paragraphs.extend(splited_paragraphs)
                completion_labels.extend([False] * len(splited_paragraphs))
                reformated_paragraph = splited_paragraph
                
        if reformated_paragraph:
            self._append_paragraph(reformated_paragraphs, reformated_paragraph)
            completion_labels.append(True)
        
        return reformated_paragraphs, completion_labels
    
    def embed_paragraphs(self, paragraphs:List[str]):
        retriever_input = self.retriever_tokenizer.batch_encode_plus(paragraphs, padding=True, truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            retriever_output = self.retriever_model(**retriever_input)
            paragraph_embeddings= self._mean_pooling(retriever_output[0], retriever_input['attention_mask'])
        return paragraph_embeddings
    
    def retrieve_paragraphs(self, question:str, paragraphs:List[str], paragraph_embeddings:torch.Tensor, k:int=5, order_by_rank:bool=True):
        question_embeddings = self.embed_paragraphs([question])
        with torch.no_grad():
            ranks = torch.matmul(paragraph_embeddings, question_embeddings.T).squeeze()
            indices = torch.topk(ranks, k).indices.tolist()
            if not order_by_rank:
                indices.sort()
        return [paragraphs[idx] for idx in indices], indices
    
    def parse_to_graph(self, response:str):
        start_ents = False
        start_summary = False
        summary:List[str] = []
        line:str
        temp_graph = nx.DiGraph()
        for line in response.splitlines():
            if line.strip().startswith('Important ent'):
                start_ents = True
            elif line.strip().startswith('Entity summary'):
                start_summary = True
            else:
                if start_summary and line and ':' in line:
                    summary.append(line.strip('*+ '))

        summary = [sum.split('. ', 1)[1] if sum.startswith(f'{sid+1}.') else sum for sid, sum in enumerate(summary)]
        for sum in summary:
            ent, rel = sum.split(':', 1)
            temp_graph.add_node(ent.strip())

        for sid, sum in enumerate(summary):
            ent, rel = sum.split(':', 1)
            rel = rel.lower()
            ent = ent.strip()
            for other_ent in temp_graph.nodes:
                other_ent_mention = other_ent
                if '(' in other_ent_mention:
                    other_ent_mention = other_ent_mention.split('(')[0].strip()
                if ',' in other_ent_mention:
                    other_ent_mention = other_ent_mention.split(',')[0].strip()
                if other_ent != ent and other_ent_mention.lower() in rel:
                    if not temp_graph.has_edge(ent, other_ent):
                        temp_graph.add_edge(ent, other_ent, sum=[])
                    temp_graph.get_edge_data(ent, other_ent)['sum'].append(sid)
        return temp_graph, summary
    
    def retrieve_node(self, all_graph:nx.DiGraph, targets:List[str], k:int=5):
        nodes = list(all_graph.nodes)
        nodes.sort()
        ret:List[List[str]] = []
        node_embeds = self.embed_paragraphs(nodes).cpu().numpy()
        for target in targets:
            ent_embed = self.embed_paragraphs([target]).cpu().numpy()
            scores = node_embeds.dot(ent_embed.squeeze())
            # scores = np.array([cal_rouge(node, target)['rouge-l']['f'] for node in nodes])
            max_indices = np.argsort(scores)[::-1][:k]
            # max_indices = max_indices[scores[max_indices] > 0]
            ret.append([nodes[i] for i in max_indices])
        return ret
    
    # def retrieve_paragraphs(self, task_name:str, question:str, paragraphs:List[str], completion_labels:List[bool], paragraph_embeddings:torch.Tensor, k:int=5, order_by_rank:bool=True):
    #     question_input = self.retriever_tokenizer.batch_encode_plus([question], truncation=True, return_tensors='pt').to(self.device)
    #     with torch.no_grad():
    #         question_output = self.retriever_model(**question_input)
    #         question_embeddings = self._mean_pooling(question_output[0], question_input['attention_mask'])
    #         ranks = torch.matmul(paragraph_embeddings, question_embeddings.T).squeeze()
    #         indices = torch.topk(ranks, k).indices.tolist()
    #         if not order_by_rank:
    #             indices.sort()
    #         paragraph_sep = self.get_task_paragraph_sep(task_name)
    #         retrieved_text = ''
    #         for i, idx in enumerate(indices):
    #             retrieved_text += paragraphs[idx]
    #             if idx == len(indices) - 1:
    #                 break
    #             if not completion_labels[idx]:
    #                 if indices[i+1] == idx + 1:
    #                     retrieved_text += ' '
    #                 else:
    #                     retrieved_text += paragraph_sep
    #     return retrieved_text, indices
    
    # def index_paragraph(self, paragraph:str):
    #     doc = self.nlp(paragraph)
    #     g = nx.DiGraph()
    #     for i in range(len(doc)):
    #         for child in doc[i].children:
    #             g.add_edge(i, child.i, type=child.dep_)
    #     ug = g.to_undirected()
    #     doc_ncs = list(doc.noun_chunks) + list(doc.ents)
    #     root_id_to_nc = {nc.root.i: nc for nc in doc_ncs}
    #     results = []
    #     sents = list(doc.sents)
    #     s_start2sid = {sent.start: sid for sid, sent in enumerate(sents)}
    #     for sid, sent in enumerate(sents):
    #         if not sent.text.strip():
    #             continue
    #         sent_ncs = list(sent.noun_chunks)
    #         for pair in combinations(sent_ncs, 2):
    #             if sum([pair[0].root.dep_.startswith('nsubj'), pair[1].root.dep_.startswith('nsubj')]) != 1:
    #                 continue
    #             paths:List[List[int]] = nx.simple_paths.shortest_simple_paths(ug, pair[0].root.i, pair[1].root.i) if pair[0].root.dep_.startswith('nsubj') else nx.simple_paths.shortest_simple_paths(ug, pair[1].root.i, pair[0].root.i)
    #             for path in paths:
    #                 results.append((path[0], sid, path[-1]))
        
    #     appos_to_ref = defaultdict(set)
        
    #     for head, tail, dep in g.edges.data('type'):
    #         if dep == 'appos' and doc[tail].pos_ in ['NOUN', 'PROPN', 'NUM']:
    #             # if head in root_id_to_nc and tail in root_id_to_nc:
    #             appos_to_ref[head].add(tail)
        
    #     head2sids = defaultdict(set)
    #     tail2sids = defaultdict(set)
    #     nc2sids = defaultdict(set)
    #     for head, sid, tail in results:
    #         pair_ref = []
    #         for node, node2sids in [(head, head2sids), (tail, tail2sids)]:
    #             if doc[node].pos_ == 'PRON':
    #                 refs = doc._.coref_chains.resolve(doc[node])
    #                 if refs is not None:
    #                     for ref in refs:
    #                         if ref.i not in root_id_to_nc:
    #                             root_id_to_nc[ref.i] = doc[ref.i : ref.i + 1]
    #                     nodes = [ref.i for ref in refs]
    #                 else:
    #                     nodes = [node]
    #             elif head in appos_to_ref:
    #                 nodes = list(appos_to_ref[node]) + [node]
    #             else:
    #                 nodes = [node]
    #             pair_ref.append(nodes)
    #             for n in nodes:
    #                 nc = root_id_to_nc[n].text
    #                 nc_sid = s_start2sid[doc[n].sent.start]
    #                 if nc_sid != sid:
    #                     node2sids[nc].add((nc_sid, sid))
    #                     nc2sids[nc].add((nc_sid, sid))
    #                 else:
    #                     node2sids[nc].add(sid)
    #                     nc2sids[nc].add(sid)
    #         # heads, tails = pair_ref
            
    #         # for h, t in product(heads, tails):
    #         #     index_graph.append((root_id_to_nc[h].text, sid, root_id_to_nc[t].text))

    #     # return index_graph
    #     return nc2sids, head2sids, tail2sids, sents
                
longdoc = LongDoc()

In [4]:
def cal_rouge(hp, ref):
    return PyRouge().evaluate([hp], [[ref]])

## LLM inference

In [106]:
task_name = 'narrativeqa'
test_i = 2
paragraphs, completion_labels = longdoc.split_paragraphs(dataset_dict[task_name][test_i]['context'], task_name, 400)
# retrieved_text, indices = longdoc.retrieve_paragraphs(task_name, dataset_dict['narrativeqa'][0]['input'], paragraphs, completion_labels, embeddings)

### Coreference Resolution (Test)

In [8]:
coref_batch = 4
coref_resolver:fastcoref.spacy_component.FastCorefResolver = longdoc.nlp.get_pipe('fastcoref')
coref_resolved_paragraphs = []
for bid in tqdm(range((len(paragraphs[:10]) - 1) // coref_batch + 1)):
    if bid == 0:
        frozen_paragraphs = []
        update_paragraphs = paragraphs[:coref_batch]
    else:
        frozen_paragraphs = coref_resolved_paragraphs[-1:]
        update_paragraphs = paragraphs[bid * coref_batch : (bid + 1) * coref_batch]
    
    batch_paragraphs = frozen_paragraphs + update_paragraphs
    doc = longdoc.nlp(''.join(batch_paragraphs))
    prev_char_num = 0
    paragraph_char_seps = []
    for paragraph in batch_paragraphs:
        paragraph_char_seps.append((prev_char_num, prev_char_num + len(paragraph)))
        prev_char_num = paragraph_char_seps[-1][1]
    
    clusters:List[List[Tuple[int, int]]] = doc._.coref_clusters
    
    # Normalize the referred entities
    idx2nc = {nc.root.i: nc for nc in doc.ents if '\n\n' not in nc.text}
    new_clusters = []
    for cluster in clusters:
        indices = coref_resolver._get_span_noun_indices(doc, cluster)
        if indices:
            new_cluster = []
            for sid, span in enumerate(cluster):
                if sid in indices:
                    doc_span = doc.char_span(span[0], span[1])
                    if ',' in doc_span.text:
                        root_i = doc_span.root.i
                        if root_i in idx2nc:
                            new_cluster.append((idx2nc[root_i].start_char, idx2nc[root_i].end_char))
                        continue
                new_cluster.append(span)
            new_clusters.append(new_cluster)
    clusters = new_clusters
    
    # Resolve part of prons
    resolved = list(tok.text_with_ws for tok in doc)
    all_spans = [span for cluster in clusters for span in cluster]
    for cluster in clusters:
        indices = coref_resolver._get_span_noun_indices(doc, cluster)
        if indices and doc.char_span(cluster[indices[0]][0], cluster[indices[0]][1]).root.i in idx2nc:
            mention_span, mention = coref_resolver._get_cluster_head(doc, cluster, indices)
            marked = ([True] * len(frozen_paragraphs)) + ([False] * len(update_paragraphs))
            pid = 0
            for pid, (p_start, p_end) in enumerate(paragraph_char_seps):
                if mention[0] >= p_start and mention[0] < p_end:
                    marked[pid] = True
                    break
            for coref in cluster:
                if coref != mention and not coref_resolver._is_containing_other_spans(coref, all_spans):
                    while pid < len(marked) and marked[pid]:
                        pid += 1
                    if pid == len(marked):
                        break
                    if coref[0] >= paragraph_char_seps[pid][0] and coref[0] < paragraph_char_seps[pid][1]:
                        marked[pid] = True
                        coref_resolver._core_logic_part(doc, coref, resolved, mention_span)
    coref_resolved_paragraphs.append("".join(resolved)[sum([len(p) for p in frozen_paragraphs]):])

  0%|          | 0/3 [00:00<?, ?it/s]

02/28/2024 06:23:15 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/28/2024 06:23:18 - INFO - 	 ***** Running Inference on 1 texts *****
 33%|███▎      | 1/3 [00:03<00:07,  3.82s/it]02/28/2024 06:23:19 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/28/2024 06:23:22 - INFO - 	 ***** Running Inference on 1 texts *****
 67%|██████▋   | 2/3 [00:08<00:04,  4.16s/it]02/28/2024 06:23:23 - INFO - 	 Tokenize 1 inputs...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

02/28/2024 06:23:26 - INFO - 	 ***** Running Inference on 1 texts *****
100%|██████████| 3/3 [00:12<00:00,  4.12s/it]


In [9]:
print(''.join(coref_resolved_paragraphs))

E-text prepared by Jonathan Ingram, Janet Blenkinship, and the Project
Gutenberg Online Distributed Proofreading Team (https://www.pgdp.net/)

 

 Transcriber's note: The author is Mary Wollstonecraft (1759-1797).

 

 

 MARY,

 A Fiction

 L'exercice des plus sublimes vertus éleve et nourrit le génie.
                                                     ROUSSEAU.

 London,
Printed for J. Johnson, St. Paul's Church-Yard.

 MDCCLXXXVIII

 

 

 

 ADVERTISEMENT.

 
In delineating the Heroine of this Fiction, the Author attempts to
develop a character different from those generally portrayed. This woman
is neither a Clarissa, a Lady G----, nor a[A] Sophie.--It would be vain
to mention the various modifications of these models, as it would to
remark, how widely artists wander from nature, when they copy the
originals of great masters. They catch the gross parts; but the subtile
spirit evaporates; and not having the just ties, affectation disgusts,
when grace was expected to charm.

 Thos

In [None]:
print(''.join(paragraphs))

### Context Decomposition

In [5]:
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

In [107]:
print(dataset_dict[task_name][test_i]['input'])

Where does the witch live?


In [108]:
print(dataset_dict[task_name][test_i]['answers'])

['The Atlas Mountains']


In [None]:
print(dataset_dict[task_name][test_i]['context'])

## Index text

In [44]:
context_type = 'novel'
ent_num = 10
n = 20
temperature = 0.8

In [45]:
results = []
for paragraph in tqdm(paragraphs):
    list_entity_prompt = f'''{context_type.upper()}:\n\n{paragraph}\n\nAbove is part of a {context_type}. First, list the important entities in the above passages that are relevant to most of the content. You may synthesis entities to avoid ambiguity. Don't give any explanation. Then, summarize the information in the above context for each of the important entities and try to include other important entities in each entity's summary if they are related. The two steps should be generated in the following format: "Important entities:\n1. Entity 1\n2. Entity 2\n...\nEntity summary:\nEntity 1: Entity 1's summary\nEntity 2: Entity 2's summary\n..."'''
    chat_response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        messages=[
            {"role": "user", "content": list_entity_prompt},
        ],
        # n=n,
        # temperature=temperature
    )
    results.append((paragraph, chat_response.choices[0].message.content))

with open(f'response_{test_i}.json', 'w') as f_out:
    json.dump(results, f_out)

  0%|          | 0/29 [00:00<?, ?it/s]02/28/2024 13:48:39 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
  3%|▎         | 1/29 [00:13<06:31, 13.98s/it]02/28/2024 13:48:50 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
  7%|▋         | 2/29 [00:24<05:21, 11.91s/it]02/28/2024 13:49:00 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
 10%|█         | 3/29 [00:35<04:54, 11.33s/it]02/28/2024 13:49:14 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
 14%|█▍        | 4/29 [00:48<05:08, 12.35s/it]02/28/2024 13:49:24 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
 17%|█▋        | 5/29 [00:58<04:29, 11.21s/it]02/28/2024 13:49:35 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
 21%|██        | 6/29 [01:10<04:22, 11.43s/it]02/28/2024 13:50:03 - INFO

## Interact with LLM

In [109]:
# Load index
with open(f'response_{test_i}.json') as f_in:
    results = json.load(f_in)

paragraph:str
all_graph = nx.DiGraph()
all_summary:List[str] = []
for pid, (paragraph, response) in enumerate(results):
    temp_graph, temp_summary = longdoc.parse_to_graph(response)
    s_offset = len(all_summary)
    for head, tail, sids in temp_graph.edges.data('sum'):
        if not all_graph.has_edge(head, tail):
            all_graph.add_edge(head, tail, sum=[])
        all_graph.get_edge_data(head, tail)['sum'].extend([(sid + s_offset, pid) for sid in sids])
    all_summary.extend(temp_summary)

### Step 1: identity entities of interest

In [110]:
question = dataset_dict[task_name][test_i]['input']
query_entity_prompt = f'''Question: {question}\nYou need to answer the above question based on a given story. Before searching in the story, identify some entities you want to query to gain useful information. Generate your response in the following format:\n"Query entities:\nthe first entity\nthe second entity\n...". Don't give any explanation.'''

In [111]:
chat_response = client.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[
        {"role": "user", "content": query_entity_prompt},
    ]
)
print(chat_response.choices[0].message.content)

02/28/2024 15:25:10 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


 Query entities:
the witch's character
the setting of the story.


In [112]:
ent_sets = longdoc.retrieve_node(all_graph, ["Ann", "Mary", "feelings of affection"])
# ent_sets = longdoc.retrieve_node(all_graph, ['the witch', "the character's residence in the story"])

In [93]:
ent_sets

[['Ann', 'Ann and Mary', "Ann's mother", 'Mother of Ann', 'Mary and Ann'],
 ['Mary', 'Mary and Ann', 'Ann and Mary', "Mary's husband", "Mary's maid"],
 ['Affection',
  'Feelings',
  'Affections',
  'Understanding and affections',
  'Affection/support']]

### Step 2: retrieve summary/original text

In [113]:
pair2sids = set()
for ent_set in ent_sets:
    for ent in ent_set:
        for pair in all_graph.edges(ent):
            pair2sids.update([((ent, ), sid, pid) for sid, pid in all_graph.get_edge_data(*pair)['sum']])
for pairs in itertools.combinations(ent_sets, 2):
    for m_node, a_node in itertools.product(*pairs):
        if all_graph.has_edge(m_node, a_node):
            pair2sids.update([((m_node, a_node), sid, pid) for sid, pid in all_graph.get_edge_data(m_node, a_node)['sum']])
        if all_graph.has_edge(a_node, m_node):
            pair2sids.update([((a_node, m_node), sid, pid) for sid, pid in all_graph.get_edge_data(a_node, m_node)['sum']])
pair2sids = list(pair2sids)
pair2sids.sort(key=lambda x: x[1])

In [30]:
pair2sids

[(('Mary',), 14, 2),
 (('Mary',), 65, 10),
 (('Mary',), 74, 11),
 (('Mary', 'Ann'), 86, 13),
 (('Mary',), 86, 13),
 (('Ann', 'Mary'), 87, 13),
 (('Ann',), 87, 13),
 (('Mary',), 93, 14),
 (('Affection',), 97, 14),
 (('Affection', 'Mary'), 97, 14),
 (('Mary',), 140, 19),
 (('Understanding and affections', 'Mary'), 148, 19),
 (('Understanding and affections',), 148, 19),
 (('Ann',), 155, 20),
 (("Ann's mother",), 156, 20),
 (('Mary', 'Ann'), 165, 22),
 (('Mary',), 165, 22),
 (('Ann',), 166, 22),
 (('Mary',), 172, 23),
 (('Mary', 'Ann'), 172, 23),
 (('Ann',), 176, 23),
 (('Ann', 'Mary'), 176, 23),
 (('Mary',), 177, 24),
 (('Ann', 'Mary'), 181, 24),
 (('Ann',), 181, 24),
 (('Mary',), 182, 25),
 (('Mary', 'Ann'), 182, 25),
 (('Ann',), 184, 25),
 (('Ann', 'Mary'), 184, 25),
 (('Mary', 'Ann'), 191, 27),
 (('Mary',), 191, 27),
 (('Ann',), 193, 27),
 (('Ann', 'Mary'), 193, 27),
 (('Mary',), 196, 28),
 (('Mary',), 200, 29),
 (('Mary', 'Ann'), 200, 29),
 (('Ann',), 201, 29),
 (('Ann', 'Mary'), 201

In [114]:
pid2pairs = defaultdict(set)
for pair, sid, pid in pair2sids:
    pid2pairs[pid].add(frozenset(pair))
    
passage_retrieve_prompt = f'''Question: {question}\nYou need to answer the above question based on a given story.\nBelow is a list of related entities and entity pairs contained in each passage from the story. The passage numbers are assigned based on the original order of the passages in the text.\n\n'''
menu = ''
for pid, pairs in pid2pairs.items():
    menu += f"Passage {pid}: {', '.join([str(tuple(pair)) if len(pair) == 2 else tuple(pair)[0] for pair in pairs])}\n\n"
passage_retrieve_prompt += menu
passage_retrieve_prompt += '''Now, you need to gain more information from the passages to answer the queston. Select a retrieval type and the passage numbers. For the retrieval type, you may choose "original text" to retrieve the original passages, or "summary" to retrieve the summary of the entities in the passage. For passage selection, you may select passage numbers that do not exist in the above list to obtain continuous contextual information. You may retrieve either 5 passages for "original text" or 10 passages for "summary". Generate your response in the following format:\n"Retrieval type: summary/original text\nPassage numbers: first passage number, second passage number, ...".\nDon't give any explanation.'''

In [96]:
print(passage_retrieve_prompt)

Question: Why does Ann not return Mary's feelings of affection?
You need to answer the above question based on a given story.
Below is a list of related entities and entity pairs contained in each passage from the story. The passage numbers are assigned based on the original order of the passages in the text.

Passage 2: Mary

Passage 10: Mary

Passage 11: Mary

Passage 13: Mary, ('Mary', 'Ann'), Ann

Passage 14: Mary, ('Affection', 'Mary'), Affection

Passage 19: Mary, Understanding and affections, ('Mary', 'Understanding and affections')

Passage 20: Ann, Ann's mother

Passage 22: Mary, ('Mary', 'Ann'), Ann

Passage 23: Mary, ('Mary', 'Ann'), Ann

Passage 24: Mary, ('Mary', 'Ann'), Ann

Passage 25: Mary, ('Mary', 'Ann'), Ann

Passage 27: Mary, ('Mary', 'Ann'), Ann

Passage 28: Mary

Passage 29: Mary, ('Mary', 'Ann'), Ann

Passage 32: Mary, ('Mary', 'Ann'), Ann

Passage 33: Mary

Passage 35: Mary, ('Mary', 'Ann'), Ann

Passage 37: Mary

Passage 39: Ann and Mary

Passage 41: Mary, ('Ma

In [115]:
chat_response2 = client.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[
        {"role": "user", "content": passage_retrieve_prompt},
    ]
)
print(chat_response2.choices[0].message.content)

02/28/2024 15:25:27 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


 Retrieval type: original text
Passage numbers: 1, 13, 27.


### Step 3: analyze retrieved info

In [116]:
passage_indices = [1, 13, 27]
# passage_indices = [13, 22, 23, 24, 25, 27, 28, 29, 32, 33]
retrieve_result = ''
pid2sids = defaultdict(set)
for pair, sid, pid in pair2sids:
    pid2sids[pid].add(sid)
for pid, sids in pid2sids.items():
    if pid in passage_indices:
        temp_summary = '\n'.join([all_summary[sid] for sid in sids])
        retrieve_result += f'''Passage {pid}:\n{temp_summary}\n\n'''

In [100]:
print(retrieve_result)

Passage 13:
Mary: Mary frequently made mistakes while sending messages to her new friend Ann. To make their communication more agreeable, Ann proposed writing letters instead. Mary had little instruction, but with practice and observing Ann's handwriting, she became proficient. Mary felt less pain about her mother's favoritism towards her brother, as she hoped to be beloved by someone else. However, this hope led to new sorrows and disappointment. Mary's manners were softened, but her spirits were still fluctuating, and her movements were rapid.
Ann: Ann felt gratitude towards Mary and her friendship. Her heart was entirely engrossed by one object, making friendship an insufficient substitute. Ann recalled past scenes and made unavailing wishes, causing time to loiter. She proposed writing letters instead of verbal communication to avoid mistakes.

Passage 22:
Mary: She was not at home when summoned by her father. She had visited Ann, who was in an hysteric fit due to threats of evicti

In [117]:
analyze_retrieve_prompt = f'''Question: {question}\nYou need to answer the above question based on a given story.\nBelow is a list of related entities and entity pairs contained in each passage from the story. The passage numbers are assigned based on the original order of the passages in the text.\n\n'''
analyze_retrieve_prompt += menu
analyze_retrieve_prompt += f'''Below are some selected passages with the entity summary.\n\n'''
analyze_retrieve_prompt += retrieve_result
analyze_retrieve_prompt += '''Now, summarize any useful information from the above selected passages. Generate your response in the following format:\n"Summary: the summary of current useful information".'''

In [118]:
chat_response3 = client.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[
        {"role": "user", "content": analyze_retrieve_prompt},
    ]
)
print(chat_response3.choices[0].message.content)

02/28/2024 15:25:49 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


 Summary: The women named Ruth and Lucy are praised by Wordsworth and are likely the subjects of his poetry as grandsons' sweethearts. The entity "Love" is mentioned multiple times in the story, and in one passage, it is described as helping lovers experience joy together for ten consecutive moons, with the Witch ensuring their well-being during this time. There is no direct mention of a witch's residence in the given passages.


### Step 4: continue searching or start answering

In [83]:
current_summary = chat_response3.choices[0].message.content + '\n\n'

In [104]:
decision_prompt = f'''Question: {question}\nYou need to answer the above question based on a given story.\nBelow is a list of related entities and entity pairs contained in each passage from the story. The passage numbers are assigned based on the original order of the passages in the text.\n\n'''
decision_prompt += menu
retrieved_passage_idx_str = ', '.join(map(str, passage_indices))
decision_prompt += f'''Below is the summary of useful information from passage {retrieved_passage_idx_str}.\n\n'''
decision_prompt += current_summary
decision_prompt += '''Now, you need to choose whether to continue searching for more information or to start answering the question.

If the information is not adequate, you may choose to continue searching. Select a retrieval type and the passage numbers. For the retrieval type, you may choose "original text" to retrieve the original passages, or "summary" to retrieve the summary of the entities in the passage. For passage selection, you may select passage numbers that do not exist in the above list to obtain continuous contextual information. You may retrieve either 5 passages for "original text" or 10 passages for "summary". Generate your response in the following format:\n"Retrieval type: summary/original text\nPassage numbers: first passage number, second passage number, ...".

Otherwise, if the information is adequate, you may choose to start answering the question. Generate your answer to the question in the following format:\n"Answer: your answer here".

For either choice, don't give any explanation.'''

In [105]:
chat_response4 = client.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[
        {"role": "user", "content": decision_prompt},
    ]
)
print(chat_response4.choices[0].message.content)

02/28/2024 15:20:47 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"


 Retrieval type: summary
Passage numbers: 13, 22, 23, 24, 25, 27, 28, 29, 32, 33

Answer: Ann does not return Mary's feelings of affection due to her intense focus on one object and her belief that friendship is an insufficient substitute for deeper connection.


In [121]:
results3 = []
for paragraph in tqdm(paragraphs[8:9]):
    list_entity_prompt = f"{context_type.upper()}:\n\n{paragraph}\n\nAbove is part of a {context_type}. List {ent_num} important entities in the above passages that are relevant to most of the content. You may synthesis entities to avoid ambiguity. List and separate the entities with '\n' and don't give any explanation."
    chat_response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        messages=[
            {"role": "user", "content": list_entity_prompt},
        ],
        n=n,
        temperature=temperature
    )
    ent_cnt = Counter()
    for i in range(n):
        for p in chat_response.choices[i].message.content.split('\n\n'):
            if p.count('\n') >= ent_num - 1:
                entities = [ent.strip('*. ') for ent in p.split('\n')]
                if all([ent.startswith(str(eid + 1)) for eid, ent in enumerate(entities)]):
                    entities = [ent[len(str(eid + 1)):].strip('. ') for eid, ent in enumerate(entities)]
                ent_cnt.update(entities)
    important_entities = [ent for ent, cnt in ent_cnt.most_common(ent_num) if cnt > n // 3]
    important_entities_str = '\n'.join(important_entities)
    describe_relation_prompt = f"{context_type.upper()}:\n\n{paragraph}\n\nAbove is part of a {context_type}. Summarize the information in the above context for each of the following entities:\n\n{important_entities_str}\n\nTry to include other important entities in each entity's summary if they are related. Separate the passages with '\n'."
    chat_response2 = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        messages=[
            {"role": "user", "content": describe_relation_prompt}
        ]
    )
    results3.append((paragraph, important_entities, chat_response2.choices[0].message.content))

# with open('temp2.json', 'w') as f_out:
#     json.dump(results, f_out)

  0%|          | 0/1 [00:00<?, ?it/s]02/28/2024 04:55:09 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
02/28/2024 04:55:18 - INFO - 	 HTTP Request: POST http://localhost:8000/v1/chat/completions "HTTP/1.1 200 OK"
100%|██████████| 1/1 [00:14<00:00, 14.51s/it]


In [122]:
important_entities

['Her father',
 'Angels',
 'Wood in the park',
 'Her husband',
 'She (Mary)',
 'Female acquirements',
 'Her mother',
 'Separate state',
 'Gay world',
 'Indolence and ill health']

## Long context and multi-hop reasoning

In [None]:
from datasets import load_dataset
import json

### HotpotQA

In [None]:
hotpot_qa = load_dataset('hotpot_qa', 'distractor', split='validation')

In [None]:
hotpot_qa[1]

### NarrativeQA

In [None]:
narrative_qa = load_dataset('narrativeqa', split='train')

In [None]:
narrative_qa[0]

### QASPER

In [None]:
qasper = load_dataset('allenai/qasper', split='train')

In [None]:
qasper[0]

### QuALITY

In [None]:
quality_qa = [json.loads(l) for l in open('../../QuALITY.v1.0.1.train')]

In [None]:
quality_qa[0]

### openbookqa

In [None]:
openbookqa = load_dataset('openbookqa', 'main', split='train')

In [None]:
openbookqa[0]

### LongBench

In [None]:
datasets = [
    # "narrativeqa", 
    # "qasper", 
    # "multifieldqa_en", 
    # "multifieldqa_zh", 
    # "hotpotqa", 
    # "2wikimqa", 
    "musique", 
    # "dureader", 
    # "gov_report", 
    # "qmsum", 
    # "multi_news", 
    # "vcsum", 
    # "trec", 
    # "triviaqa", 
    # "samsum", 
    # "lsht", 
    # "passage_count", 
    # "passage_retrieval_en", 
    # "passage_retrieval_zh", 
    # "lcc", 
    # "repobench-p"
]
task_name = datasets[0]

dataset_dict = {task_name: load_dataset('THUDM/LongBench', task_name, split='test')}
print(dataset_dict[task_name][1]['context'])

In [None]:
dataset_dict[task_name][1]

### LooGLE

In [None]:
datasets = [
    # "shortdep_qa", 
    # "shortdep_cloze", 
    "longdep_qa", 
    # "longdep_summarization"
]

for testset in datasets:
    data = load_dataset('bigainlco/LooGLE', testset, split='test')

In [None]:
data[0]