Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pageindex/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ max_page_num_each_node: 10
max_token_num_each_node: 20000
if_add_node_id: "yes"
if_add_node_summary: "no"
if_add_doc_description: "yes"
if_add_doc_description: "yes"
if_add_node_text: "no"
179 changes: 94 additions & 85 deletions pageindex/page_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from .utils import *
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse


################### check title in page #########################################################
def check_title_appearance(item, page_list, start_index=1, model=None):
async def check_title_appearance(item, page_list, start_index=1, model=None):
title=item['title']
if 'physical_index' not in item or item['physical_index'] is None:
return {'list_index': item.get('list_index'), 'answer': 'no', 'title':title, 'page_number': None}
Expand All @@ -37,7 +36,7 @@ def check_title_appearance(item, page_list, start_index=1, model=None):
}}
Directly return the final JSON structure. Do not output anything else."""

response = ChatGPT_API(model=model, prompt=prompt)
response = await ChatGPT_API_async(model=model, prompt=prompt)
response = extract_json(response)
if 'answer' in response:
answer = response['answer']
Expand All @@ -46,9 +45,9 @@ def check_title_appearance(item, page_list, start_index=1, model=None):
return {'list_index': item['list_index'], 'answer': answer, 'title': title, 'page_number': page_number}


def check_title_appearance_in_start(title, page_text, model=None, logger=None):
async def check_title_appearance_in_start(title, page_text, model=None, logger=None):
prompt = f"""
You will be given given the current section title and the current page_text.
You will be given the current section title and the current page_text.
Your job is to check if the current section starts in the beginning of the given page_text.
If there are other contents before the current section title, then the current section does not start in the beginning of the given page_text.
If the current section title is the first content in the given page_text, then the current section starts in the beginning of the given page_text.
Expand All @@ -65,36 +64,40 @@ def check_title_appearance_in_start(title, page_text, model=None, logger=None):
}}
Directly return the final JSON structure. Do not output anything else."""

response = ChatGPT_API(model=model, prompt=prompt)
response = await ChatGPT_API_async(model=model, prompt=prompt)
response = extract_json(response)
if logger:
logger.info(f"Response: {response}")
if 'start_begin' in response:
return response['start_begin']
else:
return 'no'
return response.get("start_begin", "no")


def check_title_appearance_in_start_parallel(structure, page_list, model=None, logger=None):
async def check_title_appearance_in_start_concurrent(structure, page_list, model=None, logger=None):
if logger:
logger.info(f"Checking title appearance in start parallel")
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_item = {
executor.submit(check_title_appearance_in_start, item['title'], page_list[item['physical_index']-1][0], model=model, logger=logger): item
for item in structure
}

# Process completed futures and attach results to items
for future in as_completed(future_to_item):
item = future_to_item[future]
try:
result = future.result()
item['appear_start'] = result
except Exception as e:
if logger:
logger.error(f"Error processing item {item['title']}: {str(e)}")
item['appear_start'] = 'no'
logger.info("Checking title appearance in start concurrently")

# skip items without physical_index
for item in structure:
if item.get('physical_index') is None:
item['appear_start'] = 'no'

# only for items with valid physical_index
tasks = []
valid_items = []
for item in structure:
if item.get('physical_index') is not None:
page_text = page_list[item['physical_index'] - 1][0]
tasks.append(check_title_appearance_in_start(item['title'], page_text, model=model, logger=logger))
valid_items.append(item)

results = await asyncio.gather(*tasks, return_exceptions=True)
for item, result in zip(valid_items, results):
if isinstance(result, Exception):
if logger:
logger.error(f"Error checking start for {item['title']}: {result}")
item['appear_start'] = 'no'
else:
item['appear_start'] = result

return structure


Expand Down Expand Up @@ -505,14 +508,15 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"):
For the title, you need to extract the original title from the text, only fix the space inconsistency.

The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X. \


For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.

The response should be in the following format.
[
{
"structure": <structure index, "x.x.x" or None> (string),
"structure": <structure index, "x.x.x"> (string),
"title": <title of the section, keep the original title>,
"physical_index": "<physical_index_X> (keep the format)" or None
"physical_index": "<physical_index_X> (keep the format)"
},
...
]
Expand All @@ -538,13 +542,15 @@ def generate_toc_init(part, model=None):

The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X.

For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.

The response should be in the following format.
[
{
"structure": <structure index, "x.x.x" or None> (string),
{{
"structure": <structure index, "x.x.x"> (string),
"title": <title of the section, keep the original title>,
"physical_index": "<physical_index_X> (keep the format)" or None
},
"physical_index": "<physical_index_X> (keep the format)"
}},

],

Expand Down Expand Up @@ -738,15 +744,15 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20



def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None):
async def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None):
print(f'start fix_incorrect_toc with {len(incorrect_results)} incorrect results')
incorrect_indices = {result['list_index'] for result in incorrect_results}

end_index = len(page_list) + start_index - 1

incorrect_results_and_range_logs = []
# Helper function to process and check a single incorrect item
def process_and_check_item(incorrect_item):
async def process_and_check_item(incorrect_item):
list_index = incorrect_item['list_index']
# Find the previous correct item
prev_correct = None
Expand Down Expand Up @@ -786,28 +792,27 @@ def process_and_check_item(incorrect_item):
# Check if the result is correct
check_item = incorrect_item.copy()
check_item['physical_index'] = physical_index_int
check_result = check_title_appearance(check_item, page_list, start_index, model)
check_result = await check_title_appearance(check_item, page_list, start_index, model)

return {
'list_index': list_index,
'title': incorrect_item['title'],
'physical_index': physical_index_int,
'is_valid': check_result['answer'] == 'yes'
}


results = []
with ThreadPoolExecutor() as executor:
future_to_item = {executor.submit(process_and_check_item, item): item for item in incorrect_results}
for future in as_completed(future_to_item):
item = future_to_item[future]

try:
result = future.result()
results.append(result)
except Exception as exc:
print(f"Processing item {item} generated an exception: {exc}")


# Process incorrect items concurrently
tasks = [
process_and_check_item(item)
for item in incorrect_results
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for item, result in zip(incorrect_results, results):
if isinstance(result, Exception):
print(f"Processing item {item} generated an exception: {result}")
continue
results = [result for result in results if not isinstance(result, Exception)]

# Update the toc_with_page_number with the fixed indices and check for any invalid results
invalid_results = []
for result in results:
Expand All @@ -827,7 +832,7 @@ def process_and_check_item(incorrect_item):



def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None):
async def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None):
print('start fix_incorrect_toc')
fix_attempt = 0
current_toc = toc_with_page_number
Expand All @@ -836,7 +841,7 @@ def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_re
while current_incorrect:
print(f"Fixing {len(current_incorrect)} incorrect results")

current_toc, current_incorrect = fix_incorrect_toc(current_toc, page_list, current_incorrect, start_index, model, logger)
current_toc, current_incorrect = await fix_incorrect_toc(current_toc, page_list, current_incorrect, start_index, model, logger)

fix_attempt += 1
if fix_attempt >= max_attempts:
Expand All @@ -849,7 +854,7 @@ def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_re


################### verify toc #########################################################
def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
async def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
print('start verify_toc')
# Find the last non-None physical_index
last_physical_index = None
Expand Down Expand Up @@ -879,16 +884,12 @@ def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
item_with_index['list_index'] = idx # Add the original index in list_result
indexed_sample_list.append(item_with_index)

# Run checks in parallel
results = []
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_item = {
executor.submit(check_title_appearance, item, page_list, start_index, model): item
for item in indexed_sample_list
}

for future in as_completed(future_to_item):
results.append(future.result())
# Run checks concurrently
tasks = [
check_title_appearance(item, page_list, start_index, model)
for item in indexed_sample_list
]
results = await asyncio.gather(*tasks)

# Process results
correct_count = 0
Expand All @@ -910,7 +911,7 @@ def verify_toc(page_list, list_result, start_index=1, N=None, model=None):


################### main process #########################################################
def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None):
async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None):
print(mode)
print(f'start_index: {start_index}')

Expand All @@ -922,7 +923,7 @@ def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, s
toc_with_page_number = process_no_toc(page_list, start_index=start_index, model=opt.model, logger=logger)

toc_with_page_number = [item for item in toc_with_page_number if item.get('physical_index') is not None]
accuracy, incorrect_results = verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model)
accuracy, incorrect_results = await verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model)

logger.info({
'mode': 'process_toc_with_page_numbers',
Expand All @@ -932,26 +933,26 @@ def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, s
if accuracy == 1.0 and len(incorrect_results) == 0:
return toc_with_page_number
if accuracy > 0.6 and len(incorrect_results) > 0:
toc_with_page_number, incorrect_results = fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results,start_index=start_index, max_attempts=3, model=opt.model, logger=logger)
toc_with_page_number, incorrect_results = await fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results,start_index=start_index, max_attempts=3, model=opt.model, logger=logger)
return toc_with_page_number
else:
if mode == 'process_toc_with_page_numbers':
return meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger)
return await meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger)
elif mode == 'process_toc_no_page_numbers':
return meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger)
return await meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger)
else:
raise Exception('Processing failed')


def process_large_node_recursively(node, page_list, opt=None, logger=None):
node_page_list = page_list[node['start_index']-1:node['end_index']-1]
async def process_large_node_recursively(node, page_list, opt=None, logger=None):
node_page_list = page_list[node['start_index']-1:node['end_index']]
token_num = sum([page[1] for page in node_page_list])

if node['end_index'] - node['start_index'] > opt.max_page_num_each_node and token_num >= opt.max_token_num_each_node:
print('large node:', node['title'], 'start_index:', node['start_index'], 'end_index:', node['end_index'], 'token_num:', token_num)

node_toc_tree = meta_processor(node_page_list, mode='process_no_toc', start_index=node['start_index'], opt=opt, logger=logger)
node_toc_tree = check_title_appearance_in_start_parallel(node_toc_tree, page_list, model=opt.model, logger=logger)
node_toc_tree = await meta_processor(node_page_list, mode='process_no_toc', start_index=node['start_index'], opt=opt, logger=logger)
node_toc_tree = await check_title_appearance_in_start_concurrent(node_toc_tree, page_list, model=opt.model, logger=logger)

if node['title'].strip() == node_toc_tree[0]['title'].strip():
node['nodes'] = post_processing(node_toc_tree[1:], node['end_index'])
Expand All @@ -961,17 +962,20 @@ def process_large_node_recursively(node, page_list, opt=None, logger=None):
node['end_index'] = node_toc_tree[0]['start_index']

if 'nodes' in node and node['nodes']:
for child_node in node['nodes']:
tasks = [
process_large_node_recursively(child_node, page_list, opt, logger=logger)
for child_node in node['nodes']
]
await asyncio.gather(*tasks)

return node

def tree_parser(page_list, opt, logger=None):
check_toc_result = check_toc(page_list, opt)
async def tree_parser(page_list, opt, doc=None, logger=None):
check_toc_result = check_toc(page_list, opt)
logger.info(check_toc_result)

if check_toc_result['toc_content'] is not None and check_toc_result['page_index_given_in_toc'] == 'yes':
toc_with_page_number = meta_processor(
if check_toc_result.get("toc_content") and check_toc_result["toc_content"].strip() and check_toc_result["page_index_given_in_toc"] == "yes":
toc_with_page_number = await meta_processor(
page_list,
mode='process_toc_with_page_numbers',
start_index=1,
Expand All @@ -980,18 +984,21 @@ def tree_parser(page_list, opt, logger=None):
opt=opt,
logger=logger)
else:
toc_with_page_number = meta_processor(
toc_with_page_number = await meta_processor(
page_list,
mode='process_no_toc',
start_index=1,
opt=opt,
logger=logger)

toc_with_page_number = add_preface_if_needed(toc_with_page_number)
toc_with_page_number = check_title_appearance_in_start_parallel(toc_with_page_number, page_list, model=opt.model, logger=logger)
toc_with_page_number = await check_title_appearance_in_start_concurrent(toc_with_page_number, page_list, model=opt.model, logger=logger)
toc_tree = post_processing(toc_with_page_number, len(page_list))
for node in toc_tree:
tasks = [
process_large_node_recursively(node, page_list, opt, logger=logger)
for node in toc_tree
]
await asyncio.gather(*tasks)

return toc_tree

Expand All @@ -1012,13 +1019,15 @@ def page_index_main(doc, opt=None):
logger.info({'total_page_number': len(page_list)})
logger.info({'total_token': sum([page[1] for page in page_list])})

structure = tree_parser(page_list, opt, logger=logger)
structure = asyncio.run(tree_parser(page_list, opt, doc=doc, logger=logger))
if opt.if_add_node_id == 'yes':
write_node_id(structure)
if opt.if_add_node_summary == 'yes':
add_node_text(structure, page_list)
asyncio.run(generate_summaries_for_structure(structure, model=opt.model))
remove_structure_text(structure)
remove_structure_text(structure)
if opt.if_add_node_text == 'yes':
add_node_text_with_labels(structure, page_list)
if opt.if_add_doc_description == 'yes':
doc_description = generate_doc_description(structure, model=opt.model)
return {
Expand All @@ -1033,7 +1042,7 @@ def page_index_main(doc, opt=None):


def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None):
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None):

user_opt = {
arg: value for arg, value in locals().items()
Expand Down
Loading