In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

In the official implementation of MSN the encoder and target model have initial state:

```
    # -- init model
    encoder = init_model(
        device=device,
        model_name=model_name,
        two_layer=two_layer,
        use_pred=use_pred_head,
        use_bn=use_bn,
        bottleneck=bottleneck,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        drop_path_rate=drop_path_rate,
    )
    # Target model is a deepcopy of the encoder model.
    target_encoder = copy.deepcopy(encoder)
```

The use of deepcopy is enabling it. In this notebook we are exploring the best way to reproduce this in TensorFlow.

In [6]:
def get_model():
    inputs = layers.Input(shape=(3,))
    x = layers.Dense(3, activation="gelu")(inputs)
    outputs = layers.Dense(3, activation="sigmoid")(x)
    
    return models.Model(inputs, outputs)

In [7]:
tf.keras.backend.clear_session()
model = get_model()
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 dense (Dense)               (None, 3)                 12        
                                                                 
 dense_1 (Dense)             (None, 3)                 12        
                                                                 
Total params: 24
Trainable params: 24
Non-trainable params: 0
_________________________________________________________________


This is how we copy model to a new model with same initial weights. Note that the optimizer state will not be copied (but it's not important for our usecase). 

Reference: https://github.com/keras-team/keras/issues/1765#issuecomment-367235276

In [12]:
model_copy = tf.keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())

In [13]:
model.get_weights()

[array([[ 0.5376537 ,  0.51263   , -0.35281944],
        [-0.20083904, -0.52574396, -0.70703864],
        [ 0.66147375, -0.31244206,  0.32125545]], dtype=float32),
 array([0., 0., 0.], dtype=float32),
 array([[ 0.92110085, -0.6732621 ,  0.42984366],
        [ 0.4803393 ,  0.31643176, -0.06728935],
        [ 0.9882252 ,  0.03659225, -0.36137056]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [14]:
model_copy.get_weights()

[array([[ 0.5376537 ,  0.51263   , -0.35281944],
        [-0.20083904, -0.52574396, -0.70703864],
        [ 0.66147375, -0.31244206,  0.32125545]], dtype=float32),
 array([0., 0., 0.], dtype=float32),
 array([[ 0.92110085, -0.6732621 ,  0.42984366],
        [ 0.4803393 ,  0.31643176, -0.06728935],
        [ 0.9882252 ,  0.03659225, -0.36137056]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [19]:
for w, w_copy in zip(model.get_weights(), model_copy.get_weights()):
    print(np.array_equal(w, w_copy))

True
True
True
True
