文件目的：
1. 输入召回的文本数据，寻找到这些文本在原始文本中的位置
2. 扩充召回的文本数据，使得每个召回的文本数据都有一个上下文
3. 寻找扩充后的文本数据在tokenized后的文本中的位置
4. 输出```[{"content": "", "start": 0, "end": 0, "doc_id": 0}]```格式的json文件

In [31]:
import json, os, sys
from pathlib import Path
import re
from typing import List, Optional, Any
import argparse

# from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging

In [32]:
curr_dir = Path().resolve()
book_prompt_dir = curr_dir / "book_prompt_txt"  # 切分好的书籍段落
retrived_data_dir = curr_dir / "retrived_paragraph"  # 召回的段落
output_dir = curr_dir / "processed_data_merged_content_merged"  # 处理好的数据存放目录
tokenized_book_dir = curr_dir / "tokenized_book_prompt"  # 分词后的书籍段落

start_ignore = 562 # 匹配文件时忽略前面的段落
end_ignore = 10 # 匹配文件时忽略后面的段落


# check if the directory exists
if not book_prompt_dir.exists():
    print("The directory book_prompt_txt does not exist")
if not retrived_data_dir.exists():
    print("The directory retrived_paragraph does not exist")
if not output_dir.exists():
    print("The directory processed_data does not exist")
if not tokenized_book_dir.exists():
    print("The directory tokenized_book_prompt does not exist")

json_template = [{"content": "", "start": 0, "end": 0, "doc_id": 0}]

In [33]:
def process_retrived_file(file_name):
    with open(file_name, "r") as f:
        data = f.read()
    # 按"\n\n"分割
    data = data.split("\n\n")
    # 清理换行符和空字符串元素
    data = [x for x in data if x.strip() != "" and x != "\n"]
    # 去除句子头尾的空格和换行符
    data = [x.strip() for x in data]
    return data


