In [1]:
import numpy as np
from bert4keras.backend import keras
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array
import tensorflow as tf
import modeling
import utils
import tensorflow.keras.backend as K
from tensorflow.python import pywrap_tensorflow

Using TensorFlow backend.


In [2]:
config_path = 'chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'chinese_L-12_H-768_A-12/bert_model.ckpt'

In [3]:
model = build_transformer_model(config_path, checkpoint_path)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


2022-07-04 15:09:43.888141: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2022-07-04 15:09:43.908993: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f867e8d75a0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-07-04 15:09:43.909007: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version


In [4]:
# model.summary()
names = [layer.name for layer in model.layers]
names

['Input-Token',
 'Input-Segment',
 'Embedding-Token',
 'Embedding-Segment',
 'Embedding-Token-Segment',
 'Embedding-Position',
 'Embedding-Norm',
 'Embedding-Dropout',
 'Transformer-0-MultiHeadSelfAttention',
 'Transformer-0-MultiHeadSelfAttention-Dropout',
 'Transformer-0-MultiHeadSelfAttention-Add',
 'Transformer-0-MultiHeadSelfAttention-Norm',
 'Transformer-0-FeedForward',
 'Transformer-0-FeedForward-Dropout',
 'Transformer-0-FeedForward-Add',
 'Transformer-0-FeedForward-Norm',
 'Transformer-1-MultiHeadSelfAttention',
 'Transformer-1-MultiHeadSelfAttention-Dropout',
 'Transformer-1-MultiHeadSelfAttention-Add',
 'Transformer-1-MultiHeadSelfAttention-Norm',
 'Transformer-1-FeedForward',
 'Transformer-1-FeedForward-Dropout',
 'Transformer-1-FeedForward-Add',
 'Transformer-1-FeedForward-Norm',
 'Transformer-2-MultiHeadSelfAttention',
 'Transformer-2-MultiHeadSelfAttention-Dropout',
 'Transformer-2-MultiHeadSelfAttention-Add',
 'Transformer-2-MultiHeadSelfAttention-Norm',
 'Transformer-2

In [9]:
for index in range(len(model.layers)):
    print(model.get_layer(index=index).name, model.get_layer(index=index).output_shape)

Input-Token (None, None)
Input-Segment (None, None)
Embedding-Token (None, None, 768)
Embedding-Segment (None, None, 768)
Embedding-Token-Segment (None, None, 768)
Embedding-Position (None, None, 768)
Embedding-Norm (None, None, 768)
Embedding-Dropout (None, None, 768)
Transformer-0-MultiHeadSelfAttention (None, None, 768)
Transformer-0-MultiHeadSelfAttention-Dropout (None, None, 768)
Transformer-0-MultiHeadSelfAttention-Add (None, None, 768)
Transformer-0-MultiHeadSelfAttention-Norm (None, None, 768)
Transformer-0-FeedForward (None, None, 768)
Transformer-0-FeedForward-Dropout (None, None, 768)
Transformer-0-FeedForward-Add (None, None, 768)
Transformer-0-FeedForward-Norm (None, None, 768)
Transformer-1-MultiHeadSelfAttention (None, None, 768)
Transformer-1-MultiHeadSelfAttention-Dropout (None, None, 768)
Transformer-1-MultiHeadSelfAttention-Add (None, None, 768)
Transformer-1-MultiHeadSelfAttention-Norm (None, None, 768)
Transformer-1-FeedForward (None, None, 768)
Transformer-1-FeedF

In [13]:
for layer in model.layers:
    for weight in layer.weights:
        print(weight.name, weight.shape)

Embedding-Token/embeddings:0 (21128, 768)
Embedding-Segment/embeddings:0 (2, 768)
Embedding-Position/embeddings:0 (512, 768)
Embedding-Norm/beta:0 (768,)
Embedding-Norm/gamma:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_1/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_1/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_2/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_2/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_3/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_3/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_4/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_4/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention-Norm/beta:0 (768,)
Transformer-0-MultiHeadSelfAttention-Norm/gamma:0 (768,)
Transformer-0-FeedForward/dense_5/kernel:0 (768, 3072)
Transformer-0-FeedForward/dense_5/bias:0 (3072,)
Transformer-0-FeedForward/dense_6/kernel:0 (3072, 768)
Transformer-0-FeedForward/dense_6/bias:0 (768,)


In [15]:
class Attention(tf.keras.Model):
    """keras风格的广义attention"""
    def __init__(self, dim=768, n_head=12, head_dim=64, initializer_range=0.2):
        """
        参数:
            dim: cross attention变量的feature维度
            n_head: multi-head attention的head个数
            head_dim: 每个head的维度
            initializer_range: 全连接层参数初始化参数
        """
        super(Attention, self).__init__()
        self.dim = dim
        self.n_head = n_head
        self.head_dim = head_dim
        self.initializer_range = initializer_range
        self.inner_dim = self.n_head * self.head_dim  # inner_dim是n_head * head_dim之后将要被reshape为n_head和head_dim
        self.scale = self.head_dim ** -0.5  # scale是1/根号下dk
        self.query_layer = tf.keras.layers.Dense(
            units=self.inner_dim,
            # use_bias=False,
            name='query',
            kernel_initializer=modeling.create_initializer(initializer_range)
        )
        self.key_layer = tf.keras.layers.Dense(
            units=self.inner_dim,
            # use_bias=False,
            name='key',
            kernel_initializer=modeling.create_initializer(initializer_range)
        )
        self.value_layer = tf.keras.layers.Dense(
            units=self.inner_dim,
            # use_bias=False,
            name='value',
            kernel_initializer=modeling.create_initializer(initializer_range)
        )
        self.output_layer = tf.keras.layers.Dense(
            units=self.dim,
            # use_bias=False,
            name='output',
            kernel_initializer=modeling.create_initializer(initializer_range)
        )

    def transpose_for_scores(self, input_tensor, batch_size, n_head, seq_length, head_dim):
        """
        该方法对对input_tensor进行reshape
        input_tensor [batch, seq_len, dim] -> [batch, seq_len, n_head, head_dim] -> [batch, n_head, seq_len, head_dim]
        """
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, n_head, head_dim])
        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    def call(self, inputs, training=None, mask=None):
        """
        输入:
            from_tensor 用于query的tensor [batch, seq_len_from, dim]
            to_tensor 用于key和value的tensor [batch, seq_len_to, dim]
            若from_tensor == to_tensor 即为self-attention 否则为cross-attention
            attention_mask 用于对attention加以mask [batch, seq_len_from, seq_len_to] 值为0或1 对0位置进行mask 对1位置进行保留
        输出:
            output 经过广义attention得到的结果 [batch, seq_len_from, n_head*head_dim]
        """
        from_tensor, to_tensor = inputs[:2]
        from_shape = modeling.get_shape_list(from_tensor, expected_rank=3)
        to_shape = modeling.get_shape_list(to_tensor, expected_rank=3)
        # assert from_shape[0] == to_shape[0]
        # assert from_shape[-1] == to_shape[-1] == self.dim
        batch = from_shape[0]
        seq_len_from = from_shape[1]
        seq_len_to = to_shape[1]
        # 计算维度

        query = self.query_layer(from_tensor)
        # query [batch, seq_len_from, inner_dim]
        key = self.key_layer(to_tensor)
        # key [batch, seq_len_to, inner_dim]
        value = self.value_layer(to_tensor)
        # value [batch, seq_len_to, inner_dim]

        query = self.transpose_for_scores(query, batch, self.n_head, seq_len_from, self.head_dim)
        key = self.transpose_for_scores(key, batch, self.n_head, seq_len_to, self.head_dim)
        value = self.transpose_for_scores(value, batch, self.n_head, seq_len_to, self.head_dim)
        # query [batch, n_head, seq_len_from, head_dim]
        # key value [batch, n_head, seq_len_to, head_dim]

        query = query * self.scale
        # query = query / 根号下dk

        attention_score = tf.einsum('...ik,...jk->...ij', query, key)
        # i->seq_len_from j->seq_len_to k->head_dim
        # attention_score [batch, n_head, seq_len_from, seq_len_to]

        if len(inputs) == 3:
            attention_mask = inputs[2]
            # [batch, seq_len_from, seq_len_to]
            attention_mask = tf.expand_dims(attention_mask, axis=[1])
            # [batch, 1, seq_len_from, seq_len_to]
            adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
            # 对于0进行mask对于1不mask 所以(1-mask_val) * -10000.0 0->-10000.0 1->0.0
            attention_score += adder
            # attention_score加上adder就是mask

        attention_probs = tf.nn.softmax(attention_score)
        # attention_score经过softmax得到概率分布
        # attention_probs [batch, n_head, seq_len_from, seq_len_to]

        output = tf.einsum('...ij,...jk->...ik', attention_probs, value)
        # i->seq_len_from j->seq_len_to k->head_dim
        # output [batch, n_head, seq_len_from, head_dim]

        output = tf.transpose(output, [0, 2, 1, 3])
        output = tf.reshape(output, shape=(batch, seq_len_from, self.inner_dim))
        output = self.output_layer(output)
        # [batch, seq_len_from, dim]
        return output


