<a href="https://colab.research.google.com/github/DaisukeSugiyama-MT/NLP_samples/blob/main/bert2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install transformers===4.5.0 fugashi==1.1.0 ipadic==1.0.0
import torch
from transformers import BertJapaneseTokenizer,BertModel

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers===4.5.0
  Downloading transformers-4.5.0-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 9.5 MB/s 
[?25hCollecting fugashi==1.1.0
  Downloading fugashi-1.1.0-cp37-cp37m-manylinux1_x86_64.whl (486 kB)
[K     |████████████████████████████████| 486 kB 73.1 MB/s 
[?25hCollecting ipadic==1.0.0
  Downloading ipadic-1.0.0.tar.gz (13.4 MB)
[K     |████████████████████████████████| 13.4 MB 18.8 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 57.8 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 57.2 MB/s 
Building wheels for collected packages: ipadic, sacremoses
  Building wheel fo

In [3]:
# モデルのロード
model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
bert = BertModel.from_pretrained(model_name)

# BERTをGPUに載せる
bert = bert.cuda() 

print(bert.config)


Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

BertConfig {
  "_name_or_path": "cl-tohoku/bert-base-japanese-whole-word-masking",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "tokenizer_class": "BertJapaneseTokenizer",
  "transformers_version": "4.5.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32000
}



In [6]:
text_list = [
             "明日は自然言語処理の勉強をしよう。",
             "明日はマシーンラーニングの勉強をしよう。"
]
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
# 文書の符号化
encoding = tokenizer(
    text_list,
    max_length=32,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

# データをGPUに載せる
encoding={ k:v.cuda() for k, v in encoding.items()}

# BERTでの処理
output = bert(**encoding)
last_hidden_state = output.last_hidden_state

Downloading:   0%|          | 0.00/258k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110 [00:00<?, ?B/s]

In [8]:
output = bert(
    input_ids=encoding['input_ids'],
    attention_mask=encoding['attention_mask'],
    token_type_ids=encoding['token_type_ids']
)
print(last_hidden_state.size())

torch.Size([2, 32, 768])


In [12]:
# BERTで推論のみを行う場合はtorch.no_grad()で囲むとメモリや計算時間を減らせる
with torch.no_grad():
  output = bert(**encoding)
  last_hidden_state = output.last_hidden_state

In [13]:
# CPUに移す
last_hidden_state = last_hidden_state.cpu()
# numpy.ndarrayに変換
last_hidden_state = last_hidden_state.numpy() 
# リストに変換
last_hidden_state = last_hidden_state.tolist()
print(last_hidden_state)

[[[-0.1925278753042221, 0.01197525393217802, -0.5876621603965759, -0.230732262134552, -0.23260611295700073, 0.14993831515312195, -0.20874914526939392, -0.2596367597579956, 0.3164978623390198, -0.15749220550060272, -0.024604568257927895, -0.14330513775348663, -0.06961685419082642, 0.07929160445928574, 0.00011807527334894985, -0.3409557342529297, -0.7248815894126892, 0.4103004038333893, -0.014490310102701187, 0.4469870328903198, -0.179174542427063, -0.11338815838098526, -0.6846212148666382, 0.06800048798322678, 0.5399682521820068, -0.6441159844398499, 0.3807668089866638, -0.9615087509155273, -0.31968218088150024, 0.13129346072673798, 0.21575486660003662, -0.27473923563957214, 0.04693571478128433, -0.6227548718452454, 0.05283420905470848, -0.4819393754005432, 0.2708776891231537, -0.3297897279262543, 0.9375247955322266, -0.7310634255409241, 0.38679757714271545, 0.17567098140716553, -0.05920346826314926, 0.3582152724266052, 0.42478346824645996, -0.3546150326728821, -0.4135079085826874, -0.1