# Comparing TensorFlow (original) and PyTorch models

You can use this small notebook to check the conversion of the model's weights from the TensorFlow model to the PyTorch model. In the following, we compare the weights of the last layer on a simple example (in `input.txt`) but both models returns all the hidden layers so you can check every stage of the model.

To run this notebook, follow these instructions:
- make sure that your Python environment has both TensorFlow and PyTorch installed,
- download the original TensorFlow implementation,
- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,
- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.

If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code.

In [1]:
import os
os.chdir('../')

In [2]:
import tensorflow as tf

W0628 22:47:49.683971 139974244267776 __init__.py:308] Limited tf.compat.v2.summary API due to missing TensorBoard installation.


## 1/ TensorFlow code

In [3]:
original_tf_inplem_dir = "../bert/"
model_dir = "/tmp/pretraining_output/"

vocab_file = model_dir + "vocab.txt"
bert_config_file = model_dir + "bert_config.json"
init_checkpoint = model_dir + "model.ckpt-20"

input_file = "./samples/input.txt"
max_seq_length = 128

In [4]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features.py')
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules['extract_features_tensorflow'] = module
sys.path.append('../bert')
from extract_features_tensorflow import *

ModuleNotFoundError: No module named 'modeling'

In [6]:
# with tf.variable_scope("test", dtype=tf.float64):
layer_indexes = list(range(12))
bert_config = modeling.BertConfig.from_json_file(bert_config_file)
tokenizer = tokenization.FullTokenizer(
    vocab_file=vocab_file, do_lower_case=True)
examples = read_examples(input_file)

features = convert_examples_to_features(
    examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)
unique_id_to_feature = {}
for feature in features:
    unique_id_to_feature[feature.unique_id] = feature

W0628 11:10:21.640357 140112437987072 deprecation_wrapper.py:119] From ../bert/modeling.py:93: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

W0628 11:10:21.792764 140112437987072 deprecation_wrapper.py:119] From ../bert//extract_features.py:297: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.



In [7]:
# with tf.variable_scope("test", dtype=tf.float64):
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
    master=None,
    tpu_config=tf.contrib.tpu.TPUConfig(
        num_shards=1,
        per_host_input_for_training=is_per_host))

model_fn = model_fn_builder(
    bert_config=bert_config,
    init_checkpoint=init_checkpoint,
    layer_indexes=layer_indexes,
    use_tpu=False,
    use_one_hot_embeddings=False)

# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=False,
    model_fn=model_fn,
    config=run_config,
    predict_batch_size=1)

input_fn = input_fn_builder(
    features=features, seq_length=max_seq_length)

W0628 11:10:23.922789 140112437987072 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0628 11:10:23.925731 140112437987072 estimator.py:1984] Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7f6e40398b70>) includes params argument, but params are not passed to Estimator.
W0628 11:10:23.930257 140112437987072 estimator.py:1811] Using temporary folder as model directory: /tmp/tmpalg_r0p1
W0628 11:10:23.933089 140112437987072 tpu_context.py:750] Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.)
W0628 11:10:23.934185 140112437987072 tpu_context.py:211] eval_on_tpu ig

In [8]:
# with tf.variable_scope("test", dtype=tf.float64):
tensorflow_all_out = []
for result in estimator.predict(input_fn, yield_single_examples=True):
    unique_id = int(result["unique_id"])
    feature = unique_id_to_feature[unique_id]
    output_json = collections.OrderedDict()
    output_json["linex_index"] = unique_id
    tensorflow_all_out_features = []
    # for (i, token) in enumerate(feature.tokens):
    all_layers = []
    for (j, layer_index) in enumerate(layer_indexes):
        print("extracting layer {}".format(j))
        layer_output = result["layer_output_%d" % j]
        layers = collections.OrderedDict()
        layers["index"] = layer_index
        layers["values"] = layer_output
        all_layers.append(layers)
    tensorflow_out_features = collections.OrderedDict()
    tensorflow_out_features["layers"] = all_layers
    tensorflow_all_out_features.append(tensorflow_out_features)

    output_json["features"] = tensorflow_all_out_features
    tensorflow_all_out.append(output_json)

