In [1]:
import os
import tensorflow as tf
from transformers import (
    BertForSequenceClassification,
    BertConfig,
)
import numpy as np
import json

def bert_pytorch_to_tensorflow(pt_model:BertForSequenceClassification,
                               tf_bert_config_file:str,
                               tf_output_dir:str,
                               tf_model_name:str = "bert_model"):

    """
    Converts a PyTorch transformers BertForSequenceClassification model to Tensorflow
    :param pt_model: PyTorch model instance to be converted
    :param tf_bert_config_file: path to bert_config.json file with Tensorflow BERT configuration.
      This config file should correspond to the architecture (N layers, N hidden units, etc.) of the PyTorch model. 
      Hopefully in future the code below will be improved and this config file will be generated on the fly. 
      Feel free to contribute such an implementation to https://gist.github.com/artoby/b13d2b9d2d6d7f21e195bdb8542709c6
    :param tf_output_dir: directory to write resulting Tensorflow model to
    :param tf_model_name: resulting Tensorflow model name (will be used in a file name)
    :return:
    """
    tensors_to_transpose = (
        "dense.weight",
        "attention.self.query",
        "attention.self.key",
        "attention.self.value"
    )
    # Pytorch name, TF name, continue if found
    name_patterns_map = (
        ('classifier.weight', 'output_weights', False),
        ('classifier.bias', 'output_bias', False),
        ('layer.', 'layer_', True),
        ('word_embeddings.weight', 'word_embeddings', True),
        ('position_embeddings.weight', 'position_embeddings', True),
        ('token_type_embeddings.weight', 'token_type_embeddings', True),
        ('.', '/', True),
        ('LayerNorm/weight', 'LayerNorm/gamma', True),
        ('LayerNorm/bias', 'LayerNorm/beta', True),
        ('weight', 'kernel', True)
    )

    if not os.path.isdir(tf_output_dir):
        os.makedirs(tf_output_dir)

    state_dict = pt_model.state_dict()

    def to_tf_var_name(name:str):
        for patt, repl, cont in iter(name_patterns_map):
            if patt in name:
                name = name.replace(patt, repl)
                if not cont:
                    break
        return name

    def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    
    all_vars = {}
    
    pytorch_vars = dict([(k, v.numpy()) for (k, v) in state_dict.items()])
    all_vars.update(pytorch_vars)
    
    additional_vars = {"global_step": np.array([0])}
    all_vars.update(additional_vars)
    
    with tf.Session() as session:
        for var_name, np_value in all_vars.items():
            print(var_name)
            tf_name = to_tf_var_name(var_name)
            if any([x in var_name for x in tensors_to_transpose]):
                np_value = np_value.T
            tf_var = create_tf_var(tensor=np_value, name=tf_name, session=session)
            tf.keras.backend.set_value(tf_var, np_value)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, np_value)))

        saver = tf.train.Saver(tf.trainable_variables())
        print ("Se guardará el modelo en:", os.path.join(tf_output_dir, tf_model_name.replace("-", "_") + ".ckpt"))
        saver.save(session, os.path.join(tf_output_dir, tf_model_name.replace("-", "_") + ".ckpt"))
        
    with open(tf_bert_config_file) as f:
        tf_bert_config = json.load(f)
    
    vocab_size = pytorch_vars["bert.embeddings.word_embeddings.weight"].shape[0]
    tf_bert_config["vocab_size"] = vocab_size
    with open(os.path.join(tf_output_dir, "bert_config.json"), "w") as f:
        json.dump(tf_bert_config, f, indent=2)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
pt_model_dir = "../../../../model_save/Dos-Fases-all_Harassment/"
pt_config = BertConfig.from_pretrained(pt_model_dir)
pt_model = BertForSequenceClassification.from_pretrained(pt_model_dir, config=pt_config)
bert_pytorch_to_tensorflow(
        pt_model=pt_model,
        tf_bert_config_file="../../../../model_save/Dos-Fases-all_Harassment/config.json",
        tf_output_dir="../../../../model_save/Dos-Fases-all_Harassment_TENSORFLOW/"
    )


bert.embeddings.word_embeddings.weight
Successfully created bert/embeddings/word_embeddings: True
bert.embeddings.position_embeddings.weight
Successfully created bert/embeddings/position_embeddings: True
bert.embeddings.token_type_embeddings.weight
Successfully created bert/embeddings/token_type_embeddings: True
bert.embeddings.LayerNorm.weight
Successfully created bert/embeddings/LayerNorm/gamma: True
bert.embeddings.LayerNorm.bias
Successfully created bert/embeddings/LayerNorm/beta: True
bert.encoder.layer.0.attention.self.query.weight
Successfully created bert/encoder/layer_0/attention/self/query/kernel: True
bert.encoder.layer.0.attention.self.query.bias
Successfully created bert/encoder/layer_0/attention/self/query/bias: True
bert.encoder.layer.0.attention.self.key.weight
Successfully created bert/encoder/layer_0/attention/self/key/kernel: True
bert.encoder.layer.0.attention.self.key.bias
Successfully created bert/encoder/layer_0/attention/self/key/bias: True
bert.encoder.layer.0.

Successfully created bert/encoder/layer_4/attention/self/query/bias: True
bert.encoder.layer.4.attention.self.key.weight
Successfully created bert/encoder/layer_4/attention/self/key/kernel: True
bert.encoder.layer.4.attention.self.key.bias
Successfully created bert/encoder/layer_4/attention/self/key/bias: True
bert.encoder.layer.4.attention.self.value.weight
Successfully created bert/encoder/layer_4/attention/self/value/kernel: True
bert.encoder.layer.4.attention.self.value.bias
Successfully created bert/encoder/layer_4/attention/self/value/bias: True
bert.encoder.layer.4.attention.output.dense.weight
Successfully created bert/encoder/layer_4/attention/output/dense/kernel: True
bert.encoder.layer.4.attention.output.dense.bias
Successfully created bert/encoder/layer_4/attention/output/dense/bias: True
bert.encoder.layer.4.attention.output.LayerNorm.weight
Successfully created bert/encoder/layer_4/attention/output/LayerNorm/gamma: True
bert.encoder.layer.4.attention.output.LayerNorm.bias

Successfully created bert/encoder/layer_8/attention/output/LayerNorm/gamma: True
bert.encoder.layer.8.attention.output.LayerNorm.bias
Successfully created bert/encoder/layer_8/attention/output/LayerNorm/beta: True
bert.encoder.layer.8.intermediate.dense.weight
Successfully created bert/encoder/layer_8/intermediate/dense/kernel: True
bert.encoder.layer.8.intermediate.dense.bias
Successfully created bert/encoder/layer_8/intermediate/dense/bias: True
bert.encoder.layer.8.output.dense.weight
Successfully created bert/encoder/layer_8/output/dense/kernel: True
bert.encoder.layer.8.output.dense.bias
Successfully created bert/encoder/layer_8/output/dense/bias: True
bert.encoder.layer.8.output.LayerNorm.weight
Successfully created bert/encoder/layer_8/output/LayerNorm/gamma: True
bert.encoder.layer.8.output.LayerNorm.bias
Successfully created bert/encoder/layer_8/output/LayerNorm/beta: True
bert.encoder.layer.9.attention.self.query.weight
Successfully created bert/encoder/layer_9/attention/self