In [1]:
import pyarrow.parquet as pq
import pyarrow as pa
import ujson
import numpy as np
from rich import progress

# 1. 处理预训练阶段数据
## 1.1 处理Wiki数据

In [2]:
origin_wiki_file = './data/wiki.simple.txt'

liness = []
with open(origin_wiki_file, 'r', encoding='utf-8') as f:
    lines = f.readlines()

In [None]:
lines[0:5]

In [None]:
len(lines)

合并词条和内容

In [5]:
items, content = [], []
key_word, kw_line_idx = '', 0
content_start = False  # 词条内容开始标记

bos_token, eos_token = '[BOS]', '[EOS]' 
for i, line in enumerate(lines):
    
    line_strip = line.strip()

    # 词条以冒号`：`结尾
    if len(line_strip) > 0 and line_strip[-1] in (':', '：'):
        key_word = ''.join(line_strip[: -1])
        kw_line_idx = i 
        continue
    
    # 词条key_word在下一行，则合并上个词条并保存
    if i == kw_line_idx + 1 and key_word in line_strip or i == len(lines) - 1:
        txt = ''.join(content)

        if len(txt) > 0:
            items.append(f"{txt}{eos_token}")
            
        content = []
        content.append(f"{key_word}：")
    
    content.append(line)


In [None]:
len(items)

In [None]:
items[20]

将Wiki数据合并为长度固定的行

In [8]:
def split_txt_cropus_to_chunk_data(texts: list[str], batch_size: int=512 ** 2, max_len: int=320, window_size: int = 2) -> list[str]:
    
    buffer, buffer_len = [], 0
    chunk_data = []

    for i, line in enumerate(texts):
        buffer_len += len(line)
        buffer.append(line)

        if buffer_len >= batch_size or i == len(texts) - 1:
            buffer_txt = ''.join(buffer)
            
            # - window_size为滑动窗口，这样每个窗口都包含有window_size个上文
            for i in range(0, len(buffer_txt), max_len - window_size):

                chunk_data.append(''.join(buffer_txt[i: i + max_len]))
            
            buffer, buffer_len = [], 0
    
    return chunk_data

In [9]:
chunk_data = split_txt_cropus_to_chunk_data(items)
print(len(chunk_data))

2045355


In [None]:
chunk_data[0: 3]

In [11]:
tb = pa.Table.from_arrays([chunk_data], names=['text'])
# compression='GZIP'
pq.write_table(table=tb, where='./data/wiki_chunk_320_2.2M.parquet', row_group_size=50000, data_page_size=50000, )

In [12]:
# wiki_samples = np.random.choice(chunk_data, size=100_0000).tolist()
wiki_samples = chunk_data[0: 10_0000]

In [None]:
len(wiki_samples)

In [None]:
wiki_samples[-5:]

##  1.2 处理百度百科数据库

In [1]:
import pyarrow.parquet as pq
import pyarrow as pa
import ujson
from unicodedata import normalize

In [2]:
def split_txt_cropus_to_chunk_data(texts: list[str], batch_size: int=512 ** 2, max_len: int=320, window_size: int = 2) -> list[str]:
    
    buffer, buffer_len = [], 0
    chunk_data = []

    for i, line in enumerate(texts):
        buffer_len += len(line)
        buffer.append(line)

        if buffer_len >= batch_size or i == len(texts) - 1:
            buffer_txt = ''.join(buffer)
            
            # - window_size为滑动窗口，这样每个窗口都包含有window_size个上文
            for i in range(0, len(buffer_txt), max_len - window_size):

                chunk_data.append(''.join(buffer_txt[i: i + max_len]))
            
            buffer, buffer_len = [], 0
    
    return chunk_data

In [3]:
bd_baike_563w_file = './data/563w_baidubaike.json'
baike_items = []
eos_token = '[EOS]' 
max_len = 320
batch_size, batch_cnt = 200_0000, 0

In [None]:
with open(bd_baike_563w_file, 'r', encoding='utf-8') as f:
    line = f.readline()
    line = normalize('NFKC', line)
    item = ujson.loads(line)
    print(item)

In [None]:
with open(bd_baike_563w_file, 'r', encoding='utf-8') as f:

    def process_none(s: str) -> str:
        if s: return s
        return ''
    
    while True:
        line = f.readline()
        if not line: break

        item = ujson.loads(line)
        cur_txt, cur_len = [], 0

        if not item['title']: continue

        temp_txt = f"{item['title']}：{process_none(item['summary'])}"
        
        cur_len += len(temp_txt)
        cur_txt.append(temp_txt)

        for section in item['sections']:

            # 太长的截断不要了
            if cur_len > max_len:
                break
            
            title = f"{section['title']}：" if section['title'] else ""
            temp_txt = f"{title}{process_none(section['content'])}"
            
            cur_len += len(temp_txt)
            cur_txt.append(temp_txt)
        
        # normalize 处理\u3000 \xa0，全角转半角
        temp_txt =  normalize('NFKC', ''.join(cur_txt))

        if len(temp_txt) > max_len:
            # 从 max_len 开始找第一个句号，叹号
            n, i = len(temp_txt), max_len
            while i < n and temp_txt[i] not in ('。', '！'):
                i += 1
            temp_txt = ''.join(temp_txt[0: i + 1])

        # 添加 eos token
        temp_txt = f"{temp_txt}{eos_token}"
        
        baike_items.append( temp_txt )

        if len(baike_items) % batch_size == 0:

            chunk_data = split_txt_cropus_to_chunk_data(baike_items)
            tb = pa.Table.from_arrays([chunk_data], names=['text'])

            file_name = f'./data/baike_chunk_320_5.6M_{batch_cnt}.parquet'
            pq.write_table(table=tb, where=file_name, row_group_size=50000, )

            print(f"save to {file_name}")

            batch_cnt += 1
            baike_items = []

    if len(baike_items) > 0:
        chunk_data = split_txt_cropus_to_chunk_data(baike_items)
        tb = pa.Table.from_arrays([chunk_data], names=['text'])

        file_name = f'./data/baike_chunk_320_5.6M_{batch_cnt}.parquet'
        pq.write_table(table=tb, where=file_name, row_group_size=50000, )

        print(f"save to {file_name}")

        batch_cnt += 1
        baike_items = []

