In [1]:
import transformers
from transformers import BertModel

import torch
import os
from copy import deepcopy

class KobertBiEncoder(torch.nn.Module):
    def __init__(self, config):
        super(KobertBiEncoder, self).__init__()
        # self.passage_encoder = BertModel.from_pretrained("skt/kobert-base-v1")
        # self.query_encoder = BertModel.from_pretrained("skt/kobert-base-v1")
        # skt/kobert-base-v1 not be used anymore
        # initialize it from outside instead 
        self.passage_encoder = BertModel(config)
        self.query_encoder = BertModel(config)
        self.emb_sz = (
            self.passage_encoder.pooler.dense.out_features
        )  # get cls token dim
        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 

    def forward(self, x: torch.LongTensor, attn_mask: torch.LongTensor, type: str = "passage") -> torch.FloatTensor:
        """passage 또는 query를 bert로 encoding합니다."""
        assert type in (
            "passage",
            "query",
        ), "type should be either 'passage' or 'query'"
        if type == "passage":
            return self.passage_encoder(
                input_ids=x, attention_mask=attn_mask
            ).pooler_output
        else:
            return self.query_encoder(
                input_ids=x, attention_mask=attn_mask
            ).pooler_output
    
    def load(self, model_ckpt_path):
        with open(model_ckpt_path, "rb") as f:
            state_dict = torch.load(f, map_location=self.device)
        self.load_state_dict(state_dict, strict=False)

In [2]:
from transformers import BertConfig

config = BertConfig.from_json_file("config.json")
model = KobertBiEncoder(config)

In [3]:
model.load("2050iter_model.pt")

In [4]:
model.passage_encoder.save_pretrained("kordpr-psg-enc")

## `transformers`로 export

In [60]:
model.passage_encoder
# model.passage_encoder.save_pretrained("../models/kordpr-psg-enc")

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(8002, 768, padding_idx=1)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )

KorDPR의 경우엔 pooling layer을 그대로 사용하고 있음.
- (output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True))
- (pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True))

In [93]:
from transformers import DPRContextEncoder, DPRConfig

dpr_config = DPRConfig.from_pretrained("kordpr-psg-enc")
# dpr_config.proje
dpr_ctx_enc = DPRContextEncoder(dpr_config)

You are using a model of type bert to instantiate a model of type dpr. This is not supported for all configurations of models and can yield errors.


In [94]:
dpr_ctx_enc

DPRContextEncoder(
  (ctx_encoder): DPREncoder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(8002, 768, padding_idx=1)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
 

DPR의 경우 조금 다름:
- (output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)

In [95]:
len(list(model.passage_encoder.named_parameters())), len(list(dpr_ctx_enc.ctx_encoder.bert_model.named_parameters()))

(199, 197)

In [97]:
dpr_bert_params = set(name for name, item in dpr_ctx_enc.ctx_encoder.bert_model.named_parameters())
kordpr_bert_params = set(name for name, item in model.passage_encoder.named_parameters())

# Compare the parameters
kordpr_bert_params - dpr_bert_params

{'pooler.dense.bias', 'pooler.dense.weight'}

DPR의 config에 projection_dim이 dense 역할을 하는거 같음.

In [99]:
from transformers import DPRContextEncoder, DPRConfig

dpr_config = DPRConfig.from_pretrained("kordpr-psg-enc")
dpr_config.projection_dim = 768
dpr_ctx_enc = DPRContextEncoder(dpr_config)

You are using a model of type bert to instantiate a model of type dpr. This is not supported for all configurations of models and can yield errors.


In [104]:
# 기존 bert 모델로 찍었을 때:
print(len(list(model.passage_encoder.named_parameters())), len(list(dpr_ctx_enc.ctx_encoder.bert_model.named_parameters())))

# ctx_encoder로 찍었을 때
print(len(list(model.passage_encoder.named_parameters())), len(list(dpr_ctx_enc.ctx_encoder.named_parameters())))

199 197
199 199


맞는거같음

In [118]:
print(f"KorDPR 형태: {model.passage_encoder.pooler}")
print(f"DPR 형태: {dpr_ctx_enc.ctx_encoder.encode_proj}")

KorDPR 형태: BertPooler(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (activation): Tanh()
)
DPR 형태: Linear(in_features=768, out_features=768, bias=True)


`Tanh()`있고 없고도 중요할듯.

### 결론:

- https://github.com/huggingface/transformers/issues/19111
- https://github.com/huggingface/transformers/issues/14486

이슈/PR을 참고했을 때, DPR은 BertPooler의 output을 사용하지 않으므로 헷갈리지 않게 이를 제거하여 구현한 것으로 보임.
이를 해결하는 방법은 두 갠데, `BertModel`의 `add_pooling_layer=True`로 바꾸는 방법이 있음.