# BERT预训练

两种预训练任务：
- `Mask LM`：将标记序列中一定数量的标记遮挡，然后预测该处的标记
- `Next Sentence Prediction`：标记序列的两部分在真实文本中是不是连续的
<img src="../images/pretrained_tasks.png" width="100%">

In [1]:
import tensorflow as tf

## 获取预训练数据
加载 [38-BERT创建训练数据(Tensorflow)](38-BERT创建训练数据(Tensorflow).ipynb) 中创建然后保存为 `TFRecord` 的预训练数据

In [2]:
# 读取 TFRecord， 转换数据类型 tf.int64 --> tfint32
def decode_record(record, name_to_features):
    example = tf.io.parse_single_example(record, name_to_features)
    for name in list(example.keys()):
        t = example[name]
        if t.dtype == tf.int64:
            t = tf.cast(t, tf.int32)
        example[name] = t
    return example


# 读取 TFRecord 文件
def single_file_dataset(input_file, name_to_features):
    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    d = d.map(lambda record: decode_record(record, name_to_features))

    # When `input_file` is a path to a single file or a list
    # containing a single path, disable auto sharding so that
    # same input file is sent to all workers.
    if isinstance(input_file, str) or len(input_file) == 1:
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = (
            tf.data.experimental.AutoShardPolicy.OFF)
        d = d.with_options(options)

    return d


