# tf2-keras-models-sharing-weights

Based on "Training a neural network on MNIST with Keras" : https://www.tensorflow.org/datasets/keras_example

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [3]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [4]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
model_input = tf.keras.layers.Input(shape=(28, 28))
x = tf.keras.layers.Flatten()(model_input)
x = tf.keras.layers.Dense(128, activation='relu')(x)

model_output_1 = tf.keras.layers.Dense(10)(x)
model_output_2 = tf.keras.layers.Dense(10)(x)

model_1 = tf.keras.Model(model_input, model_output_1, name="model_1")
model_1.summary()

model_2 = tf.keras.Model(model_input, model_output_2, name="model_2")
model_2.summary()

model_1.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model_2.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
____________________________________

In [6]:
print(np.array_equal(np.array(model_1.layers[2].get_weights()[0]).flatten(), np.array(model_2.layers[2].get_weights()[0]).flatten()))
print(np.array_equal(np.array(model_1.layers[2].get_weights()[1]).flatten(), np.array(model_2.layers[2].get_weights()[1]).flatten()))

True
True


In [7]:
print(np.array_equal(np.array(model_1.layers[3].get_weights()[0]).flatten(), np.array(model_2.layers[3].get_weights()[0]).flatten()))
print(np.array_equal(np.array(model_1.layers[3].get_weights()[1]).flatten(), np.array(model_2.layers[3].get_weights()[1]).flatten()))

False
True


In [8]:
print(model_2.evaluate(ds_train))
print(model_2.evaluate(ds_train))

[2.393404960632324, 0.11423332989215851]
[2.3934051990509033, 0.11423332989215851]


In [9]:
model_1.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<tensorflow.python.keras.callbacks.History at 0x7f13a004c3c8>

In [10]:
print(model_2.evaluate(ds_train))
print(model_2.evaluate(ds_train))

[4.08781099319458, 0.04309999942779541]
[4.087810516357422, 0.04309999942779541]


In [11]:
print(np.array_equal(np.array(model_1.layers[2].get_weights()[0]), np.array(model_2.layers[2].get_weights()[0])))
print(np.array_equal(np.array(model_1.layers[2].get_weights()[1]), np.array(model_2.layers[2].get_weights()[1])))

True
True


In [12]:
print(np.array_equal(np.array(model_1.layers[3].get_weights()[0]).flatten(), np.array(model_2.layers[3].get_weights()[0]).flatten()))
print(np.array_equal(np.array(model_1.layers[3].get_weights()[1]).flatten(), np.array(model_2.layers[3].get_weights()[1]).flatten()))

False
False
