In [1]:
from typing import Optional

import torch
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer


class AnceEncoder(PreTrainedModel):
    config_class = RobertaConfig
    base_model_prefix = 'ance_encoder'
    load_tf_weights = None
    _keys_to_ignore_on_load_missing = [r'position_ids']
    _keys_to_ignore_on_load_unexpected = [r'pooler', r'classifier']

    def __init__(self, config: RobertaConfig):
        super().__init__(config)
        self.config = config
        self.roberta = RobertaModel(config)
        self.embeddingHead = torch.nn.Linear(config.hidden_size, 768)
        self.norm = torch.nn.LayerNorm(768)
        self.init_weights()

    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, torch.nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def init_weights(self):
        self.roberta.init_weights()
        self.embeddingHead.apply(self._init_weights)
        self.norm.apply(self._init_weights)

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
    ):
        input_shape = input_ids.size()
        device = input_ids.device
        if attention_mask is None:
            attention_mask = (
                torch.ones(input_shape, device=device)
                if input_ids is None
                else (input_ids != self.roberta.config.pad_token_id)
            )
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        pooled_output = sequence_output[:, 0, :]
        pooled_output = self.norm(self.embeddingHead(pooled_output))
        return pooled_output

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ance_model = AnceEncoder.from_pretrained("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp")
print(ance_model.state_dict())

In [3]:
print(ance_model.embeddingHead.state_dict())