W0628 11:10:24.777451 140112437987072 deprecation_wrapper.py:119] From ../bert//extract_features.py:162: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0628 11:10:24.785151 140112437987072 deprecation_wrapper.py:119] From ../bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0628 11:10:24.818706 140112437987072 deprecation.py:323] From ../bert/modeling.py:485: to_double (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
W0628 11:10:24.825600 140112437987072 deprecation_wrapper.py:119] From ../bert/modeling.py:495: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.

W0628 11:10:24.893762 140112437987072 deprecation.py:323] From ../bert/modeling.py:676: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
I

double check  Tensor("bert/embeddings/ToDouble:0", shape=(?, 2), dtype=float64)


W0628 11:10:27.659510 140112437987072 deprecation_wrapper.py:119] From ../bert//extract_features.py:174: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.

W0628 11:10:27.670760 140112437987072 deprecation_wrapper.py:119] From ../bert//extract_features.py:187: The name tf.train.init_from_checkpoint is deprecated. Please use tf.compat.v1.train.init_from_checkpoint instead.

W0628 11:10:28.989188 140112437987072 deprecation.py:323] From /lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py:1354: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


[<tf.Variable 'bert/embeddings/word_embeddings:0' shape=(30522, 768) dtype=float64_ref>, <tf.Variable 'bert/embeddings/token_type_embeddings:0' shape=(2, 768) dtype=float64_ref>, <tf.Variable 'bert/embeddings/position_embeddings:0' shape=(512, 768) dtype=float64_ref>, <tf.Variable 'bert/embeddings/LayerNorm/beta:0' shape=(768,) dtype=float64_ref>, <tf.Variable 'bert/embeddings/LayerNorm/gamma:0' shape=(768,) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/query/kernel:0' shape=(768, 768) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/query/bias:0' shape=(768,) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/key/kernel:0' shape=(768, 768) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/key/bias:0' shape=(768,) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/value/kernel:0' shape=(768, 768) dtype=float64_ref>, <tf.Variable 'bert/encoder/layer_0/attention/self/value/bias:0' shape

E0628 11:10:30.636290 140112437987072 error_handling.py:70] Error recorded from prediction_loop: tensor_name = bert/encoder/layer_7/attention/self/value/kernel; expected dtype double does not equal original dtype float
	 [[node checkpoint_initializer_158 (defined at ../bert//extract_features.py:187) ]]

Original stack trace for 'checkpoint_initializer_158':
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/py

W0628 11:10:30.638432 140112437987072 error_handling.py:130] Reraising captured error


InvalidArgumentError: tensor_name = bert/encoder/layer_7/attention/self/value/kernel; expected dtype double does not equal original dtype float
	 [[node checkpoint_initializer_158 (defined at ../bert//extract_features.py:187) ]]

Original stack trace for 'checkpoint_initializer_158':
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 505, in start
    self.io_loop.start()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 148, in start
    self.asyncio_loop.run_forever()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/asyncio/base_events.py", line 438, in run_forever
    self._run_once()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/asyncio/base_events.py", line 1451, in _run_once
    handle._run()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/gen.py", line 781, in inner
    self.run()
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/gen.py", line 742, in run
    yielded = self.gen.send(value)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 272, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 542, in execute_request
    user_expressions, allow_stdin,
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2848, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2874, in _run_cell
    return runner(coro)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/async_helpers.py", line 67, in _pseudo_sync_runner
    coro.send(None)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3049, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3214, in run_ast_nodes
    if (yield from self.run_code(code, result)):
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3296, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-4dd15cdb6ee2>", line 3, in <module>
    for result in estimator.predict(input_fn, yield_single_examples=True):
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2913, in predict
    yield_single_examples=yield_single_examples):
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 619, in predict
    features, None, ModeKeys.PREDICT, self.config)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2709, in _call_model_fn
    config)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1146, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2967, in _model_fn
    features, labels, is_export_mode=is_export_mode)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 1549, in call_without_tpu
    return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 1867, in _call_model_fn
    estimator_spec = self._model_fn(features=features, **kwargs)
  File "../bert//extract_features.py", line 187, in model_fn
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint
    init_from_checkpoint_fn)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1684, in merge_call
    return self._merge_call(merge_fn, args, kwargs)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1691, in _merge_call
    return merge_fn(self._strategy, *args, **kwargs)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_utils.py", line 286, in <lambda>
    ckpt_dir_or_file, assignment_map)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_utils.py", line 334, in _init_from_checkpoint
    _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_utils.py", line 458, in _set_variable_or_list_initializer
    _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/training/checkpoint_utils.py", line 412, in _set_checkpoint_initializer
    ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1696, in restore_v2
    name=name)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/lfs/1/zjian/anaconda2/envs/bert-pretraining/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()


