## 读取数据

In [1]:
import os
from datasets import load_dataset

os.environ['HF_DATASETS_CACHE'] = "/root/lanyun-tmp/hf/cache"

dataset = load_dataset(
    path="/root/lanyun-tmp/hf/wikipedia-zh-mnbvc",
    cache_dir="/root/lanyun-tmp/hf/cache",
    num_proc=36,
)

  from .autonotebook import tqdm as notebook_tqdm


## 处理数据

In [None]:
import os
from tqdm import tqdm


def process_segment(sub_dataset, output_dir, shard_idx=0):
    """处理数据分片并保存为预训练格式
    Args:
        sub_dataset: 数据集分片（可迭代对象）
        output_dir: 输出目录路径
        shard_idx: 分片索引（用于多进程文件命名）
    """
    # 配置参数
    eos_token = "<|endoftext|>"
    max_seq_length = 512  # 最大序列长度（含结束符）
    chunk_size = max_seq_length - len(eos_token)  # 实际文本块长度
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"pretrain_{shard_idx:04d}.txt")
    progress_bar = tqdm(
        desc=f"Processing Shard {shard_idx}",
        unit="page",
        mininterval=1  # 降低更新频率提升性能
    )

    with open(output_path, "w", encoding="utf-8") as writer:
        for page in sub_dataset:
            # 1. 合并段落内容
            try:
                text = "".join(
                    str(p["内容"]) for p in page.get("段落", [])  # 防御性数据访问
                    if "内容" in p  # 过滤无效段落
                )
            except KeyError as e:
                print(f"跳过无效段落：{str(e)}")
                continue
            # 分块处理
            for start_idx in range(0, len(text), chunk_size):
                # 计算分块区间
                end_idx = start_idx + chunk_size
                chunk = text[start_idx:end_idx]
                processed_chunk = chunk + eos_token
                if len(processed_chunk) > max_seq_length:
                    processed_chunk = processed_chunk[:max_seq_length]  # 硬截断
                writer.write(processed_chunk + "\n")  # 换行分隔不同样本
            progress_bar.update(1)
    progress_bar.close()


# 多进程调用示例（需要修改parallel_chunking函数）
def parallel_chunking(dataset, output_dir, workers=32):
    """并行分片处理"""
    from concurrent.futures import ProcessPoolExecutor

    with ProcessPoolExecutor(max_workers=workers) as executor:
        futures = []
        for i in range(workers):
            # 获取数据分片（需根据实际数据集结构修改）
            sub_data = dataset["train"].shard(num_shards=workers, index=i)

            # 提交任务时传递分片索引
            futures.append(
                executor.submit(
                    process_segment,
                    sub_dataset=sub_data,
                    output_dir=output_dir,
                    shard_idx=i
                )
            )

        # 等待所有任务完成
        for f in futures:
            f.result()


processed_data = parallel_chunking(dataset, output_dir="./data/")