Skip to content

THUDM/Chinese-Transformer-XL

master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
mpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Chinese-Transformer-XL

Under construction

本项目提供了智源研究院"文汇" 预训练模型Chinese-Transformer-XL的预训练和文本生成代码。[应用主页] [模型下载]

数据

本模型使用了智源研究院发布的中文预训练语料WuDaoCorpus 。具体地,我们使用了WuDaoCorpus中来自百度百科+搜狗百科(133G)、知乎(131G)、百度知道(38G)的语料,一共303GB数据。

模型

本模型使用了GPT-3 的训练目标,同时使用能够更好地处理长序列建模的Transformer-XL 替代了GPT中的Transformer。模型的结构与GPT-3 2.7B(32层,隐表示维度2560,每层32个注意力头)基本相同,因为Transformer-XL的结构改动,模型参数增加到了29亿。

结果

为了验证模型的生成能力,我们在中文的开放域长文问答上进行了评测。我们从知乎 上随机选择了100个不同领域的、不在训练语料中的问题。对每个问题,由人类测试员对一个高赞同数回答、3个模型生成的回答和3个CPM 生成的回答在流畅度、信息量、相关度、总体四个维度进行打分。测评结果如下:

模型 流畅度(1-5) 信息量(1-5) 相关度(1-5) 总体(1-10)
CPM 2.66 2.47 2.36 4.32
文汇 3.44 3.25 3.21 5.97
人类答案 3.80 3.61 3.67 6.85

可以看到相比起CPM,"文汇"更接近人类所写的高赞答案。

安装

根据requirements.txt安装pytorch等基础依赖

pip install -r requirements.txt

如果要finetune模型参数,还需要安装DeepSpeed

DS_BUILD_OPS=1 pip install deepspeed

也可以使用我们提供的Docker镜像

推理

首先下载模型的checkpoint ,目录结构如下

.
└─ txl-2.9B
       └─ mp_rank_00_model_states.pt

然后运行交互式生成脚本

bash scripts/generate_text.sh ./txl-2.9B

Finetune

模型的finetune基于使用DeepSpeed。首先在scripts/ds_finetune_gpt_2.9B.sh中修改NUM_WORKERSNUM_GPUS_PER_WORKER 为使用的节点数目和每个节点的GPU数量。如果使用多机训练的话,还要修改HOST_FILE_PATH 为hostfile的路径(DeepSpeed使用OpenMPI风格的hostfile )。

然后运行finetune脚本

bash scripts/ds_finetune_gpt_2.9B.sh ./txl-2.9B ./data.json

其中./txl-2.9B为checkpoint目录。./data.json为finetune数据,格式为jsonl文件 ,每条数据的格式为{"prompt": .., "text": ...}。其中prompt为生成的context,text为生成的内容。

如果你在finetune的遇到了OOM错误(一般是因为GPU数量或者显存不足导致的),可以尝试在scripts/ds_config_2.9B_finetune.jsonzero_optimization部分添加"cpu_offload": true,来开启ZeRO-Offload 以减少显存消耗。

模型并行

如果你的显存大小比较有限,可以尝试使用模型并行来减少显存消耗。我们提供的模型checkpoint是在单卡上运行的。首先使用change_mp.py来对hceckpoint进行切分

python change_mp.py ./txl-2.9B 2

其中2表示2路模型并行。在推理和finetune的时候,将脚本中的MP_SIZE改为2,然后使用./txl-2.9B_MP2作为运行脚本时的checkpoint路径。

引用

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published