In [None]:
tf.executing_eagerly()

In [12]:
# with sess as tf.session():

In [13]:
tvars = tf.trainable_variables()
print(tvars)

[]


In [9]:
print(len(tensorflow_all_out))
print(len(tensorflow_all_out[0]))
print(tensorflow_all_out[0].keys())
print("number of tokens", len(tensorflow_all_out[0]['features']))
print("number of layers", len(tensorflow_all_out[0]['features'][0]['layers']))
tensorflow_all_out[0]['features'][0]['layers'][0]['values'].shape

1
2
odict_keys(['linex_index', 'features'])
number of tokens 1
number of layers 12


(128, 768)

In [10]:
tensorflow_outputs = list(tensorflow_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)

In [11]:
print(tensorflow_outputs[0])

[[-2.32029069  1.17394683 -0.26360368 ...  1.85500242 -1.81151755
   1.21505852]
 [-1.36823489  0.71090095  0.44132799 ...  0.79780859 -1.56258003
   0.10156749]
 [-1.35045821  1.10464981  0.45813221 ...  0.37352803 -0.58642476
   1.81977958]
 ...
 [ 0.59371397  0.54279067  1.2337462  ...  1.43765635 -0.35358792
   0.17982318]
 [ 0.05889777  0.53335308 -0.3478505  ...  0.16594686  0.8880447
  -0.95944832]
 [-1.09464921 -0.2095373   1.7236201  ...  1.49718659 -0.1028087
  -1.12150647]]


In [None]:
[[ 0.09768458  0.00179938 -0.14077505 ...  0.08461832  0.06095919
  -0.00853015]
 [-0.00433884  0.6136221  -0.28382796 ...  0.0972865   0.099303
  -0.7878085 ]
 [-0.32750753 -0.82382095  0.16448613 ... -0.19547357  0.1602732
   0.09975615]
 ...
 [ 0.07915473 -0.3328963   0.61744314 ...  0.46285754 -0.32566088
   0.02301384]
 [ 0.00915309 -0.38213047  0.48889852 ...  0.47541893 -0.24534532
  -0.08870013]
 [ 0.14412603 -0.27417877  0.5007422  ...  0.763762   -0.5418124
  -0.13567454]]

## 2/ PyTorch code

In [12]:
os.chdir('./examples')

In [13]:
import extract_features
import pytorch_pretrained_bert as ppb
from extract_features import *

In [14]:
init_checkpoint_pt = "/tmp/pretraining_output/"

In [15]:
device = torch.device("cpu")
model = ppb.BertModel.from_pretrained(init_checkpoint_pt)
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=

In [16]:
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_type_ids, all_example_index)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)

model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=

In [17]:
layer_indexes = list(range(12))

pytorch_all_out = []
for input_ids, input_mask, input_type_ids, example_indices in eval_dataloader:
    print(input_ids)
    print(input_mask)
    print(example_indices)
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)

    all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)

    for b, example_index in enumerate(example_indices):
        feature = features[example_index.item()]
        unique_id = int(feature.unique_id)
        # feature = unique_id_to_feature[unique_id]
        output_json = collections.OrderedDict()
        output_json["linex_index"] = unique_id
        all_out_features = []
        # for (i, token) in enumerate(feature.tokens):
        all_layers = []
        for (j, layer_index) in enumerate(layer_indexes):
            print("layer", j, layer_index)
            layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
            layer_output = layer_output[b]
            layers = collections.OrderedDict()
            layers["index"] = layer_index
            layer_output = layer_output
            layers["values"] = layer_output if not isinstance(layer_output, (int, float)) else [layer_output]
            all_layers.append(layers)

            out_features = collections.OrderedDict()
            out_features["layers"] = all_layers
            all_out_features.append(out_features)
        output_json["features"] = all_out_features
        pytorch_all_out.append(output_json)

