Skip to content

JunnYu/FLASHQuad_pytorch

Repository files navigation

FLASHQuad_pytorch & FLASH_pytorch

pytorch implement of FLASHQuad and FLASH

Describtion

个人实现pytorch版本的《Transformer Quality in Linear Time》

存在的问题

  • A = square(relu(qk / seq_len + bias))感觉不对劲,假设训练的时候都是在seq_len=512的长度上进行的,如果预测的时候seq_len=16时,A的结果会发生很大的变化。
  • embedding部分和MLM head不确定是否使用的是ScaleNorm,不确定是否使用到dropout
  • 发现当前代码训练出来的模型结果不理想,n-1层的输出和n层的输出差距不大。

更新

    # 之前的代码.
    bias = self.rel_pos_bias(seq_len)
    kernel = torch.square(torch.relu(qk / seq_len + bias))
    # 更新后的代码.
    self.max_position_embeddings = 512
    bias = self.rel_pos_bias(self.max_position_embeddings)[:, :seq_len, :seq_len]
    kernel = torch.square(torch.relu(qk / self.max_position_embeddings + bias))

Usage

# flashquad
from flash import FLASHQuadConfig, FLASHQuadModel
import torch
config = FLASHQuadConfig()
model = FLASHQuadModel(config)
model.eval()
input_ids = torch.randint(0,12000,(4,128))
with torch.no_grad():
    outputs = model(input_ids=input_ids, output_attentions=True, output_hidden_states=True)
    print(outputs)

# flash
from flash import FLASHConfig, FLASHModel
import torch
config = FLASHConfig()
model = FLASHModel(config)
model.eval()
input_ids = torch.randint(0, 12000, (4, 128))
with torch.no_grad():
    outputs = model(
        input_ids=input_ids, output_attentions=True, output_hidden_states=True
    )
    print(outputs)

Pretrain

准备数据

CLUECorpusSmall 数据集处理教程(摘抄自paddlenlp)

数据集简介:可用于语言建模、预训练或生成型任务等,数据量超过14G,近4000个定义良好的txt文件、50亿个字。主要部分来自于nlp_chinese_corpus项目 包含如下子语料库(总共14G语料):新闻语料news2016zh_corpus.zip, 社区互动语料webText2019zh_corpus.zip,维基百科语料wiki2019zh_corpus.zip,评论数据语料comment2019zh_corpus.zip

数据集下载: 用户可以通过官方github网页下载,https://github.com/CLUEbenchmark/CLUECorpus2020 。同时,为方便用户,我们也提供了aistudio数据集下载地址。part1part2。使用aistudio版本的数据,下载好后,可以核对md5值:

> md5sum ./*
 8a8be341ebce39cfe9524fb0b46b08c5  ./comment2019zh_corpus.zip
 4bdc2c941a7adb4a061caf273fea42b8  ./news2016zh_corpus.zip
 fc582409f078b10d717caf233cc58ddd  ./webText2019zh_corpus.zip
 157dacde91dcbd2e52a60af49f710fa5  ./wiki2019zh_corpus.zip

(1) 解压文件

unzip comment2019zh_corpus.zip -d  clue_corpus_small_14g/comment2019zh_corpus
unzip news2016zh_corpus.zip    -d  clue_corpus_small_14g/news2016zh_corpus  
unzip webText2019zh_corpus.zip -d  clue_corpus_small_14g/webText2019zh_corpus
unzip wiki2019zh_corpus.zip    -d  clue_corpus_small_14g/wiki2019zh_corpus  

(2) 将txt文件转换为jsonl格式

python trans_to_json.py  --input_path ./clue_corpus_small_14g --output_path clue_corpus_small_14g.jsonl
mkdir data #创建data文件夹
mv clue_corpus_small_14g.jsonl ./data #将jsonl放进该目录

(3) 使用rjieba进行中文分词,会得到data/refids.txtdata/reftext.txt两个文件,并组合data/refids.txtdata/reftext.txt两个文件保存成huggingfacedataset

python run_chinese_ref.py  --model_name junnyu/roformer_chinese_char_base --input_path ./data/clue_corpus_small_14g.jsonl

开始训练(small版本模型)

TRAIN_DIR=./clue_small_wwm_data
OUTPUT_DIR=./wwm_flash_small/
BATCH_SIZE=32
ACCUMULATION=4
LR=1e-4
python run_mlm_wwm.py \
    --do_train \
    --tokenizer_name junnyu/roformer_chinese_char_base \
    --train_dir $TRAIN_DIR \
    --output_dir $OUTPUT_DIR \
    --logging_dir $OUTPUT_DIR/logs \
    --per_device_train_batch_size $BATCH_SIZE \
    --gradient_accumulation_steps $ACCUMULATION \
    --learning_rate $LR \
    --weight_decay 0.01 \
    --adam_epsilon 1e-6 \
    --max_steps 250000 \
    --warmup_steps 5000 \
    --logging_steps 100 \
    --save_steps 5000 \
    --seed 2022 \
    --max_grad_norm 3.0 \
    --dataloader_num_workers 6 \
    --fp16

MLM测试

# flashquad
import torch
from flash import FLASHQuadForMaskedLM
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall")
model = FLASHQuadForMaskedLM.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall")
model.eval()
text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
    pt_outputs = model(**inputs).logits[0]

pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val,idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v,t in zip(val.cpu(),tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(pt_outputs_sentence)
# pytorch: 天气预报说今天的天[气+0.9948||空+0.0011||色+0.0007||候+0.0004||势+0.0003]很好,那么我[就+0.4915||们+0.4186||也+0.0753||还+0.0021||都+0.0016]一起去公园玩吧!

# flash
import torch
from flash import FLASHForMaskedLM
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
model.eval()
text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512,  return_token_type_ids=False) #这里必须是512,不然结果可能不对。
with torch.no_grad():
    pt_outputs = model(**inputs).logits[0]

pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val,idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v,t in zip(val.cpu(),tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(pt_outputs_sentence)
# pytorch: 天气预报说今天的天[气+0.9938||天+0.0017||空+0.0011||晴+0.0007||阳+0.0002]很好,那么我[们+0.9367||就+0.0554||也+0.0041||俩+0.0005||还+0.0004]一起去公园玩吧!

Tnews分类

Tips

不怎么确定实现的对不对,如果代码有错误的话,请帮我指出来,谢谢~

About

FLASHQuad_pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published