def create_pretrain_dataset(input_patterns,
                            seq_length,
                            max_predictions_per_seq,
                            batch_size,
                            is_training=True,
                            input_pipeline_context=None):
    # 数据中的特征
    name_to_features = {
        'input_ids':
        tf.io.FixedLenFeature([seq_length], tf.int64),
        'input_mask':
        tf.io.FixedLenFeature([seq_length], tf.int64),
        'segment_ids':
        tf.io.FixedLenFeature([seq_length], tf.int64),
        'masked_lm_positions':
        tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
        'masked_lm_ids':
        tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
        'masked_lm_weights':
        tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
        'next_sentence_labels':
        tf.io.FixedLenFeature([1], tf.int64),
    }

    # 读取所有数据
    dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)

    # 设置并行通道
    if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
        dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                                input_pipeline_context.input_pipeline_id)

    dataset = dataset.repeat()

    # 随机重排序
    input_files = []
    for input_pattern in input_patterns:
        input_files.extend(tf.io.gfile.glob(input_pattern))
    dataset = dataset.shuffle(len(input_files))

    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=8,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # 读取 TFRecord 内容，转换数据类型
    decode_fn = lambda record: decode_record(record, name_to_features)
    dataset = dataset.map(decode_fn,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # 选择训练数据
    def _select_data_from_record(record):
        x = {
            'input_word_ids': record['input_ids'],
            'input_mask': record['input_mask'],
            'input_type_ids': record['segment_ids'],
            'masked_lm_positions': record['masked_lm_positions'],
            'masked_lm_ids': record['masked_lm_ids'],
            'masked_lm_weights': record['masked_lm_weights'],
            'next_sentence_labels': record['next_sentence_labels'],
        }
        y = record['masked_lm_weights']
        return (x, y)

    dataset = dataset.map(_select_data_from_record,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if is_training:
        dataset = dataset.shuffle(100)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(1024)
    return dataset

In [29]:
input_files = "../datasets/sample.record"
input_patterns = input_files.split(',')
dataset = create_pretrain_dataset(input_patterns,
                                  seq_length=128,
                                  max_predictions_per_seq=20,
                                  batch_size=32)

for data in dataset:
    import random
    x, y = data
    for k in x:
        batch_size = x[k].shape[0]
        ind = random.randint(0, batch_size)
        print(k + f" batch_size: {batch_size} ; the {ind}th sample: \n")
        print(x[k][ind])
        print("=" * 80)
    break

input_word_ids batch_size: 32 ; the 24th sample: 

tf.Tensor(
[  101  1103  4458  1125  1178  6445   103  1103  5021 24177  1104  2106
  1120 20013  2851   103  1105  1103  3433 21167  1106   170  7279  2305
  1104  4044   103   117  1105  1103  4006  1104  6278 14726   117 14086
 13624   117  1105  2964  3227   190 23826   103   103  1187  1103   103
  5946  1116  1125  8589  1283  1103 11829   103   103 19943  1116  1196
  1103  6493  3581   119   102  1103 12325  1104   103 12304   117   177
  1183   103 10691  1941   119   102     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0], shape=(128,), dtype=int32)
input_mask batch_size: 32 ; the 15th sample: 

tf.Tensor(
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

In [32]:
# 一个批次数据中，每个序列中被遮挡的位置数量是不相同的
for col in x['masked_lm_positions'].numpy()!=0:
    print(sum(col))

8
10
19
8
19
19
10
12
19
6
8
19
19
8
19
10
19
19
19
19
19
19
6
19
12
19
19
9
8
19
19
19


In [None]:
# 读取数据封装成函数

def get_pretrain_dataset_fn(input_file_pattern, seq_length,
                            max_predictions_per_seq, global_batch_size):
    def _dataset_fn(ctx=None):
        input_patterns = input_file_pattern.split(',')
        batch_size = ctx.get_per_replica_batch_size(global_batch_size)
        train_dataset = input_pipeline.create_pretrain_dataset(
            input_patterns,
            seq_length,
            max_predictions_per_seq,
            batch_size,
            is_training=True,
            input_pipeline_context=ctx,
        )
        return train_dataset

    return _dataset_fn

## 模型

### `Mask LM`
- 从`transformer`编码器的输出，和 被遮挡标记位置列表，获取被遮挡标记的输出

- 被遮挡标记的输出分布与 `transformer` 的 `embedding` 层矩阵乘法，获取预测输出，每个标记对应词汇表的概率分布

In [None]:
# @tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(network.Network):
    """Masked language model network head for BERT modeling.

  This network implements a masked language model based on the provided network.
  It assumes that the network being passed has a "get_embedding_table()" method.

  Attributes:
    input_width: The innermost dimension of the input tensor to this network.
    num_predictions: The number of predictions to make per sequence.
    source_network: The network with the embedding layer to use for the
      embedding layer.
    activation: The activation, if any, for the dense layer in this network.
    initializer: The intializer for the dense layer in this network. Defaults to
      a Glorot uniform initializer.
    output: The output style for this network. Can be either 'logits' or
      'predictions'.
  """
    def __init__(
            self,
            input_width,
            num_predictions,
            source_network,  # transformer
            activation=None,
            initializer='glorot_uniform',
            output='logits',
            **kwargs):

        embedding_table = source_network.get_embedding_table()
        vocab_size, hidden_size = embedding_table.shape

        # transformer 的输出：batch,seq_len,hidden_size
        sequence_data = tf.keras.layers.Input(shape=(None, input_width),
                                              name='sequence_data',
                                              dtype=tf.float32)
        # 被遮挡的标记的位置列表
        masked_lm_positions = tf.keras.layers.Input(shape=(num_predictions, ),
                                                    name='masked_lm_positions',
                                                    dtype=tf.int32)

        # 被遮挡的标记对应的输出序列: batch*um_predictions,hidden_size
        masked_lm_input = tf.keras.layers.Lambda(
            lambda x: self._gather_indexes(x[0], x[1]))(
                [sequence_data, masked_lm_positions])

        # 密集层
        lm_data = (tf.keras.layers.Dense(
            hidden_size,
            activation=activation,
            kernel_initializer=initializer,
            name='cls/predictions/transform/dense',
        )(masked_lm_input))
        lm_data = tf.keras.layers.LayerNormalization(
            axis=-1, epsilon=1e-12,
            name='cls/predictions/transform/LayerNorm')(lm_data)

        # 与词汇表每个单词做点积
        lm_data = tf.keras.layers.Lambda(
            lambda x: tf.matmul(x, embedding_table, transpose_b=True))(lm_data)
        logits = Bias(initializer=tf.keras.initializers.Zeros(),
                      name='cls/predictions/output_bias')(lm_data)

        # We can't use the standard Keras reshape layer here, since it expects
        # the input and output batch size to be the same.
        reshape_layer = tf.keras.layers.Lambda(
            lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))

        # batch，num_predictions, vocab_size
        self.logits = reshape_layer(logits)
        
        # softmax 获取概率分布
        predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
            self.logits)

        if output == 'logits':
            output_tensors = self.logits
        elif output == 'predictions':
            output_tensors = predictions
        else:
            raise ValueError((
                'Unknown `output` value "%s". `output` can be either "logits" or '
                '"predictions"') % output)

        super(MaskedLM,
              self).__init__(inputs=[sequence_data, masked_lm_positions],
                             outputs=output_tensors,
                             **kwargs)

    def get_config(self):
        raise NotImplementedError(
            'MaskedLM cannot be directly serialized at this '
            'time. Please use it only in Layers or '
            'functionally subclassed Models/Networks.')

    def _gather_indexes(self, sequence_tensor, positions):
        """Gathers the vectors at the specific positions.

    Args:
        sequence_tensor: Sequence output of `BertModel` layer of shape
          (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
          hidden units of `BertModel` layer.
        positions: Positions ids of tokens in sequence to mask for pretraining
          of with dimension (batch_size, num_predictions) where
          `num_predictions` is maximum number of tokens to mask out and predict
          per each sequence.

    Returns:
        Masked out sequence tensor of shape (batch_size * num_predictions,
        num_hidden).
    """
        sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                                 name='sequence_output_tensor')
        batch_size, seq_length, width = sequence_shape

        flat_offsets = tf.keras.backend.reshape(
            tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
        flat_positions = tf.keras.backend.reshape(positions + flat_offsets,
                                                  [-1])
        flat_sequence_tensor = tf.keras.backend.reshape(
            sequence_tensor, [batch_size * seq_length, width])
        output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

        return output_tensor

### `NSP`
- 开始标记`[CLS]`对应的输出，在加一个二分类的全连接层

In [None]:
# @tf.keras.utils.register_keras_serializable(package='Text')
class Classification(network.Network):
    """Classification network head for BERT modeling.

  This network implements a simple classifier head based on a dense layer.

  Attributes:
    input_width: The innermost dimension of the input tensor to this network.
    num_classes: The number of classes that this network should classify to.
    activation: The activation, if any, for the dense layer in this network.
    initializer: The intializer for the dense layer in this network. Defaults to
      a Glorot uniform initializer.
    output: The output style for this network. Can be either 'logits' or
      'predictions'.
  """
    def __init__(self,
                 input_width,
                 num_classes,
                 initializer='glorot_uniform',
                 output='logits',
                 **kwargs):
        self._self_setattr_tracking = False
        self._config_dict = {
            'input_width': input_width,
            'num_classes': num_classes,
            'initializer': initializer,
            'output': output,
        }

        cls_output = tf.keras.layers.Input(shape=(input_width, ),
                                           name='cls_output',
                                           dtype=tf.float32)
        

        # 分类的全连接层
        self.logits = tf.keras.layers.Dense(
            num_classes,
            activation=None,
            kernel_initializer=initializer,
            name='predictions/transform/logits')(cls_output)
        predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(
            self.logits)

        if output == 'logits':
            output_tensors = self.logits
        elif output == 'predictions':
            output_tensors = predictions
        else:
            raise ValueError((
                'Unknown `output` value "%s". `output` can be either "logits" or '
                '"predictions"') % output)

        super(Classification, self).__init__(inputs=[cls_output],
                                             outputs=output_tensors,
                                             **kwargs)

    def get_config(self):
        return self._config_dict

    @classmethod
    def from_config(cls, config, custom_objects=None):
        return cls(**config)

### 整合了预训练任务的模型

In [None]:
class BertPretrainer(tf.keras.Model):
    def __init__(
            self,
            network,  # transformer 编码器
            num_classes,
            num_token_predictions,
            activation=None,
            output_activation=None,
            initializer='glorot_uniform',
            output='logits',
            **kwargs):
        self._self_setattr_tracking = False
        self._config = {
            'network': network,
            'num_classes': num_classes,
            'num_token_predictions': num_token_predictions,
            'activation': activation,
            'output_activation': output_activation,
            'initializer': initializer,
            'output': output,
        }

        # We want to use the inputs of the passed network as the inputs to this
        # Model. To do this, we need to keep a copy of the network inputs for use
        # when we construct the Model object at the end of init. (We keep a copy
        # because we'll be adding another tensor to the copy later.)
        network_inputs = network.inputs
        inputs = copy.copy(network_inputs)

        # Because we have a copy of inputs to create this Model object, we can
        # invoke the Network object with its own input tensors to start the Model.

        # cls_output 为 [CLS] 的输出，用于 NSP 任务
        # sequence_output 为所有标记的输出
        sequence_output, cls_output = network(network_inputs)

        sequence_output_length = sequence_output.shape.as_list()[1]
        if sequence_output_length < num_token_predictions:
            raise ValueError(
                "The passed network's output length is %s, which is less than the "
                'requested num_token_predictions %s.' %
                (sequence_output_length, num_token_predictions))

        # Mask LM 任务
        masked_lm_positions = tf.keras.layers.Input(
            shape=(num_token_predictions, ),
            name='masked_lm_positions',
            dtype=tf.int32,
        )
        inputs.append(masked_lm_positions)

        self.masked_lm = MaskedLM(
            num_predictions=num_token_predictions,
            input_width=sequence_output.shape[-1],
            source_network=network,
            activation=activation,
            initializer=initializer,
            output=output,
            name='masked_lm',
        )
        lm_outputs = self.masked_lm([sequence_output, masked_lm_positions])

        # NSP 任务
        self.classification = Classification(
            input_width=cls_output.shape[-1],
            num_classes=num_classes,
            initializer=initializer,
            output=output,
            name='classification',
        )
        sentence_outputs = self.classification(cls_output)

        super(BertPretrainer,
              self).__init__(inputs=inputs,
                             outputs=[lm_outputs, sentence_outputs],
                             **kwargs)

    def get_config(self):
        return self._config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        return cls(**config)

### 损失函数

In [None]:
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
    """Returns layer that computes custom loss and metrics for pretraining."""
    def __init__(self, vocab_size, **kwargs):
        super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
        self._vocab_size = vocab_size
        self.config = {
            'vocab_size': vocab_size,
        }

    def __call__(self,
                 lm_output,
                 sentence_output=None,
                 lm_label_ids=None,
                 lm_label_weights=None,
                 sentence_labels=None,
                 **kwargs):
        inputs = tf_utils.pack_inputs([
            lm_output, sentence_output, lm_label_ids, lm_label_weights,
            sentence_labels
        ])
        return super(BertPretrainLossAndMetricLayer,
                     self).__call__(inputs, **kwargs)

    def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
                     lm_example_loss, sentence_output, sentence_labels,
                     next_sentence_loss):
        """Adds metrics."""
        masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
            lm_labels, lm_output)
        numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
        denominator = tf.reduce_sum(lm_label_weights) + 1e-5
        masked_lm_accuracy = numerator / denominator
        self.add_metric(masked_lm_accuracy,
                        name='masked_lm_accuracy',
                        aggregation='mean')

        self.add_metric(lm_example_loss,
                        name='lm_example_loss',
                        aggregation='mean')

        next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
            sentence_labels, sentence_output)
        self.add_metric(next_sentence_accuracy,
                        name='next_sentence_accuracy',
                        aggregation='mean')

        self.add_metric(next_sentence_loss,
                        name='next_sentence_loss',
                        aggregation='mean')

    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        
        # Mask LM 的输出
        lm_output = unpacked_inputs[0]
        
        # NSP 的输出
        sentence_output = unpacked_inputs[1]
        
        # Mask LM 的标签
        lm_label_ids = unpacked_inputs[2]
        
        # Mask LM 的权重
        lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3],
                                                 tf.float32)
        # NSP 的输出
        sentence_labels = unpacked_inputs[4]

        # Mask LM 的损失
        mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=lm_label_ids,
            predictions=lm_output,
            weights=lm_label_weights)
        
        # NSP 的损失
        sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=sentence_labels, predictions=sentence_output)
        
        # 总损失
        loss = mask_label_loss + sentence_loss
        batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0],
                               [1])
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(batch_shape, loss)

        self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                          mask_label_loss, sentence_output, sentence_labels,
                          sentence_loss)
        return final_loss