class LMBlock(tf.keras.Model):
    """LM Block 本质为self-attention + ffw"""
    def __init__(self, dim=768, n_head=12, ffw_mult=4):
        """
        参数:
            dim: cross attention变量的feature维度
            n_head: multi-head attention的head个数
            ffw_mult: feedforward的参数量相对于feature的倍数 经验值为4
        """
        super(LMBlock, self).__init__()
        self.dim = dim
        self.n_head = n_head
        self.ffw_mult = ffw_mult
        if self.dim % n_head != 0:
            raise ValueError(
                "参数dim必须整除n_head"
            )
        self.head_dim = self.dim // self.n_head
        self.all_layers = list()
        self.attn = Attention(dim=self.dim, n_head=self.n_head, head_dim=self.head_dim)
        self.norm = tf.keras.layers.LayerNormalization()
        self.ffw = utils.FeedForward(dim=self.dim, mult=self.ffw_mult)

    def call(self, inputs, training=None, mask=None):
        """
        输入:
            input_tensor: 输入的input_tensor [batch, seq_len, dim]
            目前LMBlock实现中 并没有mask 原因在于Flamingo所对应的场景不需要进行mask
        输出:
            output: 经过LM计算后的结果 [batch, seq_len, dim]
        """
        input_tensor = inputs
        input_shape = modeling.get_shape_list(input_tensor, expected_rank=3)
        assert self.dim == input_shape[2]

        # input_tensor = self.pre_norm(input_tensor)
        input_tensor = input_tensor + self.attn(inputs=(input_tensor, input_tensor))
        # self_attention + residual
        input_tensor = self.norm(input_tensor)
        # attn之后layer norm
        input_tensor = input_tensor + self.ffw(inputs=input_tensor)
        # ffw + residual
        return input_tensor