OrderedDict([('weight', tensor([[-0.0341,  0.0045,  0.0044,  ..., -0.0211, -0.0520, -0.0090],
        [-0.0131, -0.0029, -0.0176,  ..., -0.0013,  0.0204, -0.0181],
        [-0.0193,  0.0138, -0.0323,  ..., -0.0102, -0.0300,  0.0058],
        ...,
        [-0.0253,  0.0032,  0.0159,  ...,  0.0047, -0.0380,  0.0231],
        [ 0.0448, -0.0007, -0.0048,  ...,  0.0102, -0.0102,  0.0074],
        [ 0.0232,  0.0296, -0.0184,  ...,  0.0052,  0.0276,  0.0167]])), ('bias', tensor([-1.4677e-02, -6.2046e-05,  1.7860e-02, -1.8284e-02, -2.6648e-02,
        -9.9468e-03, -2.2705e-02,  3.9847e-03, -1.5261e-02,  1.8067e-02,
         1.1441e-02,  2.5921e-03,  2.4571e-02, -7.2709e-03,  2.4661e-02,
        -6.8095e-03, -6.7075e-03, -2.6658e-03,  8.3833e-03, -2.5529e-03,
        -2.0994e-02,  1.0314e-02, -3.8494e-04,  2.2485e-02, -1.6493e-02,
        -2.3805e-03, -7.0093e-03, -6.1510e-03, -9.9726e-03, -1.1607e-02,
        -6.5083e-03, -2.2268e-02, -8.3861e-03, -1.8380e-02, -2.2023e-02,
        -2.1806e-02,

In [4]:
from openmatch.modeling.linear import LinearHead
new_linear_head = LinearHead(768, 768, True)
print(new_linear_head.state_dict())

OrderedDict([('linear.weight', tensor([[ 0.0249,  0.0205, -0.0238,  ..., -0.0218, -0.0148,  0.0182],
        [-0.0172,  0.0269,  0.0198,  ..., -0.0256, -0.0361,  0.0154],
        [-0.0332,  0.0299, -0.0210,  ...,  0.0345, -0.0276, -0.0275],
        ...,
        [-0.0352, -0.0208, -0.0229,  ...,  0.0147,  0.0234,  0.0040],
        [ 0.0227, -0.0052, -0.0267,  ..., -0.0344, -0.0095, -0.0160],
        [ 0.0106, -0.0280, -0.0244,  ..., -0.0225, -0.0139,  0.0119]])), ('linear.bias', tensor([-1.5460e-02, -1.1468e-02,  2.8657e-02, -8.7796e-06,  3.2687e-02,
        -1.7543e-02,  4.9116e-03, -3.0813e-02,  3.6923e-03,  2.4054e-02,
         2.1349e-02, -2.1203e-02, -5.8887e-03, -1.5290e-02,  2.1982e-03,
        -2.2853e-02, -1.6303e-02, -1.5814e-03, -2.8677e-02,  1.1883e-03,
         2.0205e-03,  1.2654e-02,  1.5185e-02,  1.7850e-03, -3.4767e-02,
         2.6119e-02, -1.9799e-02, -4.8992e-03, -1.1099e-02, -1.5646e-02,
         2.4912e-02,  2.4298e-02,  2.8402e-02, -9.3261e-03, -3.2825e-02,
      

In [5]:
new_linear_head.linear.weight.data = ance_model.embeddingHead.weight.data
new_linear_head.linear.bias.data = ance_model.embeddingHead.bias.data
print(new_linear_head.state_dict())

OrderedDict([('linear.weight', tensor([[-0.0341,  0.0045,  0.0044,  ..., -0.0211, -0.0520, -0.0090],
        [-0.0131, -0.0029, -0.0176,  ..., -0.0013,  0.0204, -0.0181],
        [-0.0193,  0.0138, -0.0323,  ..., -0.0102, -0.0300,  0.0058],
        ...,
        [-0.0253,  0.0032,  0.0159,  ...,  0.0047, -0.0380,  0.0231],
        [ 0.0448, -0.0007, -0.0048,  ...,  0.0102, -0.0102,  0.0074],
        [ 0.0232,  0.0296, -0.0184,  ...,  0.0052,  0.0276,  0.0167]])), ('linear.bias', tensor([-1.4677e-02, -6.2046e-05,  1.7860e-02, -1.8284e-02, -2.6648e-02,
        -9.9468e-03, -2.2705e-02,  3.9847e-03, -1.5261e-02,  1.8067e-02,
         1.1441e-02,  2.5921e-03,  2.4571e-02, -7.2709e-03,  2.4661e-02,
        -6.8095e-03, -6.7075e-03, -2.6658e-03,  8.3833e-03, -2.5529e-03,
        -2.0994e-02,  1.0314e-02, -3.8494e-04,  2.2485e-02, -1.6493e-02,
        -2.3805e-03, -7.0093e-03, -6.1510e-03, -9.9726e-03, -1.1607e-02,
        -6.5083e-03, -2.2268e-02, -8.3861e-03, -1.8380e-02, -2.2023e-02,
      

In [6]:
new_linear_head.save("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch")

In [7]:
print(ance_model.norm.state_dict())

OrderedDict([('weight', tensor([0.9708, 0.9955, 1.0185, 1.0315, 0.9295, 0.9660, 0.9582, 1.0673, 0.9752,
        0.9782, 0.9338, 0.9858, 1.0718, 0.9989, 0.9620, 0.9338, 0.9553, 1.0657,
        1.0708, 0.9689, 0.9788, 1.0222, 0.9681, 1.0001, 0.8992, 0.9928, 1.0671,
        0.9518, 0.9993, 0.9721, 0.9680, 0.9129, 0.9813, 1.0401, 1.0053, 0.9719,
        1.1091, 0.9781, 1.0491, 0.8955, 1.0089, 1.0080, 0.9314, 1.0266, 1.0447,
        0.9087, 0.9050, 0.8942, 0.9826, 0.9986, 1.0622, 0.9686, 0.9314, 1.0202,
        1.0072, 0.9808, 1.0131, 0.9872, 0.9512, 0.9536, 0.9545, 0.9971, 1.0037,
        0.9934, 0.9990, 1.0510, 0.9971, 0.9902, 0.9830, 1.0670, 1.0600, 0.9960,
        0.9253, 0.9514, 0.9761, 0.9386, 1.0085, 0.9790, 0.9794, 0.9901, 1.0217,
        1.0663, 1.0358, 1.0436, 1.0059, 1.0151, 0.9537, 0.9998, 1.0085, 0.9630,
        0.9407, 1.0088, 1.0064, 1.0782, 0.9869, 0.9782, 1.0114, 0.9704, 1.0092,
        0.9931, 1.0098, 0.9902, 1.0652, 0.9820, 0.9383, 1.0611, 0.9834, 1.0424,
        1.0171, 

In [8]:
from openmatch.modeling.layernorm import LayerNorm
new_layer_norm = LayerNorm(768)
print(new_layer_norm.state_dict())

OrderedDict([('layernorm.weight', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
      

In [10]:
new_layer_norm.layernorm.weight.data = ance_model.norm.weight.data
new_layer_norm.layernorm.bias.data = ance_model.norm.bias.data
print(new_layer_norm.state_dict())

OrderedDict([('layernorm.weight', tensor([0.9708, 0.9955, 1.0185, 1.0315, 0.9295, 0.9660, 0.9582, 1.0673, 0.9752,
        0.9782, 0.9338, 0.9858, 1.0718, 0.9989, 0.9620, 0.9338, 0.9553, 1.0657,
        1.0708, 0.9689, 0.9788, 1.0222, 0.9681, 1.0001, 0.8992, 0.9928, 1.0671,
        0.9518, 0.9993, 0.9721, 0.9680, 0.9129, 0.9813, 1.0401, 1.0053, 0.9719,
        1.1091, 0.9781, 1.0491, 0.8955, 1.0089, 1.0080, 0.9314, 1.0266, 1.0447,
        0.9087, 0.9050, 0.8942, 0.9826, 0.9986, 1.0622, 0.9686, 0.9314, 1.0202,
        1.0072, 0.9808, 1.0131, 0.9872, 0.9512, 0.9536, 0.9545, 0.9971, 1.0037,
        0.9934, 0.9990, 1.0510, 0.9971, 0.9902, 0.9830, 1.0670, 1.0600, 0.9960,
        0.9253, 0.9514, 0.9761, 0.9386, 1.0085, 0.9790, 0.9794, 0.9901, 1.0217,
        1.0663, 1.0358, 1.0436, 1.0059, 1.0151, 0.9537, 0.9998, 1.0085, 0.9630,
        0.9407, 1.0088, 1.0064, 1.0782, 0.9869, 0.9782, 1.0114, 0.9704, 1.0092,
        0.9931, 1.0098, 0.9902, 1.0652, 0.9820, 0.9383, 1.0611, 0.9834, 1.0424,
      

In [11]:
new_layer_norm.save("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch")

In [16]:
from transformers import RobertaModel, RobertaTokenizer
ance_model.roberta.save_pretrained("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch")
tokenizer = RobertaTokenizer.from_pretrained("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp")
tokenizer.save_pretrained("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch")

('/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/tokenizer_config.json',
 '/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/special_tokens_map.json',
 '/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/vocab.json',
 '/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/merges.txt',
 '/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/added_tokens.json')

In [15]:
config = {
    "tied": True,
    "plm_backbone": {
        "type": type(ance_model.roberta).__name__,
        "feature": "last_hidden_state",
    },
    "pooling": "first",
    "linear_head": True,
    "normalize": False,
    "layernorm": True,
}
import json
with open("/data1/yushi/pretrained_models/ance-msmarco-doc-firstp-openmatch/openmatch_config.json", "w") as f:
    json.dump(config, f, indent=4)