In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.2.7



In [3]:
import random 
from tqdm import tqdm
import unicodedata

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertForMaskedLM
import pytorch_lightning as pl

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [5]:
from proofreading_tokenizer import SC_tokenizer

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)

ModuleNotFoundError: No module named 'proofreading_tokenizer'

## encode_plus_tagged
ファインチューニング時に使用。
誤変換を含む文章と正しい文章を入力とし、
符号化を行いBERTに入力できる形式にする。

In [None]:
wrong_text = '優勝トロフィーを変換した'
correct_text = '優勝トロフィーを返還した'
encoding = tokenizer.encode_plus_tagged(
    wrong_text, correct_text, max_length=12
)
print(encoding)

## encode_plus_untagged
文章を符号化し、それぞれのトークンの文章中の位置も特定しておく。

In [None]:
wrong_text = '優勝トロフィーを変換した'
encoding, spans = tokenizer.encode_plus_untagged(
    wrong_text, return_tensors='pt'
)
print('# encoding')
print(encoding)
print('# spans')
print(spans)

## convert_bert_output_to_text
推論時に使用。
文章と、各トークンのラベルの予測値、文章中での位置を入力とする。
そこから、BERTによって予測された文章に変換。

In [None]:
predicted_labels = [2, 759, 18204, 11, 8274, 15, 10, 3]
predicted_text = tokenizer.convert_bert_output_to_text(
    wrong_text, predicted_labels, spans
)
print(predicted_text)

In [None]:
bert_mlm = BertForMaskedLM.from_pretrained(MODEL_NAME)

### 学習①

In [None]:
text = '優勝トロフィーを変換した。'

encoding, spans = tokenizer.encode_plus_untagged(
    text, return_tensors='pt'
)

with torch.no_grad():
    output = bert_mlm(**encoding)
    scores = output.logits
    labels_predicted = scores[0].argmax(-1).numpy().tolist()
    
predict_text = tokenizer.convert_bert_output_to_text(
    text, labels_predicted, spans
)
predict_text

### 学習②

In [None]:
data = [
    {
        'wrong_text': '優勝トロフィーを変換した。',
        'correct_text': '優勝トロフィーを返還した。',
    },
    {
        'wrong_text': '人と森は強制している。',
        'correct_text': '人と森は共生している。',
    }
]

max_length=32
dataset_for_loader = []
for sample in data:
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    encoding = tokenizer.encode_plus_tagged(
        wrong_text, correct_text, max_length=max_length
    )
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)
    
dataloader = DataLoader(dataset_for_loader, batch_size=2)

for batch in dataloader:
    encoding = { k: v for k, v in batch.items() }
    output = bert_mlm(**encoding)
    loss = output.loss

In [None]:
!curl -L "https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JWTD/jwtd.tar.gz&name=JWTD.tar.gz" -o JWTD.tar.gz
!tar zxvf JWTD.tar.gz

In [None]:
from dataset import create_dataset

# データのロード
train_df = pd.read_json(
    './jwtd/train.jsonl', orient='records', lines=True
)
test_df = pd.read_json(
    './jwtd/test.jsonl', orient='records', lines=True
)

print('学習と検証用のデータセット：')
dataset = create_dataset(train_df)
random.shuffle(dataset)
n = len(dataset)
n_train = int(n*0.8)
dataset_train = dataset[:n_train]
dataset_val = dataset[n_train:]

print('テスト用のデータセット：')
dataset_test = create_dataset(test_df)

In [None]:
from dataset_loader import create_dataset_for_loader

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)

# データセットの作成
max_length = 32
dataset_train_for_loader = create_dataset_for_loader(
    tokenizer, dataset_train, max_length
)
dataset_val_for_loader = create_dataset_for_loader(
    tokenizer, dataset_val, max_length
)

# データローダの作成
dataloader_train = DataLoader(
    dataset_train_for_loader, batch_size=32, shuffle=True
)
dataloader_val = DataLoader(dataset_val_for_loader, batch_size=256)

In [None]:
from Lightning_model import BertForMaskedLM_pl

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/'
)

trainer = pl.Trainer(
    max_epochs=5,
    callbacks=[checkpoint]
)

model = BertForMaskedLM_pl(MODEL_NAME, lr=1e-5)
trainer.fit(model, dataloader_train, dataloader_val)
best_model_path = checkpoint.best_model_path

In [None]:
text_list = [
    'ユーザーの試行に合わせた楽曲を配信する。',
    'メールに明日の会議の史料を添付した。',
    '乳酸菌で牛乳を発行するとヨーグルトができる。',
    '突然、子供が帰省を発した。'
]

tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)
model = BertForMaskedLM_pl.load_from_checkpoint(best_model_path)
bert_mlm = model.bert_mlm

for text in text_list:
    predict_text = predict(text, tokenizer, bert_mlm) # BERTによる予測
    print('---')
    print(f'入力：{text}')
    print(f'出力：{predict_text}')

### テストデータ

####　・予測が完全に一致

In [None]:
correct_num = 0
for sample in tqdm(dataset_test):
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    predict_text = predict(wrong_text, tokenizer, bert_mlm)
    
    if correct_text == predicted_text:
        correct_num += 1

print(f'Accuracy: {correct_num/len(dataset_test):.2f}')

####　・誤変換の漢字の特定

In [None]:
correct_position_num = 0
for sample in tqdm(dataset_test):
    wrong_text = sample['wrong_text']
    correct_text = sample['correct_text']
    
    #符号化
    encoding = tokenizer(wrong_text)
    wrong_input_ids = encoding['input_ids']
    correct_encoding = tokenizer(correct_text)
    correct_input_ids = correct_encoding['input_ids']
    
    #予測
    with torch.no_grad():
        output = bert_mlm(**encoding)
        scores = output.logits
        predict_input_ids = scores[0].argmax(-1).numpy().tolist()
        
    #特殊トークン除去
    wrong_input_ids = wrong_input_ids[1:-1]
    correct_input_ids =  correct_input_ids[1:-1]
    predict_input_ids =  predict_input_ids[1:-1]
    
    #特定
    detect_flag = True
    for wrong_token, correct_token, predict_token \
        in zip(wrong_input_ids, correct_input_ids, predict_input_ids):
        
        if wrong_token == correct_token: #正しいトークン
            if wrong_token != predict_token:  #変換する必要ないのに変換した
                detect_flag = False
                break
        else:
            if wrong_token == predict_token: #誤変換トークン
                detect_flag = False　#放置
                break
                
    if detect_flag:
        correct_position_num += 1
        
print(f'Accuracy: {correct_position_num/len(dataset_test):.2f}')