![read_agent_teaser](https://read-agent.github.io/img/teaser.png)

In [1]:
# !wget https://github.com/nyu-mll/quality/raw/main/data/v1.0.1/QuALITY.v1.0.1.htmlstripped.dev
import re, time, datetime, json, string, copy, os

from index_files import LongDoc
longdoc = LongDoc(llm_name="mistralai/Mistral-7B-Instruct-v0.2", device='cuda:0')

In [2]:
model_type = 'gpt'
def query_gpt_model(
    prompt: str
) -> str:
    return longdoc._call_llm(prompt).choices[0].message.content

In [None]:
# @title Using OpenAI GPT model (DO NOT run the next cell if using GPT)
!pip3 install openai
import openai

key = 'YOUR API KEY'  #@param {type: "string"}
gpt_client = openai.OpenAI(api_key=key)
model_type = 'gpt'

def query_gpt_model(
    prompt: str,
    lm: str = 'gpt-3.5-turbo-1106',
    temperature: float = 0.0,
    max_decode_steps: int = 512,
    seconds_to_reset_tokens: float = 30.0,
) -> str:
  while True:
    try:
      raw_response = gpt_client.chat.completions.with_raw_response.create(
        model=lm,
        max_tokens=max_decode_steps,
        temperature=temperature,
        messages=[
          {'role': 'user', 'content': prompt},
        ]
      )
      completion = raw_response.parse()
      return completion.choices[0].message.content
    except openai.RateLimitError as e:
      print(f'{datetime.datetime.now()}: query_gpt_model: RateLimitError {e.message}: {e}')
      time.sleep(seconds_to_reset_tokens)
    except openai.APIError as e:
      print(f'{datetime.datetime.now()}: query_gpt_model: APIError {e.message}: {e}')
      print(f'{datetime.datetime.now()}: query_gpt_model: Retrying after 5 seconds...')
      time.sleep(5)

In [None]:
# @title Using Google Gemini model (DO NOT run this if using GPT)
!pip3 install -q -U google-generativeai
import google.generativeai as genai

key = 'YOUR API KEY'  #@param {type: "string"}

genai.configure(api_key=key)
model = genai.GenerativeModel('gemini-pro')
model_type = 'gemini'

def query_gemini_model(
    prompt: str,
    retries: int = 10,
) -> str:
  while True and retries > 0:
    try:
      response = model.generate_content(prompt)
      text_response = response.text.replace("**", "")
      return text_response
    except Exception as e:
      print(f'{datetime.datetime.now()}: query_gemini_model: Error: {e}')
      print(f'{datetime.datetime.now()}: query_gemini_model: Retrying after 5 seconds...')
      retries -= 1
      time.sleep(5)

In [3]:
def query_model(prompt):
  if model_type == "gpt":
    return query_gpt_model(prompt)
  elif model_type == "gemini":
    return query_gemini_model(prompt)

In [5]:
#@title Load a QuALITY example

# Fields that are straight text copies from raw example to processed example.
_ONE2ONE_FIELDS = (
    'article',
    'article_id',
    'set_unique_id',
    'writer_id',
    'source',
    'title',
    'topic',
    'url',
    'writer_id',
    'author',
)

quality_dev = []

with open('../../data/QuALITY/QuALITY.v1.0.1.htmlstripped.train', 'r') as f:
  for line in f.readlines():
    j = json.loads(line)
    fields = {k: j[k] for k in _ONE2ONE_FIELDS}
    fields.update({
        'questions': [q['question'] for q in j['questions']],
        'question_ids': [q['question_unique_id'] for q in j['questions']],
        'difficults': [q['difficult'] for q in j['questions']],
        'options': [q['options'] for q in j['questions']],
    })

    fields.update({
        'gold_labels': [q['gold_label'] for q in j['questions']],
        'writer_labels': [q['writer_label'] for q in j['questions']],
      })

    quality_dev.append(fields)

In [6]:
#@title Helper functions

all_lowercase_letters = string.ascii_lowercase  # "abcd...xyz"
bracketed_lowercase_letters_set = set(
    [f"({l})" for l in all_lowercase_letters]
)  # {"(a)", ...}
bracketed_uppercase_letters_set = set(
    [f"({l.upper()})" for l in all_lowercase_letters]
)  # {"(a)", ...}

choices = ['(A)', '(B)', '(C)', '(D)']

def get_index_from_symbol(answer):
  """Get the index from the letter symbols A, B, C, D, to extract answer texts.

  Args:
    answer (str): the string of answer like "(B)".

  Returns:
    index (int): how far the given choice is from "a", like 1 for answer "(B)".
  """
  answer = str(answer).lower()
  # extract the choice letter from within bracket
  if answer in bracketed_lowercase_letters_set:
    answer = re.findall(r"\(.*?\)", answer)[0][1]
  index = ord(answer) - ord("a")
  return index

def count_words(text):
  """Simple word counting."""
  return len(text.split())

def quality_gutenberg_parser(raw_article):
  """Parse Gutenberg articles in the QuALITY dataset."""
  lines = []
  previous_line = None
  for i, line in enumerate(raw_article.split('\n')):
    line = line.strip()
    original_line = line
    if line == '':
      if previous_line == '':
        line = '\n'
      else:
        previous_line = original_line
        continue
    previous_line = original_line
    lines.append(line)
  return ' '.join(lines)

In [7]:
#@title ReadAgent (1) Episode Pagination

prompt_pagination_template = """
You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.
Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.
Please choose one label that it is natural to break reading.
Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.
Please answer the break point label and explain.
For example, if <57> is a good point to break, answer with \"Break point: <57>\n Because ...\"

Passage:

{0}
{1}
{2}

"""

def parse_pause_point(text):
  text = text.strip("Break point: ")
  if text[0] != '<':
    return None
  for i, c in enumerate(text):
    if c == '>':
      if text[1:i].isnumeric():
        return int(text[1:i])
      else:
        return None
  return None


def quality_pagination(example,
                       word_limit=600,
                       start_threshold=280,
                       max_retires=10,
                       verbose=True,
                       allow_fallback_to_last=True):
  article = example['article']
  title = example['title']
  print(f"[Pagination][Article {title}]")
  paragraphs = quality_gutenberg_parser(article).split('\n')

  i = 0
  pages = []
  while i < len(paragraphs):
    preceding = "" if i == 0 else "...\n" + '\n'.join(pages[-1])
    passage = [paragraphs[i]]
    wcount = count_words(paragraphs[i])
    j = i + 1
    while wcount < word_limit and j < len(paragraphs):
      wcount += count_words(paragraphs[j])
      if wcount >= start_threshold:
        passage.append(f"<{j}>")
      passage.append(paragraphs[j])
      j += 1
    passage.append(f"<{j}>")
    end_tag = "" if j == len(paragraphs) else paragraphs[j] + "\n..."

    pause_point = None
    if wcount < 350:
      pause_point = len(paragraphs)
    else:
      prompt = prompt_pagination_template.format(preceding, '\n'.join(passage), end_tag)
      response = query_model(prompt=prompt).strip()
      pause_point = parse_pause_point(response)
      if pause_point and (pause_point <= i or pause_point > j):
        print(f"prompt:\n{prompt},\nresponse:\n{response}\n")
        print(f"i:{i} j:{j} pause_point:{pause_point}")
        pause_point = None
      if pause_point is None:
        if allow_fallback_to_last:
          pause_point = j
        else:
          raise ValueError(f"prompt:\n{prompt},\nresponse:\n{response}\n")

    page = paragraphs[i:pause_point]
    pages.append(page)
    if verbose:
      print(f"Paragraph {i}-{pause_point-1}", page)
    i = pause_point
  print(f"[Pagination] Done with {len(pages)} pages")
  return pages

In [8]:
#@title ReadAgent (2) Memory Gisting

prompt_shorten_template = """
Please shorten the following passage.
Just give me a shortened version. DO NOT explain your reason.

Passage:
{}

"""

def quality_gisting(example, pages, word_limit=600, start_threshold=280, verbose=True):
  article = example['article']
  title = example['title']
  word_count = count_words(article)
  print(f"[Gisting][Article {title}], {word_count} words")

  shortened_pages = []
  for i, page in enumerate(pages):
    prompt = prompt_shorten_template.format('\n'.join(page))
    response = query_model(prompt)
    shortened_text = response.strip()
    shortened_pages.append(shortened_text)
    if verbose:
      print("[gist] page {}:".format(i), shortened_text, flush=True)
  shortened_article = '\n'.join(shortened_pages)
  gist_word_count = count_words(shortened_article)
  if verbose:
    print("Shortened article:\n", shortened_article, flush=True)
  output = copy.deepcopy(example)
  output.update({'title': title, 'word_count': word_count, 'gist_word_count': gist_word_count, 'shortened_pages': shortened_pages, 'pages': pages})
  if verbose:
    print(f"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})")
  return output

In [9]:
#@title ReadAgent (3) Look-Up

prompt_lookup_template = """
The following text is what you remembered from reading an article and a question related to it.
You may read 1 to 6 page(s) of the article again to refresh your memory to prepare yourselve for the question.
Please respond with which page(s) you would like to read.
For example, if your only need to read Page 8, respond with \"I want to look up Page [8] to ...\";
if your would like to read Page 7 and 12, respond with \"I want to look up Page [7, 12] to ...\";
if your would like to read Page 2, 3, 7, 15 and 18, respond with \"I want to look up Page [2, 3, 7, 15, 18] to ...\".
if your would like to read Page 3, 4, 5, 12, 13 and 16, respond with \"I want to look up Page [3, 3, 4, 12, 13, 16] to ...\".
DO NOT select more pages if you don't need to.
DO NOT answer the question yet.

Text:
{}

Question:
{}
{}

Take a deep breath and tell me: Which page(s) would you like to read again?
"""

prompt_answer_template = '''
Read the following article and answer a multiple choice question.
For example, if (C) is correct, answer with \"Answer: (C) ...\"

Article:
{}

Question:
{}
{}

'''

def quality_parallel_lookup(example, verbose=True):
  preprocessed_pages = example['pages']
  article = example['article']
  title = example['title']
  word_count = example['word_count']
  gist_word_count = example['gist_word_count']
  pages = example['pages']
  shortened_pages = example['shortened_pages']
  questions = example['questions']
  options = example['options']
  gold_labels = example['gold_labels']  # numerical [1, 2, 3, 4]

  print(f"[Look-Up][Article {title}] {word_count} words")

  model_choices = []
  lookup_page_ids = []

  shortened_pages_pidx = []
  for i, shortened_text in enumerate(shortened_pages):
    shortened_pages_pidx.append("<Page {}>\n".format(i) + shortened_text)
  shortened_article = '\n'.join(shortened_pages_pidx)

  expanded_gist_word_counts = []
  responses = []
  retrieved = []
  for i, label in enumerate(gold_labels):
    # only test the first question for demo
    # if i != 1:
    #   continue
    q = questions[i]
    print("question: ", q)
    options_i = [f"{ol} {o}" for ol, o in zip(choices, options[i])]
    print("options: ", "\n".join(options_i))
    prompt_lookup = prompt_lookup_template.format(shortened_article, q, '\n'.join(options_i))

    page_ids = []

    response = query_model(prompt=prompt_lookup).strip()

    # try: start = response.index('[')
    # except ValueError: start = len(response)
    # try: end = response.index(']')
    # except ValueError: end = 0
    if 'I want to look up Page'.lower() in response.lower():
      start = response.lower().index('I want to look up Page'.lower()) + len('I want to look up Page')
      page_ids_str = response[start:].split('to', 1)[0].split()
    # if start < end:
    #   page_ids_str = response[start+1:end].split(',')
      page_ids = []
      for p in page_ids_str:
        if p.strip(',.[]').isnumeric():
          page_id = int(p.strip(',.[]'))
          if page_id < 0 or page_id >= len(pages):
            print("Skip invalid page number: ", page_id, flush=True)
          else:
            page_ids.append(page_id)

    if verbose:
      print("Model chose to look up page {}".format(page_ids))

    # Memory expansion after look-up, replacing the target shortened page with the original page
    expanded_shortened_pages = shortened_pages[:]
    if len(page_ids) > 0:
      for page_id in page_ids:
        expanded_shortened_pages[page_id] = '\n'.join(pages[page_id])

    expanded_shortened_article = '\n'.join(expanded_shortened_pages)
    retrieved.append('\n\n'.join([f'Passage {page_id}:\n\n{pages[page_id]}' for page_id in page_ids]))
    expanded_gist_word_count = count_words(expanded_shortened_article)
    if verbose:
      print("Expanded shortened article:\n", expanded_shortened_article, flush=True)
    prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\n'.join(options_i))

    model_choice = None
    response = query_model(prompt=prompt_answer)
    response = response.strip()
    # for j, choice in enumerate(choices):
    #   if response.startswith(f"Answer: {choice}") or response.startswith(f"Answer: {choice[1]}"):
    #     model_choice = j+1
    #     break
    # is_correct = 1 if model_choice == label else 0
    # print(f"question: {q}")
    # print(f"reference answer: {choices[label]}, model prediction: {choices[model_choice]}, is_correct: {is_correct}")
    # print(f"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})")
    # print(f"compression rate after look-up {round(100.0 - expanded_gist_word_count/word_count*100, 2)}% ({expanded_gist_word_count}/{word_count})")
    responses.append(response)
  return responses, retrieved, shortened_pages_pidx

In [12]:
for eid, example in enumerate(quality_dev[:10]):
    if not os.path.exists(f'quality/pages_{eid}.json'):
        pages = quality_pagination(example, verbose=False)
        with open(f'quality/pages_{eid}.json', 'w') as f_out:
            json.dump(pages, f_out)
    else:
        with open(f'quality/pages_{eid}.json') as f_out:
            pages = json.load(f_out)
            
    example_with_gists = quality_gisting(example, pages, verbose=False)
    responses, retrieved, shortened_pages_pidx = quality_parallel_lookup(example_with_gists, verbose=False)
    for q, ret, res in zip(example['questions'], retrieved, responses):
        # print(ret)
        # pass
        with open(f'quality/response_gist_{eid}_log.jsonl', 'a') as f_out:
            f_out.write(json.dumps(['query', q]))
            f_out.write('\n')
            f_out.write(json.dumps(['menu', shortened_pages_pidx]))
            f_out.write('\n')
            f_out.write(json.dumps(['retrieval_result', ret]))
            f_out.write('\n')
            f_out.write(json.dumps(['current_summary', res]))
            f_out.write('\n')

[Gisting][Article Spaceman on a Spree], 4800 words
[Look-Up][Article Spaceman on a Spree] 4800 words
question:  Why is Si retirement so significant to the Space Exploration Team? 
options:  (A) There aren’t enough working people in the world. They won’t be able to find a replacement.
(B) As one of two remaining spacemen, it would likely mean the defunding and shut down of the Space Exploration Team.
(C) Training new spacemen is costly and time consuming. They won’t have anyone else ready after him.
(D) His retirement may inspire others to stop working as well, which would be hugely detrimental as most people don't feel the drive to work as is.  
question:  What makes Gubelin an outlier in the present day?
options:  (A) He is much older than the rest of the population.
(B) He refuses new operations that could improve his health.
(C) His mind is still active, and he values hard work.
(D) He still wears glasses and value objects like the gold watch given to Si.
question:  What is the main