In [None]:
file_list = [
    f'./data/baike_chunk_320_5.6M_{batch_cnt}.parquet' for batch_cnt in range(3)
]

line_cnt = 0 
for file in file_list:
    pf = pq.read_table(file)
    line_cnt += pf.num_rows

print(f"bake all lines: {line_cnt}")

In [None]:
chunk_data[20: 25]

## 1.3 处理bell指令数据
尝试在预训练阶段加入prompt指令数据，就是尝试在预训练解决加加入部分Sft数据

In [2]:
train_data = []
eval_data = []
eval_size = 1_0000
max_len = 400
root = 'D:/GitHub/ChatLM-mini-Chinese/data/raw_data'

In [3]:
with open(root + '/bell_open_source/train_3.5M_CN.json', 'r', encoding='utf-8') as f:
    for line in f:
        item = ujson.loads(line)

        if len(item['conversations']) != 2: continue

        conversation = item['conversations']
        txt = ''
        if conversation[0]['from'] =='human':
            txt = f"{conversation[0]['value']}\n{conversation[1]['value']}"
        else:
            txt = f"{conversation[1]['value']}\n{conversation[0]['value']}"
        
         # 收集测试数据
        if len(txt) >= max_len and len(txt) < max_len + 8 and len(eval_data) < eval_size and np.random.rand() <= 0.12:
            eval_data.append(txt)
            continue
            

        if len(txt) >= max_len: continue
        train_data.append(txt)

In [4]:
print(len(eval_data), len(train_data))

5429 1084177


In [5]:

for file in [root + '/bell_open_source/train_2M_CN.json',  root + '/bell_open_source/Belle_open_source_1M.json']:
    with open(file, 'r', encoding='utf-8') as f:
        for line in f:
            item = ujson.loads(line)

            if item['input'].strip() != '':
                txt = f"{item['instruction']}\n{item['input']}\n{item['output']}"
            else:
                txt = f"{item['instruction']}\n{item['output']}"

            # 收集测试数据
            if len(txt) >= max_len and len(txt) < max_len + 8 and len(eval_data) < eval_size and np.random.rand() > 0.75:
                eval_data.append(txt)
                continue
            
            if len(txt) == 0 or len(txt) >= max_len: continue
            train_data.append(
                    txt
            )

In [6]:
print(len(eval_data), len(train_data))

10000 3150704


In [None]:
print(train_data[0:5])


In [8]:
tb = pa.Table.from_arrays([train_data], names=['text'])
# compression='GZIP'
pq.write_table(table=tb, where=f'./data/bell_pretrain_{max_len}_3M.parquet', row_group_size=20480, data_page_size=20480, )

In [9]:
tb = pa.Table.from_arrays([eval_data], names=['text'])
# compression='GZIP'
pq.write_table(table=tb, where=f'./data/pretrain_eval_{max_len}_1w.parquet', row_group_size=20480, data_page_size=20480, )

# 2. 处理sft阶段数据

In [6]:
lines = []
with open('./data/sft_0.8M_CN.json', 'r', encoding='utf-8') as f:
    for line in f:
        item = ujson.loads(line)

        txt = f"{item['instruction']}{item['output']}"
        
        if len(txt) == 0 or len(txt) >= 320: continue
        lines.append(
                item
        )

In [7]:
print(len(lines))

726475


In [8]:
tb = pa.Table.from_pylist(lines)
# compression='GZIP'
pq.write_table(table=tb, where='./data/sft_train_data.parquet', row_group_size=20480, data_page_size=20480, )

## 统计 token数量

In [8]:
from pyarrow import parquet as pq
from transformers import PreTrainedTokenizerFast
from tqdm import tqdm

In [9]:
tokenizer = PreTrainedTokenizerFast.from_pretrained('./model_save/tokenizer/')

In [10]:
# 字符数量
files = [
    './data/baike_chunk_320_5.6M_0.parquet', 
    './data/baike_chunk_320_5.6M_1.parquet', 
    './data/baike_chunk_320_5.6M_2.parquet', 
    './data/bell_pretrain_400_3M.parquet',
    # './data/pretrain_eval_400_1w.parquet',
]

total_char = 0
for file in files: 
    pf = pq.read_table(file)
    for row in pf['text']:
        total_char += len(row.as_py())

In [None]:
print(total_char)

In [None]:
total_token = 0
buffer = []
for file in files: 
    pf = pq.read_table(file)
    n = pf.num_rows
    for i, row in tqdm(enumerate(pf['text']), total=n):
        buffer.append(row.as_py())

        if len(buffer) >= 10000 or i == n - 1:
            input_ids = tokenizer(buffer, return_attention_mask=False)['input_ids']
            
            total_token += sum([len(item) for item in input_ids])
            buffer = []

if len(buffer) > 0:
    input_ids = tokenizer(buffer, return_attention_mask=False)['input_ids']
    
    total_token += sum([len(item) for item in input_ids])
    buffer = []

In [None]:
print(total_token)