## Chain

In [None]:
import random
import json
from openai import OpenAI
import os
from typing import List
from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
from src.my_util import openai_get

client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

In [None]:
def parse_response(prediction:str):
    # Process the predidction
    return [int(''.join([c for c in num_str if c.isdigit()])) for num_str in prediction.split() if any(char.isdigit() for char in num_str)]

def locate_chain(chains:List[List[int]], number:int):
    for chain in chains:
        if number in chain:
            return chain

def locate_error(chains:List[List[int]], prediction:List[int], start_num:int):
    error_loc = {}
    prev = -1
    for pid, p in enumerate(prediction):
        if pid == 0:
            if p != start_num:
                # If the first number is not the given starting number, then this is an error
                error_loc[pid] = 'start_number_error'
        else:
            # Find chain
            chain = locate_chain(chains, prev)
            if chain is None:
                # If previous number is not in the list, then it should stop instead of collecting further numbers
                error_loc[pid] = 'previous_not_in_list'
            else:
                prev_idx = chain.index(prev)
                if prev_idx == len(chain) - 1:
                    # If previous number is the last number in a chain, then it should stop instead of collecting further numbers
                    error_loc[pid] = 'previous_at_end'
                else:
                    if chain[prev_idx + 1] != p:
                        # If the new number is not the next number in the chain, then this is an error
                        error_loc[pid] = 'wrong_next_number'
                    elif pid == len(prediction) - 1:
                        if p != chain[-1]:
                            error_loc[pid] = 'stop_before_end'
        prev = p
    return error_loc

## Build dataset

In [None]:
list_length = 20
total_num = 400
chain_num = total_num // list_length

### Fixed chain

