# 数据处理脚本

调用bge—m3和langchain的数据处理脚本，将数据处理成模型训练所需的格式。
输入
- input_dir: 数据集目录
- output_dir: 输出目录
- tokenizer_dir: 分词器目录
- 输出格式
```json
[
    {
        "content": "text",
        "start": 0,
        "end": 10,
        "dense": "matrix",
        "sparse": "matrix",
        "vecs": "matrix"
    }
]
```

In [1]:
from pathlib import Path
import os
import sys
import re
import json
# import nbimporter
from sentence_splitor import ChineseRecursiveTextSplitter
import pickle
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# from find_corresponding_token_start_end import find_text_indices, read_lookup_table

data_input_dir = Path('./raw_datas/3body') # 三体书
data_output_dir = Path('./bgem3_output') # 映射表存放位置
tokenizer_dir = Path('../internlm2-7B') 
tokenized_book_prompt_dir = Path('./output_datas/tokenized_3body_without_prompt') # 三体书分词后的存放位置
tokenized_book_prompt_file = tokenized_book_prompt_dir / 'tokenized_3body.txt'
chunk_size = 250
chunk_overlap = 50
# check not exist then create
if not data_output_dir.exists():
    data_output_dir.mkdir()
if not tokenizer_dir.exists(): 
    raise Exception('tokenizer not found')
if not data_input_dir.exists():
    raise Exception('input_data not found')
if not tokenized_book_prompt_dir.exists():
    raise Exception('tokenized_book_prompt_dir not found')

In [2]:
template ={
        "content": "text",
        "start": 0,
        "end": 10,
        "dense_vecs": "matrix",
        "lexical_weights": "matrix",
        "colbert_vecs": "matrix"
    }

In [3]:
proxy = 'http://127.0.0.1:20171'
os.environ['http_proxy'] = proxy
os.environ['HTTP_PROXY'] = proxy
os.environ['https_proxy'] = proxy
os.environ['HTTPS_PROXY'] = proxy
os.environ['all_proxy'] = proxy

In [4]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=True)
text_splitter = ChineseRecursiveTextSplitter(
    keep_separator=True,
    is_separator_regex=True,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap
)
model = BGEM3FlagModel('BAAI/bge-m3',  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

Fetching 23 files:   0%|          | 0/23 [00:00<?, ?it/s]

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
def read_lookup_table(file_path):
    lookup_table = []
    # 定义正则表达式来匹配 idx, token_id, 和可能跨越多行的 token_string
    pattern = re.compile(r"(\d+),\s*(\d+),\s*'((?:[^']|'(?!$))*)'", re.MULTILINE)
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()
        # 使用正则表达式匹配所有行
        matches = pattern.findall(content)
        for match in matches:
            idx, token_id, token_string = match
            # 替换代表换行符的单引号
            token_string = token_string.replace("\n'", "\n")
            lookup_table.append((int(idx), int(token_id), token_string))
    cat_text = ""
    position_idx_record = [] # 记录每个char在lookup_table中的idx
    for idx, _, token_string in lookup_table:
        cat_text += token_string
        position_idx_record+=[idx]*len(token_string)
    return lookup_table, cat_text, position_idx_record

def find_text_indices(text, lookup_table, cat_text, position_idx_record, n=20):
    # 取巧只匹配段落的前20个和后20个字符来定位整段的位置
    # 初始化 start_idx 和 end_idx 为 None
    start_idx = end_idx = None
    # # cat lookup_table[3][:]
    # cat_text = ""
    # position_idx_record = [] # 记录每个char在lookup_table中的idx
    # for idx, _, token_string in lookup_table:
    #     cat_text += token_string
    #     position_idx_record+=[idx]*len(token_string)

    # 通过头，尾各取20个字符，查找text在cat_text中的位置
    start_txt = text[:n]
    end_txt = text[-n:]

    start_idx = cat_text.find(start_txt)
    # 如果直接搜索失败，则尝试使用正则表达式搜索
    if start_idx == -1:
        target_regex = re.compile(".{0,10}".join(map(re.escape, text)), re.DOTALL)
        match = target_regex.search(cat_text)
        if match:
            start_idx = match.start()
    
    end_idx = cat_text.find(end_txt) + len(end_txt) - 1
    # 如果直接搜索失败，则尝试使用正则表达式搜索
    if end_idx == -1:
        target_regex = re.compile(".{0,10}".join(map(re.escape, text)), re.DOTALL)
        match = target_regex.search(cat_text)
        if match:
            end_idx = match.end()
    
    if start_idx != -1 and end_idx != -1:
        start_idx = position_idx_record[start_idx]
        end_idx = position_idx_record[end_idx]

        return start_idx, end_idx

    else:
        # raise Exception(f"Failed to find text: {text[:10]}")
        print(f"Failed to find text: {text[:10]}")
        return text, None

# def get_text_and_pos_from_lookup_table(lookup_table, start, end):
#     cat_text = ""
#     for idx, _, token_string in lookup_table:
#         cat_text += token_string
#         position_idx_record+=[idx]*len(token_string)
    

In [None]:
# walk through all files in input_data
for root, dirs, files in os.walk(data_input_dir):
    for file in files:
        if file.endswith('.txt'):
            with open(Path(root) / file, 'r', encoding='utf-8') as f:
                text = f.read()
                # split text into chunks
                look_up_table, cat_text, position_idx_record = read_lookup_table(tokenized_book_prompt_file)
                processed_chunk_list = []
                not_found_list = []
                chunks = text_splitter.split_text(text)
                for i, chunk in enumerate(chunks):
                    meta = template.copy()
                    start, end = find_text_indices(chunk, look_up_table, cat_text, position_idx_record)
                    if end == None:
                        not_found_list.append(chunk)
                        continue
                    output = model.encode(chunk, return_dense=True, return_sparse=True, return_colbert_vecs=True)
                    meta['content'] = chunk
                    meta['start'] = start
                    meta['end'] = end
                    meta['dense_vecs'] = output['dense_vecs']
                    meta['lexical_weights'] = output['lexical_weights'] # sparse_vecs is the same as lexical_weights
                    # meta['colbert_vecs'] = output['colbert_vecs']
                    meta['idx'] = i
                    meta['doc_id'] = file
                    processed_chunk_list.append(meta)
            # save to pickle
            with open(data_output_dir / f'{file}_no_colbert.pkl', 'wb') as f:
                pickle.dump(processed_chunk_list, f)
            
            with open(data_output_dir / 'not_found.txt', 'w', encoding='utf-8') as f:
                for chunk in not_found_list:
                    f.write(chunk + '\n')
                    


Failed to find text: 荒原依旧，但V装具感


In [None]:
# tokenize the book and store the tokenized book prompt
for root, dirs, files in os.walk(data_input_dir):
    for file in files:
        if file.endswith('.txt'):
            with open(Path(root) / file, 'r', encoding='utf-8') as f:
                text = f.read()
                encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=False)
                token_ids = encoded_input['input_ids'][0]
                tokens = [tokenizer.decode([token_id]) for token_id in token_ids]

            with open(tokenized_book_prompt_file , 'w', encoding='utf-8') as file:
                for i, (token_id, token) in enumerate(zip(token_ids, tokens)):
                    file.write(f"{i}, {token_id}, '{token}'\n")

In [None]:
# read the pickled files
with open(data_output_dir / 'data.pkl', 'rb') as f:
    data = pickle.load(f)
    print(data)

FileNotFoundError: [Errno 2] No such file or directory: 'bgem3_output/data.pkl'