In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.2.7



In [3]:
import random 
from tqdm import tqdm
import unicodedata

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertForMaskedLM
import pytorch_lightning as pl

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [6]:
from proofreading_tokenizer import SC_tokenizer

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)

## encode_plus_tagged
ファインチューニング時に使用。
誤変換を含む文章と正しい文章を入力とし、
符号化を行いBERTに入力できる形式にする。

In [17]:
wrong_text = '優勝トロフィーを変換した'
correct_text = '優勝トロフィーを返還した'
encoding = tokenizer.encode_plus_tagged(
    wrong_text, correct_text, max_length=12
)
print(encoding)

{'input_ids': [2, 759, 18204, 11, 4618, 15, 10, 3, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], 'labels': [2, 759, 18204, 11, 8274, 15, 10, 3, 0, 0, 0, 0]}


## encode_plus_untagged
文章を符号化し、それぞれのトークンの文章中の位置も特定しておく。

In [18]:
wrong_text = '優勝トロフィーを変換した'
encoding, spans = tokenizer.encode_plus_untagged(
    wrong_text, return_tensors='pt'
)
print('# encoding')
print(encoding)
print('# spans')
print(spans)

# encoding
{'input_ids': tensor([[    2,   759, 18204,    11,  4618,    15,    10,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
# spans
[[-1, -1], [0, 2], [2, 7], [7, 8], [8, 10], [10, 11], [11, 12], [-1, -1]]


## convert_bert_output_to_text
推論時に使用。
文章と、各トークンのラベルの予測値、文章中での位置を入力とする。
そこから、BERTによって予測された文章に変換。

In [19]:
predicted_labels = [2, 759, 18204, 11, 8274, 15, 10, 3]
predicted_text = tokenizer.convert_bert_output_to_text(
    wrong_text, predicted_labels, spans
)
print(predicted_text)

優勝トロフィーを返還した


In [20]:
bert_mlm = BertForMaskedLM.from_pretrained(MODEL_NAME)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### 学習①

In [21]:
text = '優勝トロフィーを変換した。'

encoding, spans = tokenizer.encode_plus_untagged(
    text, return_tensors='pt'
)

with torch.no_grad():
    output = bert_mlm(**encoding)
    scores = output.logits
    labels_predicted = scores[0].argmax(-1).numpy().tolist()
    
predict_text = tokenizer.convert_bert_output_to_text(
    text, labels_predicted, spans
)
predict_text

'優勝トロフィーを獲得した。'

### 学習②

In [22]:
data = [
    {
        'wrong_text': '優勝トロフィーを変換した。',
        'correct_text': '優勝トロフィーを返還した。',
    },
    {
        'wrong_text': '人と森は強制している。',
        'correct_text': '人と森は共生している。',
    }
]

max_length=32
dataset_for_loader = []
for sample in data:
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    encoding = tokenizer.encode_plus_tagged(
        wrong_text, correct_text, max_length=max_length
    )
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)
    
dataloader = DataLoader(dataset_for_loader, batch_size=2)

for batch in dataloader:
    encoding = { k: v for k, v in batch.items() }
    output = bert_mlm(**encoding)
    loss = output.loss

In [23]:
!curl -L "https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JWTD/jwtd.tar.gz&name=JWTD.tar.gz" -o JWTD.tar.gz
!tar zxvf JWTD.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   243  100   243    0     0    820      0 --:--:-- --:--:-- --:--:--   818
100   268  100   268    0     0    738      0 --:--:-- --:--:-- --:--:--   738
100 64.9M  100 64.9M    0     0  2042k      0  0:00:32  0:00:32 --:--:-- 3646k
x jwtd/
x jwtd/train.jsonl
x jwtd/test.jsonl


In [28]:
from dataset import create_dataset

# データのロード
train_df = pd.read_json(
    './jwtd/train.jsonl', orient='records', lines=True
)
test_df = pd.read_json(
    './jwtd/test.jsonl', orient='records', lines=True
)

print('学習と検証用のデータセット：')
dataset = create_dataset(train_df)
random.shuffle(dataset)
n = len(dataset)
n_train = int(n*0.8)
dataset_train = dataset[:n_train]
dataset_val = dataset[n_train:]

print('テスト用のデータセット：')
dataset_test = create_dataset(test_df)

学習と検証用のデータセット：
- 漢字誤変換の総数：235490
- トークンの対応関係のつく文章の総数: 172883
  (全体の73%)
テスト用のデータセット：
- 漢字誤変換の総数：3061
- トークンの対応関係のつく文章の総数: 2263
  (全体の74%)


In [32]:
from dataset_loader import create_dataset_for_loader

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)

# データセットの作成
max_length = 32
dataset_train_for_loader = create_dataset_for_loader(
    tokenizer, dataset_train, max_length
)
dataset_val_for_loader = create_dataset_for_loader(
    tokenizer, dataset_val, max_length
)

# データローダの作成
dataloader_train = DataLoader(
    dataset_train_for_loader, batch_size=32, shuffle=True
)
dataloader_val = DataLoader(dataset_val_for_loader, batch_size=256)

100%|██████████| 138306/138306 [01:07<00:00, 2056.13it/s]
100%|██████████| 34577/34577 [00:16<00:00, 2156.59it/s]


In [36]:
from Lightning_model import BertForMaskedLM_pl

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/'
)

trainer = pl.Trainer(
    max_epochs=5,
    callbacks=[checkpoint]
)

model = BertForMaskedLM_pl(MODEL_NAME, lr=1e-5)
trainer.fit(model, dataloader_train, dataloader_val)
best_model_path = checkpoint.best_model_path

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

  | Name     | Type            | Params
---------------------------------------------
0 | bert_mlm | BertForMaskedLM | 110 M 
---------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
442.604 

Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



In [None]:
text_list = [
    'ユーザーの試行に合わせた楽曲を配信する。',
    'メールに明日の会議の史料を添付した。',
    '乳酸菌で牛乳を発行するとヨーグルトができる。',
    '突然、子供が帰省を発した。'
]

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)
model = BertForMaskedLM_pl.load_from_checkpoint(best_model_path)
bert_mlm = model.bert_mlm