### 带损失层的完整模型

In [None]:
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_pre_seq,
                   initializer=None):

    # 模型输入：
    input_word_ids = tf.keras.layers.Input(
        shape=(seq_lenght, ),
        name="input_word_ids",
        dtype=tf.int32,
    )
    input_mask = tf.keras.layers.Input(
        shape=(seq_length),
        name='input_mask',
        dtype=tf.int32,
    )
    input_type_ids = tf.keras.layers.Input(
        shape=(seq_length, ),
        name='input_type_ids',
        dtype=tf.int32,
    )
    masked_lm_positions = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_positions',
        dtype=tf.int32)
    masked_lm_ids = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_ids',
        dtype=tf.int32,
    )
    masked_lm_weights = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_weights',
        dtype=tf.int32)
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1, ),
        name='next_sentence_labels',
        dtype=tf.int32,
    )

    # transformer 编码器
    transformer_encoder = _get_transformer_encoder(bert_config, seq_length)
    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range)

    # 预训练任务的模型
    pretrainer_model = BertPretrainer(
        network=transformer_encoder,
        num_classes=2,  # The next sentence prediction label has two classes.
        num_token_predictions=max_predictions_per_seq,
        initializer=initializer,
        output='predictions',
    )

    # 输出
    lm_output, sentence_output = pretrainer_model(
        [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

    # 损失函数
    pretrain_loss_layer = BertPretrainLossAndMetricLayer(
        vocab_size=bert_config.vocab_size)
    output_loss = pretrain_loss_layer(lm_output, sentence_output,
                                      masked_lm_ids, masked_lm_weights,
                                      next_sentence_labels)

    # 完整的模型
    keras_model = tf.keras.Model(inputs={
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids,
        'masked_lm_positions': masked_lm_positions,
        'masked_lm_ids': masked_lm_ids,
        'masked_lm_weights': masked_lm_weights,
        'next_sentence_labels': next_sentence_labels
    },
                                 outputs=output_loss)
    return keras_model, transformer_encoder

### 训练过程

In [None]:
def run_customized_training_loop(  
        _sentinel=None,  
        strategy=None,
        model_fn=None,
        loss_fn=None,
        model_dir=None,
        train_input_fn=None,
        steps_per_epoch=None,
        steps_per_loop=1,
        epochs=1,
        eval_input_fn=None,
        eval_steps=None,
        metric_fn=None,
        init_checkpoint=None,
        custom_callbacks=None,
        run_eagerly=False,
        sub_model_export_name=None):
    """Run BERT pretrain model training using low-level API.
  
    Arguments:
        _sentinel: Used to prevent positional parameters. Internal, do not use.
        strategy: Distribution strategy on which to run low level training loop.
        model_fn: Function that returns a tuple (model, sub_model). Caller of this
          function should add optimizer to the `model` via calling
          `model.compile()` API or manually setting `model.optimizer` attribute.
          Second element of the returned tuple(sub_model) is an optional sub model
          to be used for initial checkpoint -- if provided.
        loss_fn: Function with signature func(labels, logits) and returns a loss
          tensor.
        model_dir: Model directory used during training for restoring/saving model
          weights.
        train_input_fn: Function that returns a tf.data.Dataset used for training.
        steps_per_epoch: Number of steps to run per epoch. At the end of each
          epoch, model checkpoint will be saved and evaluation will be conducted
          if evaluation dataset is provided.
        steps_per_loop: Number of steps per graph-mode loop. In order to reduce
          communication in eager context, training logs are printed every
          steps_per_loop.
        epochs: Number of epochs to train.
        eval_input_fn: Function that returns evaluation dataset. If none,
          evaluation is skipped.
        eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
          is not none.
        metric_fn: A metrics function that returns a Keras Metric object to record
          evaluation result using evaluation dataset or with training dataset
          after every epoch.
        init_checkpoint: Optional checkpoint to load to `sub_model` returned by
          `model_fn`.
        custom_callbacks: A list of Keras Callbacks objects to run during
          training. More specifically, `on_batch_begin()`, `on_batch_end()`,
          methods are invoked during training.
        run_eagerly: Whether to run model training in pure eager execution. This
          should be disable for TPUStrategy.
        sub_model_export_name: If not None, will export `sub_model` returned by
          `model_fn` into checkpoint files. The name of intermediate checkpoint
          file is {sub_model_export_name}_step_{step}.ckpt and the last
          checkpint's name is {sub_model_export_name}.ckpt;
          if None, `sub_model` will not be exported as checkpoint.
  
    Returns:
        Trained model.
  
    Raises:
        ValueError: (1) When model returned by `model_fn` does not have optimizer
          attribute or when required parameters are set to none. (2) eval args are
          not specified correctly. (3) metric_fn must be a callable if specified.
          (4) sub_model_checkpoint_name is specified, but `sub_model` returned
          by `model_fn` is None.
    """

    if _sentinel is not None:
        raise ValueError('only call `run_customized_training_loop()` '
                         'with named arguments.')

    required_arguments = [
        strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
    ]
    if [arg for arg in required_arguments if arg is None]:
        raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                         '`steps_per_loop` and `steps_per_epoch` are required '
                         'parameters.')
    if steps_per_loop > steps_per_epoch:
        logging.error(
            'steps_per_loop: %d is specified to be greater than '
            ' steps_per_epoch: %d, we will use steps_per_epoch as'
            ' steps_per_loop.', steps_per_loop, steps_per_epoch)
        steps_per_loop = steps_per_epoch
    assert tf.executing_eagerly()

    if run_eagerly:
        if steps_per_loop > 1:
            raise ValueError(
                'steps_per_loop is used for performance optimization. When you want '
                'to run eagerly, you cannot leverage graph mode loop.')
        if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
            raise ValueError(
                'TPUStrategy should not run eagerly as it heavily replies on graph'
                ' optimization for the distributed system.')

    if eval_input_fn and (eval_steps is None or metric_fn is None):
        raise ValueError(
            '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
            'is not none.')
    if metric_fn and not callable(metric_fn):
        raise ValueError(
            'if `metric_fn` is specified, metric_fn must be a callable.')

    total_training_steps = steps_per_epoch * epochs

    # To reduce unnecessary send/receive input pipeline operation, we place input
    # pipeline ops in worker task.
    train_iterator = _get_input_iterator(train_input_fn, strategy)

    with distribution_utils.get_strategy_scope(strategy):
        # To correctly place the model weights on accelerators,
        # model and optimizer should be created in scope.
        
        # 输入-->损失的完整模型， 和 transformer 编码器部分
        model, sub_model = model_fn()
        if not hasattr(model, 'optimizer'):
            raise ValueError('User should set optimizer attribute to model '
                             'inside `model_fn`.')
        if sub_model_export_name and sub_model is None:
            raise ValueError('sub_model_export_name is specified as %s, but '
                             'sub_model is None.' % sub_model_export_name)

        # 优化器
        optimizer = model.optimizer
        use_float16 = isinstance(
            optimizer,
            tf.keras.mixed_precision.experimental.LossScaleOptimizer)

        if init_checkpoint:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'initial checkpoint for core model.', init_checkpoint)
            checkpoint = tf.train.Checkpoint(model=sub_model)
            checkpoint.restore(
                init_checkpoint).assert_existing_objects_matched()
            logging.info('Loading from checkpoint file completed')

        train_loss_metric = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
        eval_metrics = [metric_fn()] if metric_fn else []
        # If evaluation is required, make a copy of metric as it will be used by
        # both train and evaluation.
        train_metrics = [
            metric.__class__.from_config(metric.get_config())
            for metric in eval_metrics
        ]

        # Create summary writers
        summary_dir = os.path.join(model_dir, 'summaries')
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'eval'))
        if steps_per_loop >= _MIN_SUMMARY_STEPS:
            # Only writes summary when the stats are collected sufficiently over
            # enough steps.
            train_summary_writer = tf.summary.create_file_writer(
                os.path.join(summary_dir, 'train'))
        else:
            train_summary_writer = None

        # Collects training variables.
        training_vars = model.trainable_variables

        def _replicated_step(inputs):
            """Replicated training step."""

            inputs, labels = inputs
            with tf.GradientTape() as tape:
                model_outputs = model(inputs, training=True)
                loss = loss_fn(labels, model_outputs)
                if use_float16:
                    scaled_loss = optimizer.get_scaled_loss(loss)

            if use_float16:
                scaled_grads = tape.gradient(scaled_loss, training_vars)
                grads = optimizer.get_unscaled_gradients(scaled_grads)
            else:
                grads = tape.gradient(loss, training_vars)
            optimizer.apply_gradients(zip(grads, training_vars))
            # For reporting, the metric takes the mean of losses.
            train_loss_metric.update_state(loss)
            for metric in train_metrics:
                metric.update_state(labels, model_outputs)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop.
      
            Args:
              iterator: the distributed iterator of training datasets.
              steps: an tf.int32 integer tensor to specify number of steps to run
                inside host training loop.
      
            Raises:
              ValueError: Any of the arguments or tensor shapes are invalid.
            """
            if not isinstance(steps, tf.Tensor):
                raise ValueError(
                    'steps should be an Tensor. Python object may cause '
                    'retracing.')

            for _ in tf.range(steps):
                strategy.experimental_run_v2(_replicated_step,
                                             args=(next(iterator), ))

        def train_single_step(iterator):
            """Performs a distributed training step.
      
            Args:
              iterator: the distributed iterator of training datasets.
      
            Raises:
              ValueError: Any of the arguments or tensor shapes are invalid.
            """
            strategy.experimental_run_v2(_replicated_step,
                                         args=(next(iterator), ))

        def test_step(iterator):
            """Calculates evaluation metrics on distributed devices."""
            def _test_step_fn(inputs):
                """Replicated accuracy calculation."""

                inputs, labels = inputs
                model_outputs = model(inputs, training=False)
                for metric in eval_metrics:
                    metric.update_state(labels, model_outputs)

            strategy.experimental_run_v2(_test_step_fn,
                                         args=(next(iterator), ))

        if not run_eagerly:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        def _run_evaluation(current_training_step, test_iterator):
            """Runs validation steps and aggregate metrics."""
            for _ in range(eval_steps):
                test_step(test_iterator)

            with eval_summary_writer.as_default():
                for metric in eval_metrics + model.metrics:
                    metric_value = _float_metric_value(metric)
                    logging.info('Step: [%d] Validation %s = %f',
                                 current_training_step, metric.name,
                                 metric_value)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_training_step)
                eval_summary_writer.flush()

        def _run_callbacks_on_batch_begin(batch):
            """Runs custom callbacks at the start of every step."""
            if not custom_callbacks:
                return
            for callback in custom_callbacks:
                callback.on_batch_begin(batch)

        def _run_callbacks_on_batch_end(batch):
            """Runs custom callbacks at the end of every step."""
            if not custom_callbacks:
                return
            for callback in custom_callbacks:
                callback.on_batch_end(batch)

        # Training loop starts here.
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        sub_model_checkpoint = tf.train.Checkpoint(
            model=sub_model) if sub_model_export_name else None

        latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
        if latest_checkpoint_file:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(latest_checkpoint_file)
            logging.info('Loading from checkpoint file completed')

        current_step = optimizer.iterations.numpy()
        checkpoint_name = 'ctl_step_{step}.ckpt'

        while current_step < total_training_steps:
            # Training loss/metric are taking average over steps inside micro
            # training loop. We reset the their values before each round.
            train_loss_metric.reset_states()
            for metric in train_metrics + model.metrics:
                metric.reset_states()

            _run_callbacks_on_batch_begin(current_step)
            # Runs several steps in the host while loop.
            steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

            if steps == 1:
                # TODO(zongweiz): merge with train_steps once tf.while_loop
                # GPU performance bugs are fixed.
                train_single_step(train_iterator)
            else:
                # Converts steps to a Tensor to avoid tf.function retracing.
                train_steps(train_iterator,
                            tf.convert_to_tensor(steps, dtype=tf.int32))
            _run_callbacks_on_batch_end(current_step)
            current_step += steps

            train_loss = _float_metric_value(train_loss_metric)
            # Updates training logging.
            training_status = 'Train Step: %d/%d  / loss = %s' % (
                current_step, total_training_steps, train_loss)

            if train_summary_writer:
                with train_summary_writer.as_default():
                    tf.summary.scalar(train_loss_metric.name,
                                      train_loss,
                                      step=current_step)
                    for metric in train_metrics + model.metrics:
                        metric_value = _float_metric_value(metric)
                        training_status += '  %s = %f' % (metric.name,
                                                          metric_value)
                        tf.summary.scalar(metric.name,
                                          metric_value,
                                          step=current_step)
                    train_summary_writer.flush()
            logging.info(training_status)

            # Saves model checkpoints and run validation steps at every epoch end.
            if current_step % steps_per_epoch == 0:
                # To avoid repeated model saving, we do not save after the last
                # step of training.
                if current_step < total_training_steps:
                    _save_checkpoint(checkpoint, model_dir,
                                     checkpoint_name.format(step=current_step))
                    if sub_model_export_name:
                        _save_checkpoint(
                            sub_model_checkpoint, model_dir,
                            '%s_step_%d.ckpt' %
                            (sub_model_export_name, current_step))
                if eval_input_fn:
                    logging.info('Running evaluation after step: %s.',
                                 current_step)
                    _run_evaluation(
                        current_step,
                        _get_input_iterator(eval_input_fn, strategy))
                    # Re-initialize evaluation metric.
                    for metric in eval_metrics + model.metrics:
                        metric.reset_states()

        _save_checkpoint(checkpoint, model_dir,
                         checkpoint_name.format(step=current_step))
        if sub_model_export_name:
            _save_checkpoint(sub_model_checkpoint, model_dir,
                             '%s.ckpt' % sub_model_export_name)

        if eval_input_fn:
            logging.info(
                'Running final evaluation after training is complete.')
            _run_evaluation(current_step,
                            _get_input_iterator(eval_input_fn, strategy))

        training_summary = {
            'total_training_steps': total_training_steps,
            'train_loss': _float_metric_value(train_loss_metric),
        }
        if eval_metrics:
            # TODO(hongkuny): Cleans up summary reporting in text.
            training_summary['last_train_metrics'] = _float_metric_value(
                train_metrics[0])
            training_summary['eval_metrics'] = _float_metric_value(
                eval_metrics[0])

        write_txt_summary(training_summary, summary_dir)

        return model

### 训练模型

In [None]:
def run_customized_training(strategy, bert_config, max_seq_length,
                            max_predictions_per_seq, model_dir,
                            steps_per_epoch, steps_per_loop, epochs,
                            initial_lr, warmup_steps, input_files,
                            train_batch_size):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                             max_predictions_per_seq,
                                             train_batch_size)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config, max_seq_length, max_predictions_per_seq)
        pretrain_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                pretrain_model.optimizer)
        return pretrain_model, core_model

    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(
            loss_factor=1.0 /
            strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        sub_model_export_name='pretrained/bert_model')

    return trained_model


def run_bert_pretrain(strategy):
    """Runs BERT pre-training."""

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    if not strategy:
        raise ValueError('Distribution strategy is not specified.')

    # Runs customized training loop.
    logging.info(
        'Training using customized training loop TF 2.0 with distrubuted'
        'strategy.')

    return run_customized_training(strategy, bert_config, FLAGS.max_seq_length,
                                   FLAGS.max_predictions_per_seq,
                                   FLAGS.model_dir, FLAGS.num_steps_per_epoch,
                                   FLAGS.steps_per_loop,
                                   FLAGS.num_train_epochs, FLAGS.learning_rate,
                                   FLAGS.warmup_steps, FLAGS.input_files,
                                   FLAGS.train_batch_size)


assert tf.version.VERSION.startswith('2.')

if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy(
    distribution_strategy=FLAGS.distribution_strategy,
    num_gpus=FLAGS.num_gpus,
    tpu_address=FLAGS.tpu)
if strategy:
    print('***** Number of cores used : ', strategy.num_replicas_in_sync)

run_bert_pretrain(strategy)