In [None]:
dataset = []
for _ in range(400):
    l = list(range(total_num))
    random.shuffle(l)
    chains = []
    num_lists = []
    for i in range(total_num // list_length):
        sub_l = l[i*list_length:(i+1)*list_length]
        num_lists.append(sub_l)
        chain = []
        for j in range(len(sub_l)-1):
            chain.append(f'The next number of {sub_l[j]} is {sub_l[j+1]}.')
        chains.append(chain)

    merged_list = []
    for j in range(list_length-1):
        for i in range(total_num // list_length):
            merged_list.append(chains[i][j])
    target_chain = num_lists[total_num // list_length // 2]
    dataset.append({'question': target_chain[0], 'list': merged_list, 'chains': num_lists})

with open(f'data/chain_generation/fixed_chain/test_left2right_{total_num}_{list_length}.txt', 'w') as f_out:
    for data in dataset:
        f_out.write(json.dumps(data) + '\n')


In [None]:
left2right_file = 'data/chain_generation/fixed_chain/test_left2right_400_20.txt'
with open(left2right_file) as f_in:
    for l in f_in:
        if l:
            sample = json.loads(l)
            temp_list = [sample['list'][i*chain_num : (i+1)*chain_num] for i in range(list_length)]
            temp_list.reverse()
            sample['list'] = [sent for batch in temp_list for sent in batch]

            with open(left2right_file.replace('left2right', 'right2left'), 'a') as f_out:
                f_out.write(json.dumps(sample) + '\n')

### Shuffled chain

In [None]:
dataset = []
for _ in range(200):
    l = list(range(total_num))
    random.shuffle(l)
    chains = []
    num_lists = []
    for i in range(total_num // list_length):
        sub_l = l[i*list_length:(i+1)*list_length]
        num_lists.append(sub_l)
        chain = []
        for j in range(len(sub_l)-1):
            chain.append(f'The next number of {sub_l[j]} is {sub_l[j+1]}.')
        chains.append(chain)

    merged_list = []
    for j in range(list_length-1):
        order = list(range(total_num // list_length))
        random.shuffle(order)
        for i in order:
            merged_list.append(chains[i][j])
    target_chain = random.choice(num_lists)
    dataset.append({'question': target_chain[0], 'list': merged_list, 'chains': num_lists})

with open(f'test_left2right_{total_num}_{list_length}.txt', 'w') as f_out:
    for data in dataset:
        f_out.write(json.dumps(data) + '\n')


In [None]:
with open('test_left2right.txt') as f_in:
    for l in f_in:
        if l:
            sample = json.loads(l)
            sample['list'].reverse()

            with open('test_right2left.txt', 'a') as f_out:
                f_out.write(json.dumps(sample) + '\n')

In [None]:
with open('test_left2right.txt') as f_in:
    for l in f_in:
        if l:
            sample = json.loads(l)
            random.shuffle(sample['list'])

            with open('test_random.txt', 'a') as f_out:
                f_out.write(json.dumps(sample) + '\n')

## Experiment

In [None]:
prompt_template = '''The text below describes the order of numbers in several chains. Find the complete chain that starts with number {start_number}. Only output the chain and seperate the numbers with space.

{description}

The chain starts with {start_number} is: '''

In [None]:
test_left2right = True

In [None]:
if test_left2right:
    with open('test_left2right.txt') as f_in:
        data = [json.loads(l) for l in f_in]
else:
    with open('test_right2left.txt') as f_in:
        data = [json.loads(l) for l in f_in]
    

### Single test

In [None]:
data_idx = 0
question = data[data_idx]['question']
chains = data[data_idx]['chains']
prompt = prompt_template.format(start_number=question, description='\n'.join(data[data_idx]['list']))
response = openai_get(client, 'gpt-3.5-turbo', prompt)
print(response)

In [None]:
locate_chain(chains, question)

In [None]:
prediction = parse_response(response)
locate_error(chains, prediction, question, test_left2right)

In [None]:
print(locate_chain(chains, prediction[27]))

In [None]:
prediction[27]

In [None]:
prediction[28]

### Test

In [None]:
for data_idx in tqdm(range(len(data))):
    if data_idx < 0:
        continue
    if data_idx >= 100:
        break
    question = data[data_idx]['question']
    chains = data[data_idx]['chains']
    prompt = prompt_template.format(start_number=question, description='\n'.join(data[data_idx]['list']))
    response = openai_get(client, 'gpt-3.5-turbo', prompt)
    
    with open(f'prediction_{"left2right" if test_left2right else "right2left"}.txt', 'a') as f_out:
        f_out.write(json.dumps({
            'chains': chains,
            'question': question,
            'prediction': response.strip()
        }) + '\n')

## Evaluation

In [None]:
def readlines(file:str):
    with open(file) as f_in:
        return [json.loads(l) for l in f_in]
    
def correct_num(prediction_file:str):
    predictions = readlines(prediction_file)
    corrects = []
    for prediction in predictions:
        answer = locate_chain(prediction['chains'], prediction['question'])
        cnt = 0
        for a, p in zip(answer, parse_response(prediction['prediction'])):
            if a == p:
                cnt += 1
            else:
                break
        corrects.append(cnt)
    return corrects

In [None]:
prediction_left2right = correct_num('data/chain_generation/fixed_chain/prediction_left2right_400_20.txt')
prediction_right2left = correct_num('data/chain_generation/fixed_chain/prediction_right2left_400_20.txt')
# prediction_random = correct_num('prediction_random.txt')

In [None]:
print(sum(prediction_left2right)/400)
print(sum(prediction_right2left)/400)

In [None]:
plt.hist(prediction_left2right, bins=max(prediction_left2right)-min(prediction_left2right))

In [None]:
plt.hist(prediction_right2left, bins=max(prediction_right2left)-min(prediction_right2left))

In [None]:
plt.hist(prediction_left2right, bins=len(set(prediction_left2right))-1)

In [None]:
plt.hist(prediction_right2left, bins=len(set(prediction_right2left))-1)

In [None]:
plt.hist(prediction_random, bins=max(prediction_random)-min(prediction_random))

In [None]:
errors = []
if test_left2right:
    prediction_file = 'prediction_left2right.txt'
else:
    prediction_file = 'prediction_right2left.txt'
with open(prediction_file) as f_in:
    for l in f_in:
        sample = json.loads(l)
        prediction = parse_response(sample['prediction'])
        error_log = locate_error(sample['chains'], prediction, sample['question'])
        error_log['prediction'] = prediction
        errors.append(error_log)

In [None]:
wrong_next_number_loc = []
wrong_number_loc = []
new_start_loc = []
first_error_loc = []
for idx in range(len(errors)):
    chains = data[idx]['chains']
    error_log = errors[idx]
    prediction:List[int] = error_log['prediction']
    pids = []
    for key, value in error_log.items():
        if key != 'prediction':
            pid = int(key)
            curr = prediction[pid]
            curr_chain = locate_chain(chains, curr)
            pids.append(pid)
            error_text = value
            if value == 'wrong_next_number':
                prev = prediction[pid-1]
                prev_chain = locate_chain(chains, prev)
                wrong_number_loc.append(prev_chain.index(prev))
                if curr_chain:
                    wrong_next_number_loc.append(curr_chain.index(curr))
            elif value == 'previous_not_in_list' or value == 'previous_at_end':
                if curr_chain:
                    new_start_loc.append(curr_chain.index(curr))
    if pids:
        first_error_loc.append(min(pids))
    

In [None]:
plt.hist(first_error_loc, bins=len(set(first_error_loc))-1)

In [None]:
plt.hist(first_error_loc, bins=len(set(first_error_loc))-1)

In [None]:
plt.hist(wrong_number_loc, bins=len(set(wrong_number_loc))-1)

In [None]:
plt.hist(wrong_next_number_loc, bins=len(set(wrong_next_number_loc))-1)

In [None]:
plt.hist(new_start_loc, bins=len(set(new_start_loc))-1)

In [None]:
# Use chain extraction to locate the prefer zone
# Dense retrieval + rank by score/origin order/score align prefer zone
# Importance of context: bare text pieces vs wrapping context

# Retrieval observation

In [1]:
from transformers import AutoTokenizer, AutoModel
from typing import List
import torch
from datasets import load_dataset
from nltk import sent_tokenize
from openai import OpenAI

In [None]:

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

chat_response = client.chat.completions.create(
    model="facebook/opt-125m",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Tell me a joke."},
    ]
)
print("Chat response:", chat_response)

In [13]:
datasets = [
    # "narrativeqa", 
    # "qasper", 
    # "multifieldqa_en", 
    # "multifieldqa_zh", 
    # "hotpotqa", 
    # "2wikimqa", 
    # "musique", 
    # "dureader", 
    # "gov_report", 
    # "qmsum", 
    "multi_news", 
    # "vcsum", 
    # "trec", 
    # "triviaqa", 
    # "samsum", 
    # "lsht", 
    # "passage_count", 
    # "passage_retrieval_en", 
    # "passage_retrieval_zh", 
    # "lcc", 
    # "repobench-p"
]
task_name = datasets[0]

dataset_dict = {task_name: load_dataset('THUDM/LongBench', task_name, split='test')}
print(dataset_dict[task_name][1]['context'])

Passage 1:
Starting in 1996, Alexa Internet has been donating their crawl data to the Internet Archive. Flowing in every day, these data are added to the Wayback Machine after an embargo period.
Passage 2:
Image copyright Getty Images Image caption Kalashnikov designed the AK-47 after being wounded fighting for the Red Army NEWLINE_CHAR NEWLINE_CHAR The inventor of the Kalashnikov assault rifle apparently wrote to the head of the Russian Orthodox Church before he died expressing fears he was morally responsible for the people it killed. NEWLINE_CHAR NEWLINE_CHAR Mikhail Kalashnikov, who died last month aged 94, wrote a long emotional letter to Patriarch Kirill in May 2012, church officials say. NEWLINE_CHAR NEWLINE_CHAR He said he was suffering "spiritual pain" over the many deaths it caused. NEWLINE_CHAR NEWLINE_CHAR Kalashnikov had previously refused to accept responsibility for those killed. NEWLINE_CHAR NEWLINE_CHAR 'Devilish desires' NEWLINE_CHAR NEWLINE_CHAR Analysis The letter p

In [4]:
dataset_dict[task_name][2]['context']

"Question: What time does Wee Willie Winkie run through the town ?\nType: Date\nQuestion: What city is often called The Insurance Capital of the World ?\nType: City\nQuestion: What are the benefits of a rowing machine ?\nType: Description of something\nQuestion: What are the different approaches of systems analysis ?\nType: Techniques and method\nQuestion: When did the Jurassic Period end ?\nType: Date\nQuestion: What actor was the first man to appear on the cover of McCall 's ?\nType: Individual\nQuestion: What exactly , specifically does sleep do for you ?\nType: Description of something\nQuestion: What substance did Joseph Priestley name for its ability to erase pencil marks ?\nType: Element and substance\nQuestion: In what year was De Gaulle elected president of France ?\nType: Date\nQuestion: What is Britain 's possession on the Chinese mainland ?\nType: Other location\nQuestion: What is Jane Goodall known for ?\nType: Reason\nQuestion: What is the origin of the surname of Braun ?

In [3]:
print(dataset_dict[task_name][1]['input'])

Question: What was J.F.K. 's wife 's name ?
Type:


In [3]:
def split_trec(text:str):
    lines = text.splitlines()
    return ['\n'.join(lines[i * 2 : i * 2 + 1]) for i in range(len(lines) // 2)]

def split_triviaqa(text:str):
    lines = text.splitlines()
    paragraphs = []
    paragraph = []
    lid = 0
    while lid < len(lines):
        paragraph.append(lines[lid])
        if lines[lid] == 'Answer:':
            lid += 1
            paragraph.append(lines[lid])
            paragraphs.append('\n'.join(paragraph))
            paragraph.clear()
        lid += 1
    return paragraphs

def split_samsum(text:str):
    paragraphs = []
    paragraph = []
    for line in text.splitlines():
        paragraph.append(line)
        if line.startswith('Summary: '):
            paragraphs.append('\n'.join(paragraph))
            paragraph.clear()
    return paragraphs

class LongDoc:
    paragraph_sep_map = {
        'qasper': '\n', 
        'multifieldqa_zh': '\n', 
        'qmsum': '\n', 
        'multi_news': '\n', 
        'vcsum': '\n', 
        'trec': (split_trec, '\n'), 
        'triviaqa': (split_triviaqa, '\n'), 
        'samsum': (split_samsum, '\n'), 
    }
    
    def __init__(self, retriever_model_name:str='facebook/contriever', llm_name:str='meta-llama/Llama-2-7b-hf') -> None:
        self.device = torch.device('cuda:2')
        self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_name)
        self.retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
        self.retriever_model = AutoModel.from_pretrained(retriever_model_name)
        self.retriever_model.cuda(device=self.device)
        
    def get_task_paragraph_sep(self, task_name:str):
        sep = self.paragraph_sep_map.get(task_name, '\n\n')
        if not isinstance(sep, str):
            func, sep = sep
        return sep
    
    def split_context_to_paragraphs(self, context:str, task_name:str):
        sep = self.paragraph_sep_map.get(task_name, '\n\n')
        if isinstance(sep, str):
            return context.split(sep)
        else:
            func, sep = self.paragraph_sep_map[task_name]
            return func(context)
    
    # Mean pooling
    @staticmethod
    def _mean_pooling(token_embeddings:torch.Tensor, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings:torch.Tensor  = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

    def _append_paragraph(self, paragraphs:list, tokenized_p:List[str]):
        paragraphs.append(self.llm_tokenizer.decode(tokenized_p))
        tokenized_p.clear()
    
    def split_single_paragraph(self, text:str, paragraph_size:int=300, is_natural_language:bool=True):
        splited_paragraphs:List[str] = []
        splited_paragraph = []
        sentences:List[str] = sent_tokenize(text) if is_natural_language else text.split('\n')
        for sent in sentences:
            tokenized_s = self.llm_tokenizer.encode(sent)[1:]
            if len(tokenized_s) <= paragraph_size:
                if len(splited_paragraph) + len(tokenized_s) > paragraph_size:
                    self._append_paragraph(splited_paragraphs, splited_paragraph)
                splited_paragraph.extend(tokenized_s)
            else:
                if splited_paragraph:
                    self._append_paragraph(splited_paragraphs, splited_paragraph)
                chunk_size = (len(tokenized_s) - 1) // paragraph_size + 1
                for i in range(chunk_size - 1):
                    self._append_paragraph(splited_paragraphs, tokenized_s[i * paragraph_size: (i+1) * paragraph_size])
                splited_paragraph = tokenized_s[(chunk_size - 1) * paragraph_size:]
        
        return splited_paragraphs, splited_paragraph
            
        
    def split_and_embed_paragraphs(self, text:str, task_name:str, paragraph_size:int=300):
        reformated_paragraphs:List[str] = []
        completion_labels:List[bool] = []
        reformated_paragraph = []
        
        paragraph_sep = self.get_task_paragraph_sep(task_name)
        paragraphs = text.split(paragraph_sep)
        for p in paragraphs:
            tokenized_p = self.llm_tokenizer.encode(p + paragraph_sep)[1:]
            if len(tokenized_p) <= paragraph_size:
                if len(reformated_paragraph) + len(tokenized_p) > paragraph_size:
                    self._append_paragraph(reformated_paragraphs, reformated_paragraph)
                    completion_labels.append(True)
                reformated_paragraph.extend(tokenized_p)
            else:
                if reformated_paragraph:
                    self._append_paragraph(reformated_paragraphs, reformated_paragraph)
                    completion_labels.append(True)
                splited_paragraphs, splited_paragraph = self.split_single_paragraph(p, paragraph_size)
                reformated_paragraphs.extend(splited_paragraphs)
                completion_labels.extend([False] * len(splited_paragraphs))
                reformated_paragraph = splited_paragraph
                
        if reformated_paragraph:
            self._append_paragraph(reformated_paragraphs, reformated_paragraph)
            completion_labels.append(True)
        
        retriever_input = self.retriever_tokenizer.batch_encode_plus(reformated_paragraphs, padding=True, truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            retriever_output = self.retriever_model(**retriever_input)
            paragraph_embeddings= self._mean_pooling(retriever_output[0], retriever_input['attention_mask'])
        
        return reformated_paragraphs, completion_labels, paragraph_embeddings
    
    def retrieve_paragraphs(self, task_name:str, question:str, paragraphs:List[str], completion_labels:List[bool], paragraph_embeddings:torch.Tensor, k:int=5, order_by_rank:bool=True):
        question_input = self.retriever_tokenizer.batch_encode_plus([question], truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            question_output = self.retriever_model(**question_input)
            question_embeddings = self._mean_pooling(question_output[0], question_input['attention_mask'])
            ranks = torch.matmul(paragraph_embeddings, question_embeddings.T).squeeze()
            indices = torch.topk(ranks, k).indices.tolist()
            if not order_by_rank:
                indices.sort()
            paragraph_sep = self.get_task_paragraph_sep(task_name)
            retrieved_text = ''
            for i, idx in enumerate(indices):
                retrieved_text += paragraphs[idx]
                if idx == len(indices) - 1:
                    break
                if not completion_labels[idx]:
                    if indices[i+1] == idx + 1:
                        retrieved_text += ' '
                    else:
                        retrieved_text += paragraph_sep
        return retrieved_text, indices
        

In [4]:
longdoc = LongDoc()

In [6]:
task_name = 'narrativeqa'
paragraphs, completion_labels, embeddings = longdoc.split_and_embed_paragraphs(dataset_dict[task_name][0]['context'], task_name)
retrieved_text, indices = longdoc.retrieve_paragraphs(task_name, dataset_dict['narrativeqa'][0]['input'], paragraphs, completion_labels, embeddings)

In [14]:
print(retrieved_text)

’s widow (he had been mayor of
Clockborough) would pass away and the heiress would return to her
inheritance.  I gathered with surprise that she had not communicated to
his wife the story of her attempt to hear Mr..Saltram, and I founded this
reticence on the easy supposition that Mrs. Saltram had fatigued by
overpressure the spring of the sympathy of which she boasted.  The girl
at any rate would forget the small adventure, be distracted, take a
husband; besides which she would lack occasion to repeat her experiment.

ONE of the consequences, for the Mulvilles, of the sacrifices they made
for Frank Saltram was that they had to give up their carriage.  Adelaide
drove gently into London in a one-horse greenish thing, an early
Victorian landau, hired, near at hand, imaginatively, from a broken-down
jobmaster whose wife was in consumption—a vehicle that made people turn
round all the more when her pensioner sat beside her in a soft white hat
and a shawl, one of the dear woman’s own.  This