for text in text_list:
    predict_text = predict(text, tokenizer, bert_mlm) # BERTによる予測
    print('---')
    print(f'入力：{text}')
    print(f'出力：{predict_text}')

### テストデータ

####　・予測が完全に一致

In [None]:
correct_num = 0
for sample in tqdm(dataset_test):
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    predict_text = predict(wrong_text, tokenizer, bert_mlm)
    
    if correct_text == predicted_text:
        correct_num += 1

print(f'Accuracy: {correct_num/len(dataset_test):.2f}')

####　・誤変換の漢字の特定

In [None]:
correct_position_num = 0
for sample in tqdm(dataset_test):
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    
    #符号化
    encoding = tokenizer(wrong_text)
    wrong_input_ids = encoding['input_ids']
    correct_encoding = tokenizer(correct_text)
    correct_input_ids = correct_encoding['input_ids']
    
    #予測
    with torch.no_grad():
        output = bert_mlm(**encoding)
        scores = output.logits
        predict_input_ids = scores[0].argmax(-1).numpy().tolist()
        
    #特殊トークン除去
    wrong_input_ids = wrong_input_ids[1:-1]
    correct_input_ids =  correct_input_ids[1:-1]
    predict_input_ids =  predict_input_ids[1:-1]
    
    #特定
    detect_flag = True
    for wrong_token, correct_token, predict_token \
        in zip(wrong_input_ids, correct_input_ids, predict_input_ids):
        
        if wrong_token == correct_token: #正しいトークン
            if wrong_token != predict_token:  #変換する必要ないのに変換した
                detect_flag = False
                break
        else:
            if wrong_token == predict_token: #誤変換トークン
                detect_flag = False　#放置
                break
                
    if detect_flag:
        correct_position_num += 1
        
print(f'Accuracy: {correct_position_num/len(dataset_test):.2f}')