## 数据预处理

**SFT训练**下载数据格式参考Alpaca数据集格式:

```python
Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 48818
})
```

然后需要转换成gpt的对话格式，也就是messages：

```python
{
    "messages": [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": response},
    ]
}
```

最后用tokenizer转换成tokens形式：

```python
Dataset({
    features: ['input_ids', 'labels', 'prompt_len'],
    num_rows: 48818
})
```

按步骤运行下面的代码，即可处理数据集

In [None]:
# 下载数据集(我习惯下载到本地保存)
!modelscope download --dataset AI-ModelScope/alpaca-gpt4-data-zh --local_dir /data/lxy/diffusion/data/course/alpaca-test

In [1]:
# 查看数据集格式
from datasets import load_dataset

data_path='/data/lxy/diffusion/data/course/alpaca-test'

## 代码报错记得删掉数据集里的dataset_infos.json
dataset=load_dataset(data_path,split='train')
print(dataset)

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 48818
})


In [2]:
# 转换成gpt格式
def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
    """
    Construct a clean text prompt from Alpaca fields.

    We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:")
    in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes
    those tags from the prompt it returns.
    """
    instruction = (instruction or "").strip()
    input_text = (input_text or "").strip()

    if input_text:
        # Keep instruction and input separated by a blank line for readability.
        return f"{instruction}\n\n{input_text}"
    else:
        return instruction

def map_fn(example):
    prompt = _build_alpaca_prompt(
        example.get("instruction", ""), example.get("input", "")
    )
    response = (example.get("output", "") or "").strip()
    return {
        "messages": [
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": response},
        ]
    }

dataset = dataset.map(
        map_fn, remove_columns=dataset.column_names, num_proc=4
    )
print(dataset)

Dataset({
    features: ['messages'],
    num_rows: 48818
})


In [None]:
# 下载模型文件
!modelscope download --model GSAI-ML/LLaDA-8B-Base --local_dir /your/path/of/model

In [4]:
# 转换成tokens形式
from transformers import AutoTokenizer
from datasets import Dataset

model_path='/data/lxy/diffusion/llada-8b'
tokenizer=AutoTokenizer.from_pretrained(model_path)

def default_mdlm_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
    """
    Build input_ids and labels for SFT.

    Args:
        row: a dataset row with `messages`
        tokenizer: a HF tokenizer
        mask_prompt_loss: whether to mask prompt tokens (set their labels to -100)

    Returns:
        dict with keys: input_ids, labels, and optionally prompt_len
    """
    prompt_response_tokens = tokenizer.apply_chat_template(
        row["messages"], tokenize=True, add_generation_prompt=False
    )
    labels = prompt_response_tokens.copy()

    if mask_prompt_loss:
        prompt_tokens = tokenizer.apply_chat_template(
            row["messages"][:-1], tokenize=True, add_generation_prompt=True
        )
        labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
        return {
            "input_ids": prompt_response_tokens,
            "labels": labels,
            "prompt_len": len(prompt_tokens),
        }

    return {"input_ids": prompt_response_tokens, "labels": labels}

final_datasets = dataset.map(
        default_mdlm_sft_map_fn,
        fn_kwargs={"tokenizer": tokenizer, "mask_prompt_loss": True},
        num_proc=16,
        desc="Tokenizing",
        remove_columns=dataset.column_names)
final_datasets: Dataset = final_datasets.shuffle(seed=42)
print(final_datasets)

output_data_path="/data/lxy/diffusion/data/course/alpaca-gpt-test"
import os
os.makedirs(output_data_path,exist_ok=True)
final_datasets.save_to_disk(output_data_path)

Dataset({
    features: ['input_ids', 'labels', 'prompt_len'],
    num_rows: 48818
})


Saving the dataset (1/1 shards): 100%|██████████| 48818/48818 [00:00<00:00, 278937.65 examples/s]


In [None]:
# 查看数据集
from datasets import load_from_disk

ds=load_from_disk(output_data_path)
ds

Dataset({
    features: ['input_ids', 'labels', 'prompt_len'],
    num_rows: 48818
})

output_data_path对应的数据集就是sft训练需要的数据集