In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

In the official implementation of MSN the encoder and target model have initial state:

```
    # -- init model
    encoder = init_model(
        device=device,
        model_name=model_name,
        two_layer=two_layer,
        use_pred=use_pred_head,
        use_bn=use_bn,
        bottleneck=bottleneck,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        drop_path_rate=drop_path_rate,
    )
    # Target model is a deepcopy of the encoder model.
    target_encoder = copy.deepcopy(encoder)
```

The use of deepcopy is enabling it. In this notebook we are exploring the best way to reproduce this in TensorFlow.

In [2]:
def get_model():
    inputs = layers.Input(shape=(3,))
    x = layers.Dense(3, activation="gelu")(inputs)
    outputs = layers.Dense(3, activation="sigmoid")(x)
    
    return models.Model(inputs, outputs)

In [3]:
tf.keras.backend.clear_session()
model = get_model()
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 dense (Dense)               (None, 3)                 12        
                                                                 
 dense_1 (Dense)             (None, 3)                 12        
                                                                 
Total params: 24
Trainable params: 24
Non-trainable params: 0
_________________________________________________________________


This is how we copy model to a new model with same initial weights. Note that the optimizer state will not be copied (but it's not important for our usecase). 

Reference: https://github.com/keras-team/keras/issues/1765#issuecomment-367235276

In [4]:
model_copy = tf.keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())

In [5]:
model.get_weights()

[array([[ 0.4656961 ,  0.15759563,  0.6949961 ],
        [ 0.32020545, -0.42861724,  0.5538626 ],
        [ 0.05646205, -0.7538266 ,  0.17306066]], dtype=float32),
 array([0., 0., 0.], dtype=float32),
 array([[-0.34945416, -0.36945963, -0.08475637],
        [ 0.42724228,  0.7825148 , -0.00794744],
        [-0.1974678 , -0.67979026, -0.25060534]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [6]:
model_copy.get_weights()

[array([[ 0.4656961 ,  0.15759563,  0.6949961 ],
        [ 0.32020545, -0.42861724,  0.5538626 ],
        [ 0.05646205, -0.7538266 ,  0.17306066]], dtype=float32),
 array([0., 0., 0.], dtype=float32),
 array([[-0.34945416, -0.36945963, -0.08475637],
        [ 0.42724228,  0.7825148 , -0.00794744],
        [-0.1974678 , -0.67979026, -0.25060534]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [7]:
for w, w_copy in zip(model.get_weights(), model_copy.get_weights()):
    print(np.array_equal(w, w_copy))

True
True
True
True


Unfortunately this only works with Sequential or Functional model and not Subclassed model.

It will give this error: ```ValueError: Expected `model` argument to be a functional `Model` instance, but got a subclassed model instead:```

In [14]:
class SubclassedModel(tf.keras.Model):
    def __init__(self):
        super(SubclassedModel, self).__init__()
        self.hidden = layers.Dense(3, activation="gelu")
        self.outputs = layers.Dense(3, activation="sigmoid")
        
    def call(self, inputs):
        x = self.hidden(inputs)
        x = self.outputs(x)
        
        return x

In [24]:
tf.keras.backend.clear_session()
model = SubclassedModel()

model.build(input_shape=(1, 3))
model.summary()

Model: "subclassed_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               multiple                  12        
                                                                 
 dense_1 (Dense)             multiple                  12        
                                                                 
Total params: 24
Trainable params: 24
Non-trainable params: 0
_________________________________________________________________


> The created Subclassed model is same the one we created using Functional APIs.

In [40]:
tf.keras.backend.clear_session()
model = SubclassedModel()

model.get_weights()

[]

> Just initializing a subclassed model will not initialize the weights. Which makes sense since the input shape is required.

In [41]:
model.build(input_shape=(1, 3))
model.get_weights()

[array([[-0.8726375 , -0.351403  ,  0.21757174],
        [ 0.93361306, -0.7142036 ,  0.87285376],
        [-0.74052906,  0.9485116 ,  0.73050404]], dtype=float32),
 array([0., 0., 0.], dtype=float32),
 array([[-0.61596346,  0.28087306,  0.66120934],
        [ 0.7427747 , -0.58226514, -0.897063  ],
        [ 0.2804525 ,  0.35987568, -0.734334  ]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

> Maybe we can build a `model_copy`, build it and set the weights from original model.

In [44]:
model_copy = SubclassedModel()
model_copy.build(input_shape=(1, 3))

In [45]:
model_copy.set_weights(model.get_weights())

In [84]:
for w, w_copy in zip(model.get_weights(), model_copy.get_weights()):
    print(np.array_equal(w, w_copy))

True
True
True
True


> WORKS!