def find_and_expand(target_str, folder_path, expansion_length=1000, start_ignore=0, end_ignore=0):
    # 将目标字符串转换为一个正则表达式，允许在其字符之间存在最多为n的空白字符
    target_regex = re.compile(".{0,10}".join(map(re.escape, target_str)), re.DOTALL)

    for root, _, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(root, file)
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()
                    assert len(content) >= start_ignore + end_ignore, f"File {file_path} is too short"
                    content = content[start_ignore:-end_ignore]
                    # 首先尝试直接搜索目标字符串
                    index = content.find(target_str)
                    # 如果直接搜索失败，则尝试使用正则表达式搜索
                    if index == -1:
                        match = target_regex.search(content)
                        if match:
                            index = match.start()
                    # 如果找到了匹配项, 无论是直接搜索还是正则搜索
                    if index != -1:
                        start = max(0, index - expansion_length // 2)
                        end = min(
                            len(content),
                            index + len(target_str) + expansion_length // 2,
                        )
                        # 找到开始和结束点附近的分隔符，以保持句子完整性
                        start = content.rfind("\n", 0, start)
                        if start == -1:
                            start = 0
                        else:
                            start += 1
                        end = content.find("\n", end)
                        if end == -1:
                            end = len(content) - 1
                        expanded_str = content[start:end]
                        return expanded_str, file
            except Exception as e:
                print(f"Error reading {file_path}: {e}")

    print(f"Target string not found in any file. {target_str}")
    return None, None


def read_lookup_table(file_path):
    lookup_table = []
    # 定义正则表达式来匹配 idx, token_id, 和可能跨越多行的 token_string
    pattern = re.compile(r"(\d+),\s*(\d+),\s*'((?:[^']|'(?!$))*)'", re.MULTILINE)
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()
        # 使用正则表达式匹配所有行
        matches = pattern.findall(content)
        for match in matches:
            idx, token_id, token_string = match
            # 替换代表换行符的单引号
            token_string = token_string.replace("\n'", "\n")
            lookup_table.append((int(idx), int(token_id), token_string))

    return lookup_table

def find_text_indices(text, lookup_table, n=20):
    # 取巧只匹配段落的前20个和后20个字符来定位整段的位置
    # 初始化 start_idx 和 end_idx 为 None
    start_idx = end_idx = None
    # cat lookup_table[3][:]
    cat_text = ""
    position_idx_record = [] # 记录每个char在lookup_table中的idx
    for idx, _, token_string in lookup_table:
        cat_text += token_string
        position_idx_record+=[idx]*len(token_string)

    # 通过头，尾各取20个字符，查找text在cat_text中的位置
    start_txt = text[:n]
    end_txt = text[-n:]

    start_idx = cat_text.find(start_txt)
    # 如果直接搜索失败，则尝试使用正则表达式搜索
    if start_idx == -1:
        target_regex = re.compile(".{0,10}".join(map(re.escape, text)), re.DOTALL)
        match = target_regex.search(cat_text)
        if match:
            start_idx = match.start()
    
    end_idx = cat_text.find(end_txt) + len(end_txt) - 1
    # 如果直接搜索失败，则尝试使用正则表达式搜索
    if end_idx == -1:
        target_regex = re.compile(".{0,10}".join(map(re.escape, text)), re.DOTALL)
        match = target_regex.search(cat_text)
        if match:
            end_idx = match.end()
    
    if start_idx != -1 and end_idx != -1:
        start_idx = position_idx_record[start_idx]
        end_idx = position_idx_record[end_idx]
        return start_idx, end_idx
    else:
        raise Exception(f"Failed to find text: {text[:10]}")


def merge_strings(str1: str, str2: str) -> str:
    # 情况1: 检查一个字符串是否是另一个字符串的子串
    if str1 in str2:
        return str2
    if str2 in str1:
        return str1
    
    # 情况2: 寻找重叠部分并合并
    # 检查 str1 结尾与 str2 开始的重叠
    for i in range(1, min(len(str1), len(str2)) + 1):
        if str1[-i:] == str2[:i]:
            return str1 + str2[i:]
    
    # 检查 str2 结尾与 str1 开始的重叠
    for i in range(1, min(len(str1), len(str2)) + 1):
        if str2[-i:] == str1[:i]:
            return str2 + str1[i:]
    
    # 如果没有重叠部分，直接连接两个字符串
    return str1 + str2

def merge_same_content(data: List[dict]) -> List[dict]:
    # 按 doc_id 分组
    grouped_data = {}
    for item in data:
        if item['doc_id'] not in grouped_data:
            grouped_data[item['doc_id']] = []
        grouped_data[item['doc_id']].append(item)
    
    result = []
    # 对每个 doc_id 分别处理
    for doc_id, items in grouped_data.items():
        # 根据 start 排序
        items.sort(key=lambda x: x['start'])
        merged = []
        for item in items:
            # 如果 merged 为空或者当前 item 与 merged 中最后一个元素没有重叠，则直接添加
            if not merged or item['start'] > merged[-1]['end']:
                merged.append(item)
            else:
                # 如果有重叠，合并
                merged[-1]['end'] = max(merged[-1]['end'], item['end'])
                merged[-1]['length'] = merged[-1]['end'] - merged[-1]['start']
                merged[-1]['content'] = merge_strings(merged[-1]['content'], item['content'])
        result.extend(merged)
    
    return result

def process_retrived_data(question_path: Path, store_path: Path, tokenized_book_dir,book_prompt_dir, start_ignore: int = 0, end_ignore: int = 0):
    processed_data = []
    not_found_sentence = []
    unextended_sentence = process_retrived_file(question_path)
    for i, sentence in enumerate(unextended_sentence):
        # print(f"processing {i}th sentence")
        # print(sentence[:10])
        meta_data = json_template[0].copy()
        # 保存并扩充句子
        content, doc_id = find_and_expand(sentence, book_prompt_dir, start_ignore=start_ignore, end_ignore=end_ignore)
        if content is None:
            # raise Exception(f"Failed to find and expand sentence: {sentence}")
            # print(f"Failed to find and expand sentence: {sentence}")
            meta_data["not_found_content"] = sentence
            continue
        # 保存处理后的数据
        meta_data["content"] = content
        # 保存doc_id int
        meta_data["doc_id"] = int(doc_id.split(".")[0])

        lookup_table = read_lookup_table(tokenized_book_dir / f"answer_{doc_id}")
        # compiled_regex = build_regex_from_lookup_table(lookup_table)
        start_idx, end_idx = find_text_indices(content, lookup_table)
        # start_idx, end_idx = find_text_indices(content, lookup_table)
        if(start_idx is None or end_idx is None):
            not_found_sentence.append(sentence)
            continue
        meta_data["start"] = start_idx
        meta_data["end"] = end_idx + 1 # 左闭右开
        meta_data["length"] = end_idx - start_idx
        processed_data.append(meta_data)

    processed_data = merge_same_content(processed_data)

    json.dump(
        processed_data,
        store_path.open("w", encoding="utf-8"),
        ensure_ascii=False,
        indent=4,
    )
    json.dump(
        not_found_sentence,
        store_path.with_suffix(".not_found.json").open("w", encoding="utf-8"),
        ensure_ascii=False,
        indent=4,
    )

In [34]:
for i in range(14):
    process_retrived_data(retrived_data_dir / f"{i}.txt", output_dir / f"{i}.json", tokenized_book_dir, book_prompt_dir, start_ignore=start_ignore, end_ignore=end_ignore)