tensor([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,  2001,
          1037, 13997, 11510,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [18]:
print(len(pytorch_all_out))
print(len(pytorch_all_out[0]))
print(pytorch_all_out[0].keys())
print("number of tokens", len(pytorch_all_out))
print("number of layers", len(pytorch_all_out[0]['features'][0]['layers']))
print("hidden_size", len(pytorch_all_out[0]['features'][0]['layers'][0]['values']))
pytorch_all_out[0]['features'][0]['layers'][0]['values'].shape

1
2
odict_keys(['linex_index', 'features'])
number of tokens 1
number of layers 12
hidden_size 128


(128, 768)

In [19]:
pytorch_outputs = list(pytorch_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)
print(pytorch_outputs[0].shape)
print(pytorch_outputs[1].shape)

(128, 768)
(128, 768)


In [20]:
print(tensorflow_outputs[0].shape)
print(tensorflow_outputs[1].shape)

(128, 768)
(128, 768)


In [21]:
print(tensorflow_outputs[0])

[[ 0.09768458  0.00179938 -0.14077505 ...  0.08461832  0.06095919
  -0.00853015]
 [-0.00433884  0.6136221  -0.28382796 ...  0.0972865   0.099303
  -0.7878085 ]
 [-0.32750753 -0.82382095  0.16448613 ... -0.19547357  0.1602732
   0.09975615]
 ...
 [ 0.07915473 -0.3328963   0.61744314 ...  0.46285754 -0.32566088
   0.02301384]
 [ 0.00915309 -0.38213047  0.48889852 ...  0.47541893 -0.24534532
  -0.08870013]
 [ 0.14412603 -0.27417877  0.5007422  ...  0.763762   -0.5418124
  -0.13567454]]


In [None]:
[[ 0.09768458  0.00179938 -0.14077505 ...  0.08461832  0.06095919
  -0.00853015]
 [-0.00433884  0.6136221  -0.28382796 ...  0.0972865   0.099303
  -0.7878085 ]
 [-0.32750753 -0.82382095  0.16448613 ... -0.19547357  0.1602732
   0.09975615]
 ...
 [ 0.07915473 -0.3328963   0.61744314 ...  0.46285754 -0.32566088
   0.02301384]
 [ 0.00915309 -0.38213047  0.48889852 ...  0.47541893 -0.24534532
  -0.08870013]
 [ 0.14412603 -0.27417877  0.5007422  ...  0.763762   -0.5418124
  -0.13567454]]

In [22]:
print(pytorch_outputs[0])

[[ 0.09786137  0.00116818 -0.14105    ...  0.08469906  0.06104695
  -0.00876771]
 [-0.00450904  0.6127007  -0.2837876  ...  0.09707876  0.09934989
  -0.78852844]
 [-0.32704192 -0.82473737  0.16454487 ... -0.19533129  0.16034643
   0.09961645]
 ...
 [ 0.07926201 -0.33316708  0.6171076  ...  0.4627899  -0.32563412
   0.02256444]
 [ 0.00925751 -0.3823599   0.48858032 ...  0.47534108 -0.24534258
  -0.08917644]
 [ 0.14433005 -0.27446604  0.5004202  ...  0.7637472  -0.5418068
  -0.1361167 ]]


## 3/ Comparing the standard deviation on the last layer of both models

In [21]:
import numpy as np

In [22]:
print('shape tensorflow layer, shape pytorch layer, standard deviation')
print('\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,
                          np.array(pytorch_outputs[i]).shape, 
                          np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(12))))

shape tensorflow layer, shape pytorch layer, standard deviation
((128, 768), (128, 768), 0.00020616385)
((128, 768), (128, 768), 0.0005438742)
((128, 768), (128, 768), 0.00066474144)
((128, 768), (128, 768), 0.0008685303)
((128, 768), (128, 768), 0.0013178637)
((128, 768), (128, 768), 0.001600817)
((128, 768), (128, 768), 0.0021130955)
((128, 768), (128, 768), 0.0023845905)
((128, 768), (128, 768), 0.002587812)
((128, 768), (128, 768), 0.00279683)
((128, 768), (128, 768), 0.0029927932)
((128, 768), (128, 768), 0.0014235352)
