# ForMaskedLMモデルで学習した学習モデルを，AutoModelで読み込んだモデルで読み込もう!

In [1]:
# 必要ライブラリのインポート
import torch
from torchinfo import summary
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer

# モデルの用意
`model_MLM`:AutoModelForMaskedLMで読み込んだモデル<br>
`model`:AutoModelで読み込んだモデル

In [2]:
# Load the model
model_name = 'bert-base-uncased'
model_MLM = AutoModelForMaskedLM.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## モデル構造の比較

In [3]:
print(model_MLM)
print("="*200)
print(model)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

ここから確認すると，<br>
```
BertForMaskedLM(
  (bert): BertModel(
```

から，`BertForMaskedLM.bert`に`BertModel`が入っていることが確認できる．

In [4]:
summary(model_MLM.bert)

Layer (type:depth-idx)                                  Param #
BertModel                                               --
├─BertEmbeddings: 1-1                                   --
│    └─Embedding: 2-1                                   23,440,896
│    └─Embedding: 2-2                                   393,216
│    └─Embedding: 2-3                                   1,536
│    └─LayerNorm: 2-4                                   1,536
│    └─Dropout: 2-5                                     --
├─BertEncoder: 1-2                                      --
│    └─ModuleList: 2-6                                  --
│    │    └─BertLayer: 3-1                              7,087,872
│    │    └─BertLayer: 3-2                              7,087,872
│    │    └─BertLayer: 3-3                              7,087,872
│    │    └─BertLayer: 3-4                              7,087,872
│    │    └─BertLayer: 3-5                              7,087,872
│    │    └─BertLayer: 3-6                              

In [5]:
summary(model)

Layer (type:depth-idx)                                  Param #
BertModel                                               --
├─BertEmbeddings: 1-1                                   --
│    └─Embedding: 2-1                                   23,440,896
│    └─Embedding: 2-2                                   393,216
│    └─Embedding: 2-3                                   1,536
│    └─LayerNorm: 2-4                                   1,536
│    └─Dropout: 2-5                                     --
├─BertEncoder: 1-2                                      --
│    └─ModuleList: 2-6                                  --
│    │    └─BertLayer: 3-1                              7,087,872
│    │    └─BertLayer: 3-2                              7,087,872
│    │    └─BertLayer: 3-3                              7,087,872
│    │    └─BertLayer: 3-4                              7,087,872
│    │    └─BertLayer: 3-5                              7,087,872
│    │    └─BertLayer: 3-6                              

summaryで，`model_MLM.bert`と`model`のモデル構造は，BertPoolerの有無のみの違いで他の構造は同じであることがわかる．

## モデルの保存方法
これまでの調査より，`BertForMaskedLM.bert`に`BertModel`が入っていることが確認できる．<br>
また，`model_MLM.bert`と`model`のモデル構造は，BertPoolerの有無のみの違いで他の構造は同じであることがわかる．

In [6]:
#とりあえず，test_model.pthとして保存
torch.save(model_MLM.bert.to('cpu').state_dict(),'test_model.pth')

## モデルの読み込み

In [7]:
# モデルの状態をロード
state_dict = torch.load('test_model.pth')
msg = model.load_state_dict(state_dict, strict=False)
print("msg : ", msg)

msg :  _IncompatibleKeys(missing_keys=['pooler.dense.weight', 'pooler.dense.bias'], unexpected_keys=[])


_IncompatibleKeysとして，Pooler部分のみ読み込みできなかっただけで，その他はできていることを確認しました．

## 同じことを，sentence-transformerでやってみよう！

In [8]:
model_name="sentence-transformers/all-mpnet-base-v2"
model_MLM = AutoModelForMaskedLM.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

Some weights of the model checkpoint at sentence-transformers/all-mpnet-base-v2 were not used when initializing MPNetForMaskedLM: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing MPNetForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MPNetForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MPNetForMaskedLM were not initialized from the model checkpoint at sentence-transformers/all-mpnet-base-v2 and are newly initialized: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be 

In [9]:
torch.save(model_MLM.mpnet.to('cpu').state_dict(),'test_mpnet.pth')

In [10]:
# モデルの状態をロード
state_dict = torch.load('test_mpnet.pth')
msg = model.load_state_dict(state_dict, strict=False)
print("msg : ", msg)

msg :  _IncompatibleKeys(missing_keys=['pooler.dense.weight', 'pooler.dense.bias'], unexpected_keys=[])


## 高本さんの保存方法での読み込み方法模索

In [11]:
model_MLM.save_pretrained('./output_test')

In [23]:
from safetensors.torch import load_file

model_path = "/taiga/sentence-transformers_mlm/output_test/model.safetensors"
state_dict = load_file(model_path)

正しいキーへのマッピング

In [24]:
mpnet_state_dict = {}
for key, value in state_dict.items():
    if key.startswith("mpnet."):
        mpnet_state_dict[key.replace("mpnet.", "")] = value

In [25]:
missing_keys, unexpected_keys = model.load_state_dict(mpnet_state_dict, strict=False)

# 欠落しているキーと予期しないキーを出力
if missing_keys:
    print(f"Missing keys: {missing_keys}")
if unexpected_keys:
    print(f"Unexpected keys: {unexpected_keys}")

Missing keys: ['pooler.dense.weight', 'pooler.dense.bias']


`missing_keys`: モデルに必要なキーがステートディクショナリに含まれていない場合のキーがリスト．<br>
ロードしようとしているモデルのパラメータの一部がステートディクショナリから見つからなかった場合にこのリストに追加.<br>
<br>
`unexpected_keys`: ステートディクショナリに含まれているがロードしようとしているモデルに必要ないキーのリスト．<br>
ステートディクショナリに余計なパラメータが含まれている場合にこのリストに追加.<br>