<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2022notebooks/2022_0529iwashita_yoshihara_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

---
- data: 2022_0529
- tile: `2022_0529iwashita_yoshihara_demo.ipynb'
- author: 浅川伸一
---

# BERT のマスク化言語モデルを使った穴埋め問題のデモ

In [None]:
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
if isColab:
    !pip install --upgrade openpyxl
    !pip install --upgrade pandas
    !pip install --upgrade fugashi[unidic-lite]
    !pip install --upgrade ipadic
    !python -m unidic download
    !pip install transformers

    !pip install --upgrade jaconv

# BERT の輸入

In [None]:
from transformers import BertJapaneseTokenizer
from transformers import BertForMaskedLM
import torch
import transformers

# stockmarket 本での事前訓練済データ
# model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
# model_ja_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
model_ja_name = 'cl-tohoku/bert-base-japanese'  # 東北大学乾研による 日本語 BERT 実装
# see https://huggingface.co/sonoisa/sentence-bert-base-ja-mean-tokens-v2
#model_ja_name = 'sonoisa/sentence-bert-base-ja-mean-tokens-v2'  # 東北大学乾研による 日本語 BERT 実装

tknz = BertJapaneseTokenizer.from_pretrained(model_ja_name)
bert_lm = BertForMaskedLM.from_pretrained(model_ja_name, return_dict = True)
#model_orig = BertForMaskedLM.from_pretrained(model_ja_name, return_dict = True)

# リソースの選択（CPU/GPU）
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

###################################################################################
# import torch
# from transformers import AutoModel, AutoTokenizer
# bertjapanese = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese-char")
# tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char")
###################################################################################

# やさしい日本語のデータを取得する

In [None]:
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())

import os
import sys
import requests
import pandas as pd
SNOWs={'T15': {'url':"https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T15-2020.1.7.xlsx"},
       'T23': {'url':"https://filedn.com/lit4DCIlHwxfS1gj9zcYuDJ/SNOW/T23-2020.1.7.xlsx"},}


for corpus in SNOWs:
    url = SNOWs[corpus]['url']
    excel_fname = corpus + '-2020.1.7.xlsx'
    if not os.path.exists(excel_fname):
        r = requests.get(url)
        with open(excel_fname, 'wb') as f:
            total_length = int(r.headers.get('content-length'))
            print(f'{excel_fname} をダウンロード中 {total_length} バイト')
            f.write(r.content)

    SNOWs[corpus]['df'] = pd.read_excel(excel_fname, engine='openpyxl')
    SNOWs[corpus]['df'] = SNOWs[corpus]['df'].rename(columns={'#日本語(原文)': 'ja', 
                                                              '#やさしい日本語':'easy_ja',
                                                              '#英語(原文)':'en'})


In [None]:
import jaconv

_snow_sents = SNOWs['T15']['df']['easy_ja'].to_list() + SNOWs['T23']['df']['easy_ja'].to_list()
snow_sents = [jaconv.normalize(line, 'NFKC') for line in _snow_sents]
print(snow_sents[:3])

In [None]:
from termcolor import colored

texts = ['誰が一番に着くか私には分かりません。', '多くの動物が人間によって殺された。', '私はテニス部員です。']
mask_token = tknz.special_tokens_map['mask_token']
for text in texts:
    print(tknz(text)['input_ids'])
    print(tknz.convert_ids_to_tokens(tknz(text)['input_ids']))
    token_to_be_masked = '分かり'
    text_masked = text.replace(token_to_be_masked, mask_token)
    print(text_masked)
    print(text)
#help(text.replace)

In [None]:
#tknz.special_tokens_map

In [None]:
# text = '誰が一番に着くか私には分かりません。'
# _text = text.replace('誰が','[MASK]')
# print(_text)

In [None]:
from torch.nn import functional as F

texts = ['誰が一番に着くか私には分かりません。', '多くの動物が人間によって殺された。', '私はテニス部員です。']
masked_texts = ['誰が一番に着くか私には[MASK]ません。', '多くの[MASK]が人間によって殺された。', '私はテニス部員[MASK]。']

n_max = 5
for i, text in enumerate(masked_texts):
    print(colored(f'{i:3d} text:{text}', color='blue', attrs=['bold']))
    inputs = tknz.encode_plus(text, return_tensors="pt")
    mask_index = torch.where(inputs['input_ids'][0] == tknz.mask_token_id)
    outputs  = bert_lm(**inputs)
    logits = outputs.logits
    softmax = F.softmax(logits, dim=-1)
    mask_word = softmax[0, mask_index, :]
    topN = torch.topk(mask_word, n_max, dim=1)[1][0]

    for i, token in enumerate(topN):
        wrd = tknz.convert_ids_to_tokens([token])
        sentence_replaced = text.replace(tknz.mask_token, wrd[0])
        print(f'{i+1:2d}', colored(sentence_replaced, color='grey', attrs=['bold']))


In [None]:
import numpy as np

# text =  '多くの動物が人間によって殺された。'
# inputs = tknz.encode_plus(text, return_tensors='pt')
# print(f'inputs_["input_ids"]:{inputs["input_ids"]}')
# print(f'inputs["input_ids"][0]:{inputs["input_ids"][0]}')
# input_ids_length = inputs['input_ids'].size()[1]
# mask_idx = np.random.choice(input_ids_length)
# print(f'mask_idx:{mask_idx}')
# input_ids_masked = inputs['input_ids'].detach()
# input_ids_masked[0][mask_idx] = tknz.mask_token_id
# print(f'input_ids_masked:{input_ids_masked}')
# print(tknz.convert_ids_to_tokens(inputs['input_ids'][0]))

# inputs['input_ids'] = input_ids_masked
# bert_
# #print(tknz.convert_ids_to_tokens(torch.LongTensor(input_ids_masked)))

texts = ['誰が一番に着くか私には分かりません。', '多くの動物が人間によって殺された。', '私はテニス部員です。']

n_max = 5
for i, text in enumerate(texts):
    print(colored(f'{i:3d} text:{text}', color='blue', attrs=['bold']))
    inputs = tknz.encode_plus(text, return_tensors="pt")
    mask_pos = np.random.choice(inputs['input_ids'].size()[1]-2) + 1
    mask_idx = inputs['input_ids'][0][mask_pos]
    word_masked = tknz.convert_ids_to_tokens([mask_idx])
    print(f'mask_pos:{mask_pos}', f'mask_idx:{mask_idx}',   f'mask_word:{word_masked}')
    inputs['input_ids'][0][mask_index] = tknz.mask_token_id
    outputs  = bert_lm(**inputs)
    logits = outputs.logits
    _softmax = F.softmax(logits, dim=-1)
    _words_preded = _softmax[0, mask_index, :]
    _topN_token = torch.topk(_words_preded, n_max)[1]  #, dim=1)

    for j, token in enumerate(_topN_token):
        word_pred = tknz.convert_ids_to_tokens([token])
        #print(f'word_pred[0]:{word_pred[0]}')
        sentence_replaced = text.replace(word_masked[0], word_pred[0])
        print(f'{j+1:2d}', colored(sentence_replaced, color='grey', attrs=['bold']))
