fairseqによる機械翻訳(日英)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install sentencepiece
import sentencepiece as spm

### 前処理



*   NFKC正規化
*   分長比率の高い物の削除



### 学習とデータ数調整

In [None]:
# 1万件に設定
!sed -n '1,10000p' /content/drive/MyDrive/ja-news/news-crawl21-500k-rem-nfkc.en >> train.en
!sed -n '1,10000p' /content/drive/MyDrive/ja-news/news-crawl21-500k-rem-nfkc.ja >> train.ja
!wc -l train.en
!wc -l train.ja

In [None]:
#/content/drive/MyDrive/translation
# ライブラリのインストール
! pip install fairseq==0.12.1

# sacrebleu install(bleuを測るツール)
! pip install sacrebleu
! pip install sacrebleu[ja]

In [None]:
# テストセット(test,dev)であるwmt20のロード
! sacrebleu -t wmt20 -l en-ja --echo src > wmt.test.en
! sacrebleu -t wmt20 -l en-ja --echo ref > wmt.test.ja
! sacrebleu -t wmt20/dev -l en-ja --echo src > wmt.valid.en
! sacrebleu -t wmt20/dev -l en-ja --echo ref > wmt.valid.ja

In [18]:
import unicodedata
# 前処理（日本語の単語分割）
# モデルのロード
sp = spm.SentencePieceProcessor()
sp.Load("/content/drive/MyDrive/enja_spm_models-3.0/spm.ja.nopretok.model")

# valid
fout = open("valid.ja", "w")
fin = open("/content/drive/MyDrive/wmt.valid.ja", "r")
for line in fin:
    line = unicodedata.normalize("NFKC",line)
    fout.write(" ".join(sp.EncodeAsPieces(line)) + "\n")
fin.close()
fout.close()

# test
fout = open("test.ja", "w")
fin = open("/content/drive/MyDrive/wmt.test.ja", "r")
for line in fin:
    line = unicodedata.normalize("NFKC",line)
    fout.write(" ".join(sp.EncodeAsPieces(line)) + "\n")
fin.close()
fout.close()

In [None]:
# 前処理（英語の単語分割）
import unicodedata

sp.Load("/content/drive/MyDrive/enja_spm_models-3.0/spm.en.nopretok.model")

# valid
fout = open("valid.en", "w")
fin = open("/content/drive/MyDrive/wmt.valid.en", "r")
for line in fin:
    line = unicodedata.normalize("NFKC",line)
    fout.write(" ".join(sp.EncodeAsPieces(line)) + "\n")
fin.close()
fout.close()


# test
fout = open("test.en", "w")
fin = open("/content/drive/MyDrive/wmt.test.en", "r")
for line in fin:
    line = unicodedata.normalize("NFKC",line)
    fout.write(" ".join(sp.EncodeAsPieces(line)) + "\n")
fin.close()
fout.close()

# Fairseq

In [None]:
# 準備
# unkタグの割合が示してある
# --destdir :バイナリデータの保存場所(def:data-bin)
ENDICT='big-3.0-enja/dict.en.txt'
JADICT='big-3.0-enja/dict.ja.txt'
! fairseq-preprocess --source-lang en --target-lang ja --trainpref train --validpref valid --testpref test \
                     --destdir pbl-big3-enja3 --srcdict $ENDICT --tgtdict $JADICT

In [None]:
# 訓練
PRETRAINED_MODEL="/content/drive/MyDrive/big-3.0-enja/big.pretrain.pt"
SEED=10
! fairseq-train pbl-big3-enja3 --arch transformer --restore-file $PRETRAINED_MODEL \
    --no-epoch-checkpoints --no-last-checkpoints\
    --seed $SEED \
    --patience 10 \
    --batch-size 16 --optimizer adam --adam-betas '(0.9,0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
    --dropout 0.1 --weight-decay 0.0001 --clip-norm 1.0 \
    --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --encoder-embed-dim 1024 --decoder-embed-dim 1024 \
    --encoder-ffn-embed-dim 4096 --decoder-ffn-embed-dim 4096 \
    --encoder-attention-heads 16 --decoder-attention-heads 16 \
    --log-interval 100 --validate-interval-updates 9999 --save-interval-updates 1000

### 評価

In [None]:
# 評価
# max-len-b :出力単語数の制限
ref_file = 'test.ja'
REF_FILE="test.en"

# 学習済みモデル(サンプル)
MODEL1='model1.pt'
MODEL2='model2.pt'

! fairseq-interactive pbl-big3-enja3 --path $MODEL1:$MODEL2 --input $REF_FILE --batch-size 128 \
 --remove-bpe sentencepiece --buffer-size 1024 --nbest 1 --max-len-b 50\
--beam 5 --task translation\
| grep "^H-" | sort -V | cut -f3 > result.txt
bleu(ref_file, "result.txt")

In [None]:
# wmt20 のテストデータに対するNFKC正規化
with open('/content/drive/MyDrive/test.ja') as f,open('/content/drive/MyDrive/test-det.ja') as fw:
  for line in f:
    line = unicodedata.normalize(line,'NFKC')
    fw.write(line)

In [None]:
# sacrebleuによるBLEUスコア
!sacrebleu  /content/drive/MyDrive/test-det.ja -i result.txt --tokenize ja-mecab