In [None]:
!pip install lingvo

In [None]:
import lingvo.compat as tf
from lingvo.core import base_model
from lingvo.core import insertion
from lingvo.core import metrics
from lingvo.core import py_utils
from lingvo.core import tpu_embedding_layers
from lingvo.tasks.mt import decoder
from lingvo.tasks.mt import encoder


Transformer model

参考了lingvo的mt模型框架

@misc{shen2019lingvo,
    title={Lingvo: a Modular and Scalable Framework for Sequence-to-Sequence Modeling},
    author={Jonathan Shen and Patrick Nguyen and Yonghui Wu and Zhifeng Chen and others},
    year={2019},
    eprint={1902.08295},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

In [None]:
class TransformerModel(base_model.BaseTask):
  """Transformer Model.
  Implements Attention is All You Need:
  https://arxiv.org/abs/1706.03762
  """

  @classmethod 
  def Params(cls):# return a default Params(base layer)
    #tp = p.train ep = p.eval dp = p.decoder 
    p = super().Params()
    p.encoder = encoder.TransformerEncoder.Params()#configure of transformer encoder
    p.decoder = decoder.TransformerDecoder.Params()#same as encoder
    return p

  def __init__(self, params): #constructor
    super().__init__(params)
    p = self.params
    assert p.encoder.model_dim == p.decoder.source_dim

  
  def _EncoderDevice(self):
    """Returns the device to run the encoder computation."""
    if self.params.device_mesh is not None:
      # We perform spmd based partitioning, in which case, we don't specifically
      # assign any operation to a particular device.
      return tf.device('')
    if py_utils.use_tpu():
      return tf.device(self.cluster.WorkerDeviceInModelSplit(0))
    else:
      return tf.device('')

  
  def _DecoderDevice(self):
    """Returns the device to run the decoder computation."""
    if self.params.device_mesh is not None:
      # We perform spmd based partitioning, in which case, we don't specifically
      # assign any operation to a particular device.
      return tf.device('')
    if py_utils.use_tpu():
      return tf.device(self.cluster.WorkerDeviceInModelSplit(1))
    else:
      return tf.device('')

  #use tpu embedding
  def _PropagateEmbeddingIds(self, input_batch):
    """Propagate the TPU embedding ids to the encoder/decoder input batch."""
    feature_names = (
        tpu_embedding_layers.TpuEmbeddingCollection.Get().feature_names)
    if feature_names:
      batch = input_batch.DeepCopy()
      for name in feature_names: 
        assert name in batch#recheck
        assert name not in batch.src, f'Duplicate {name} in batch.src'
        assert name not in batch.tgt, f'Duplicate {name} in batch.tgt'
        batch.src[name] = batch[name]
        batch.tgt[name] = batch[name]
      return batch

    return input_batch

  #root layer of network
  def ComputePredictions(self, theta, batch): #get the theta and input batch, return the network predictions
    p = self.params
    batch = self._PropagateEmbeddingIds(batch)

    with self._EncoderDevice():
      encoder_outputs = (
          self.enc.FProp(theta.enc, batch.src) if p.encoder else None) # FProp method that implements forward propagation through the layer.
    with self._DecoderDevice():
      predictions = self.dec.ComputePredictions(theta.dec, encoder_outputs,
                                                batch.tgt)
      if isinstance(predictions, py_utils.NestedMap):
        # Pass through encoder output as well for possible use as a FProp output
        # for various meta-MT modeling approaches, such as MT quality estimation
        # classification.
        predictions['encoder_outputs'] = encoder_outputs
      return predictions

  def ComputeLoss(self, theta, predictions, input_batch)://reutrn loss,(dictionary of scalar metrics, )
    with self._DecoderDevice():
      return self.dec.ComputeLoss(theta.dec, predictions, input_batch.tgt)

  def _GetTokenizerKeyToUse(self, key):
    """Returns a tokenizer key to use for the provided `key`."""
    if key in self.input_generator.tokenizer_dict:
      return key
    return None



  def _BeamSearchDecode(self, input_batch):
    p = self.params
    with tf.name_scope('fprop'), tf.name_scope(p.name):
      encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src)
      encoder_outputs = self.dec.AddExtraDecodingInfo(encoder_outputs,
                                                      input_batch.tgt)
      decoder_outs = self.dec.BeamSearchDecode(encoder_outputs)

      topk_hyps = decoder_outs.topk_hyps
      topk_ids = decoder_outs.topk_ids
      topk_lens = decoder_outs.topk_lens
      topk_scores = decoder_outs.topk_scores

      slen = tf.cast(
          tf.round(tf.reduce_sum(1 - input_batch.src.paddings, 1) - 1),
          tf.int32)
      srcs = self.input_generator.IdsToStrings(
          input_batch.src.ids, slen, self._GetTokenizerKeyToUse('src'))
      topk_decoded = self.input_generator.IdsToStrings(
          topk_ids, topk_lens - 1, self._GetTokenizerKeyToUse('tgt'))
      topk_decoded = tf.reshape(topk_decoded, tf.shape(topk_hyps))
      topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

      refs = self.input_generator.IdsToStrings(
          input_batch.tgt.labels,
          tf.cast(
              tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0),
              tf.int32), self._GetTokenizerKeyToUse('tgt'))

      ret_dict = {
          'target_ids': input_batch.tgt.ids,
          'target_labels': input_batch.tgt.labels,
          'target_weights': input_batch.tgt.weights,
          'target_paddings': input_batch.tgt.paddings,
          'sources': srcs,
          'targets': refs,
          'topk_decoded': topk_decoded,
          'topk_lens': topk_lens,
          'topk_scores': topk_scores,
      }
      return ret_dict

  def _PostProcessBeamSearchDecodeOut(self, dec_out_dict, dec_metrics_dict):
    """Post processes the output from `_BeamSearchDecode`."""
    p = self.params
    topk_scores = dec_out_dict['topk_scores']
    topk_decoded = dec_out_dict['topk_decoded']
    targets = dec_out_dict['targets']
    sources = dec_out_dict['sources']
    unsegment = dec_metrics_dict['corpus_bleu'].unsegmenter

    num_samples = len(targets)
    assert num_samples == len(topk_decoded), (
        '%s vs %s' % (num_samples, len(topk_decoded)))
    assert num_samples == len(sources)
    dec_metrics_dict['num_samples_in_batch'].Update(num_samples)

    key_value_pairs = []
    for i in range(num_samples):
      src, tgt = sources[i], targets[i]
      src_unseg, tgt_unseg = unsegment(src), unsegment(tgt)
      tf.logging.info('source: %s', src_unseg)
      tf.logging.info('target: %s', tgt_unseg)
      hyps = topk_decoded[i]
      assert p.decoder.beam_search.num_hyps_per_beam == len(hyps)
      info_str = u'src: {} tgt: {} '.format(src_unseg, tgt_unseg)
      for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)):
        hyp_str_unseg = unsegment(hyp_str)
        tf.logging.info('  %f: %s', score, hyp_str_unseg)
        info_str += u' hyp{n}: {hyp} score{n}: {score}'.format(
            n=n, hyp=hyp_str_unseg, score=score)
        # Only aggregate scores of the top hypothesis.
        if n == 0:
          dec_metrics_dict['corpus_bleu'].Update(tgt, hyp_str)
      key_value_pairs.append((src_unseg, info_str))
    return key_value_pairs

  def CreateDecoderMetrics(self):
    decoder_metrics = {
        'num_samples_in_batch': metrics.AverageMetric(),
        'corpus_bleu': metrics.CorpusBleuMetric(separator_type='wpm'),
    }
    return decoder_metrics

  # optional decoder
  def Decode(self, input_batch):
    """Constructs the decoding graph."""
    input_batch = self._PropagateEmbeddingIds(input_batch)
    return self._BeamSearchDecode(input_batch)

  def PostProcessDecodeOut(self, dec_out, dec_metrics):
    return self._PostProcessBeamSearchDecodeOut(dec_out, dec_metrics)


NameError: ignored