In [16]:
lm_block = LMBlock()

In [17]:
lm_block.build(input_shape=(None, None, 768))

In [13]:
input_tensor = tf.random.normal(shape=(32, 64, 768))
output = lm_block(input_tensor)
output.shape

TensorShape([Dimension(32), Dimension(64), Dimension(768)])

In [28]:
lm_block.summary()

Model: "lm_block_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
layer_normalization_6 (Layer multiple                  1536      
_________________________________________________________________
attention_3 (Attention)      multiple                  2362368   
_________________________________________________________________
feed_forward_3 (FeedForward) multiple                  4723968   
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
_________________________________________________________________


In [18]:
for layer in lm_block.layers:
    print('-------------------------')
    print(f'layer name {layer.name}')
    for weight in layer.weights:
        print(weight.name, weight.shape)

-------------------------
layer name attention_2
attention_2/query/kernel:0 (768, 768)
attention_2/query/bias:0 (768,)
attention_2/key/kernel:0 (768, 768)
attention_2/key/bias:0 (768,)
attention_2/value/kernel:0 (768, 768)
attention_2/value/bias:0 (768,)
attention_2/output/kernel:0 (768, 768)
attention_2/output/bias:0 (768,)
-------------------------
layer name layer_normalization_4
layer_normalization_4/gamma:0 (768,)
layer_normalization_4/beta:0 (768,)
-------------------------
layer name feed_forward_2
feed_forward_2/dense_4/kernel:0 (768, 3072)
feed_forward_2/dense_4/bias:0 (3072,)
feed_forward_2/dense_5/kernel:0 (3072, 768)
feed_forward_2/dense_5/bias:0 (768,)
feed_forward_2/layer_normalization_5/gamma:0 (768,)
feed_forward_2/layer_normalization_5/beta:0 (768,)


In [19]:
for layer in model.layers:
    print('-------------------------')
    print(f'layer name {layer.name}')
    for weight in layer.weights:
        print(weight.name, weight.shape)

-------------------------
layer name Input-Token
-------------------------
layer name Input-Segment
-------------------------
layer name Embedding-Token
Embedding-Token/embeddings:0 (21128, 768)
-------------------------
layer name Embedding-Segment
Embedding-Segment/embeddings:0 (2, 768)
-------------------------
layer name Embedding-Token-Segment
-------------------------
layer name Embedding-Position
Embedding-Position/embeddings:0 (512, 768)
-------------------------
layer name Embedding-Norm
Embedding-Norm/beta:0 (768,)
Embedding-Norm/gamma:0 (768,)
-------------------------
layer name Embedding-Dropout
-------------------------
layer name Transformer-0-MultiHeadSelfAttention
Transformer-0-MultiHeadSelfAttention/dense_1/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_1/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_2/kernel:0 (768, 768)
Transformer-0-MultiHeadSelfAttention/dense_2/bias:0 (768,)
Transformer-0-MultiHeadSelfAttention/dense_3/kernel:0 (768, 76

In [21]:
print(lm_block.get_layer(index=0).get_weights()[0])

[[-0.05685756 -0.01691563 -0.18896337 ...  0.24470313 -0.10127556
   0.01406233]
 [-0.05613437 -0.10988988  0.35343567 ... -0.06752219 -0.22311936
  -0.04273423]
 [-0.13579732 -0.18542694  0.0312222  ...  0.05397794 -0.02238331
  -0.23441242]
 ...
 [-0.06893474 -0.28299838  0.2841855  ...  0.0912203  -0.13070367
   0.13710706]
 [ 0.16877264  0.27536106  0.20753025 ... -0.06487323  0.0637659
  -0.15097073]
 [-0.04221877 -0.18164001  0.02198964 ... -0.07176036 -0.04935543
   0.19380973]]


In [39]:
lm_block_layer1 = lm_block.get_layer(index=0)
bert_layer8 = model.get_layer(name='Transformer-0-MultiHeadSelfAttention')
lm_block_layer1_weight = lm_block_layer1.get_weights()
bert_layer8_weight = bert_layer8.get_weights()
# 得到lm_block的第二层和bert的第八层(transformer的第一层)

In [40]:
for i in range(len(lm_block_layer1_weight)):
    print(lm_block_layer1_weight[i].shape == bert_layer8_weight[i].shape)
    # 维度完全相同

True
True
True
True
True
True
True
True


In [24]:
lm_block_layer1.set_weights(bert_layer8_weight)

In [25]:
print(lm_block.get_layer(index=0).get_weights()[0])

[[ 1.1491621e-01  6.5357164e-03  1.3464041e-02 ... -5.0526168e-02
   1.7295334e-02  1.5026121e-02]
 [-9.4047729e-03 -2.2129135e-02  2.4954821e-03 ...  2.3847362e-02
  -8.8263817e-02 -2.9466402e-02]
 [ 5.8782017e-03 -6.7798183e-03  3.5073500e-02 ...  1.0479877e-02
  -5.6218460e-02 -1.5802451e-03]
 ...
 [ 1.3729039e-02  4.7092853e-05  1.1137443e-01 ...  5.6518074e-02
  -4.2443074e-02  9.7083911e-02]
 [ 1.5892318e-02  4.5497168e-02 -3.1567805e-02 ...  4.1080634e-03
   3.8016669e-02 -3.4227695e-02]
 [-8.1710182e-02  1.3203139e-02 -1.4776111e-02 ...  7.1023442e-02
  -1.5387528e-02  2.2910309e-03]]


In [26]:
print(lm_block.get_layer(index=1).get_weights())

[array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1

In [27]:
lm_block_layer2 = lm_block.get_layer(index=1)
bert_layer11 = model.get_layer(index=11)
lm_block_layer2_weight = lm_block_layer2.get_weights()
bert_layer11_weight = bert_layer11.get_weights()

In [28]:
for i in range(len(lm_block_layer2_weight)):
    print(lm_block_layer2_weight[i].shape == bert_layer11_weight[i].shape)
    # 维度完全相同

True
True


In [29]:
lm_block_layer2.set_weights(bert_layer11_weight)

In [30]:
print(lm_block.get_layer(index=1).get_weights())

[array([ 1.16601745e-02,  5.40437289e-02, -9.16207060e-02,  1.05281323e-01,
        3.48078430e-01, -1.68094128e-01,  7.95071498e-02, -8.63028504e-03,
       -6.27969205e-02, -2.64445066e-01,  2.60950685e-01, -2.48833984e-01,
       -1.17190272e-01, -1.13512546e-01,  1.55974582e-01, -2.52695531e-02,
        1.71177953e-01, -5.15815839e-02,  7.09559023e-02, -1.18340828e-01,
        2.12990418e-02,  1.26445740e-01,  1.75389534e-04, -1.34746641e-01,
       -1.15555532e-01, -2.60010883e-02,  7.14812949e-02, -7.86032081e-02,
       -1.22626677e-01,  2.22578004e-01,  5.76628707e-02, -1.35067642e-01,
       -4.47699465e-02, -2.01051325e-01,  5.94254211e-02,  1.82048753e-01,
        1.29285902e-02, -1.16891749e-02, -2.27592885e-02,  2.86168993e-01,
        2.68879887e-02, -1.13643475e-01,  3.20790112e-02, -9.24040452e-02,
       -8.81051831e-03,  1.03911594e-01,  2.40146920e-01,  2.45545626e-01,
       -3.97668779e-02, -1.32390097e-01,  1.58876464e-01,  1.41543996e+00,
        3.39075364e-02, 

In [31]:
print(lm_block.get_layer(index=2).get_weights())

[array([[-0.00458032, -0.02573658,  0.02269527, ...,  0.02642405,
        -0.01966409, -0.02274757],
       [ 0.01214657, -0.02762798, -0.03058223, ...,  0.01804959,
         0.00809181, -0.00726894],
       [-0.03293017,  0.00669291,  0.02711615, ..., -0.01748941,
        -0.0174475 ,  0.0022877 ],
       ...,
       [ 0.01918899,  0.0020614 ,  0.03002295, ..., -0.01138652,
        -0.00483538, -0.00028231],
       [-0.02834468,  0.01324727,  0.02967318, ..., -0.01023217,
        -0.02230434, -0.01218567],
       [-0.03690905,  0.02856971,  0.02101301, ..., -0.01279723,
        -0.00109855,  0.03700184]], dtype=float32), array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), array([[-0.00085744, -0.01362677,  0.0121907 , ...,  0.01864001,
         0.03681674, -0.01593674],
       [ 0.03150414, -0.01039497,  0.03417247, ..., -0.00574265,
         0.00591145,  0.02696376],
       [ 0.00651057,  0.00752371,  0.00105783, ...,  0.00465839,
         0.03282674,  0.00777529],
       ...,
     

In [32]:
lm_block_layer3 = lm_block.get_layer(index=2)
bert_layer12 = model.get_layer(index=12)
bert_layer15 = model.get_layer(index=15)
lm_block_layer3_weight = lm_block_layer3.get_weights()
bert_layer12_weight = bert_layer12.get_weights()
bert_layer15_weight = bert_layer15.get_weights()

In [34]:
for i in range(len(bert_layer12_weight)):
    print(lm_block_layer3_weight[i].shape == bert_layer12_weight[i].shape)
for i in range(len(bert_layer15_weight)):
    print(lm_block_layer3_weight[i+4].shape == bert_layer15_weight[i].shape)

True
True
True
True
True
True


In [36]:
combined_weight = bert_layer12_weight + bert_layer15_weight
len(combined_weight)

6

In [37]:
lm_block_layer3.set_weights(combined_weight)

In [38]:
print(lm_block.get_layer(index=2).get_weights())

[array([[-0.0058011 ,  0.05985398, -0.03170604, ...,  0.06887805,
         0.02423428, -0.03973034],
       [-0.007228  , -0.02539035, -0.00167863, ...,  0.01907857,
         0.01626231, -0.01045543],
       [-0.04602809,  0.03800549, -0.00385575, ..., -0.02603067,
        -0.06047573, -0.00679531],
       ...,
       [ 0.00474144,  0.10876587, -0.01808572, ...,  0.04912699,
        -0.0229429 ,  0.00875449],
       [-0.02682151, -0.10051435, -0.00056017, ..., -0.05286248,
        -0.04268262, -0.03293934],
       [-0.00666132,  0.02399873,  0.00313861, ..., -0.10766514,
         0.09482101,  0.0334359 ]], dtype=float32), array([-0.03333926, -0.15175405,  0.03061958, ..., -0.12057822,
       -0.08789343, -0.11828085], dtype=float32), array([[-0.04991841,  0.0067661 ,  0.00231446, ...,  0.04711771,
         0.07601267, -0.04818509],
       [ 0.03975736, -0.0166433 ,  0.0217378 , ..., -0.00293149,
        -0.03669158, -0.05077926],
       [ 0.02046659,  0.0257666 ,  0.06818264, ..., -0.0

In [67]:
bert_chinese_params = dict()

In [68]:
for layer_num in range(12):
    each_layer_params = dict()
    attention_layer = model.get_layer(name=f'Transformer-{str(layer_num)}-MultiHeadSelfAttention')
    attn_layer_norm = model.get_layer(name=f'Transformer-{str(layer_num)}-MultiHeadSelfAttention-Norm')
    ffw = model.get_layer(name=f'Transformer-{str(layer_num)}-FeedForward')
    ffw_layer_norm = model.get_layer(name=f'Transformer-{str(layer_num)}-FeedForward-Norm')

    attention_layer_weights = attention_layer.get_weights()
    attn_layer_norm_weights = attn_layer_norm.get_weights()
    ffw_weights = ffw.get_weights()
    ffw_layer_norm_weights = ffw_layer_norm.get_weights()
    ffw_plus_ffw_layer_norm_weights = ffw_weights + ffw_layer_norm_weights

    each_layer_params['attn'] = attention_layer_weights
    each_layer_params['attn_norm'] = attn_layer_norm_weights
    each_layer_params['ffw'] = ffw_plus_ffw_layer_norm_weights

    bert_chinese_params[layer_num] = each_layer_params

In [69]:
token_embedding = model.get_layer(name='Embedding-Token')
token_embedding_weights = token_embedding.get_weights()
token_segment_embedding = model.get_layer(name='Embedding-Segment')
token_segment_embedding_weights = token_segment_embedding.get_weights()
position_embedding = model.get_layer(name='Embedding-Position')
position_embedding_weights = position_embedding.get_weights()

In [70]:
token_embedding_weights[0].shape

(21128, 768)

In [71]:
token_segment_embedding_weights[0].shape

(2, 768)

In [72]:
position_embedding_weights[0].shape

(512, 768)

In [73]:
bert_chinese_params['word_emb'] = token_embedding_weights
bert_chinese_params['token_type_emb'] = token_segment_embedding_weights
bert_chinese_params['position_emb'] = position_embedding_weights

In [74]:
import pickle

In [75]:
pickle_file = open('all_bert_chinese_L-12_H-768_A-12_params.pkl', 'wb')
pickle.dump(bert_chinese_params, pickle_file)
pickle_file.close()