<a href="https://colab.research.google.com/github/ayulockin/TF-MSN/blob/main/notebooks/TF_Exponential_Moving_Average_Weights_Sharing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

import numpy as np

In [4]:
x = np.random.randint(255, size=(100, 28, 28, 1)).astype(np.float32)
y = np.random.randint(10, size=(100,)).astype(np.float32)

x.shape, y.shape

((100, 28, 28, 1), (100,))

In [18]:
def build_anchor_model():
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(3, 3, activation="relu")(inputs)
    x = layers.GlobalAvgPool2D()(x)
    classifier = layers.Dense(10, activation="softmax")(x)

    return models.Model(inputs, classifier, name="anchor_model")

def build_target_model():
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(3, 3, activation="relu")(inputs)
    x = layers.GlobalAvgPool2D()(x)
    classifier = layers.Dense(10, activation="softmax")(x)

    return models.Model(inputs, classifier, name="target_model")

In [27]:
tf.keras.backend.clear_session()
anchor_model = build_anchor_model()
anchor_model.summary()

Model: "anchor_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 global_average_pooling2d (G  (None, 3)                0         
 lobalAveragePooling2D)                                          
                                                                 
 dense (Dense)               (None, 10)                40        
                                                                 
Total params: 70
Trainable params: 70
Non-trainable params: 0
_________________________________________________________________


In [20]:
tf.keras.backend.clear_session()
target_model = build_target_model()
target_model.summary()

Model: "target_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 global_average_pooling2d (G  (None, 3)                0         
 lobalAveragePooling2D)                                          
                                                                 
 dense (Dense)               (None, 10)                40        
                                                                 
Total params: 70
Trainable params: 70
Non-trainable params: 0
_________________________________________________________________


In [70]:
def siamese_network():
    inputs = layers.Input(shape=(28,28,1))
    # Init anchor model
    anchor_model = build_anchor_model()
    # Init target model without trainable params.
    target_model = build_target_model()
    target_model.trainable = False

    z1 = anchor_model(inputs)
    z2 = target_model(inputs)

    return models.Model(inputs, outputs=[z1, z2])

In [73]:
model = siamese_network()
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_11 (InputLayer)          [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 anchor_model (Functional)      (None, 10)           70          ['input_11[0][0]']               
                                                                                                  
 target_model (Functional)      (None, 10)           70          ['input_11[0][0]']               
                                                                                                  
Total params: 140
Trainable params: 70
Non-trainable params: 70
__________________________________________________________________________________________________


In [72]:
model.compile(
    optimizer='adam',
    loss={
        'anchor_model': 'sparse_categorical_crossentropy',
        'target_model': 'sparse_categorical_crossentropy'
    }
)

model.fit(x, y, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f0d818543d0>

In [67]:
class EMA(tf.keras.callbacks.Callback):
    def __init__(self, decay=0.996):
        super(EMA, self).__init__()
        self.decay = decay

        # Create an ExponentialMovingAverage object
        self.ema = tf.train.ExponentialMovingAverage(decay=self.decay)

    def on_train_begin(self, logs=None):
        self.ema.apply(self.model.get_layer('anchor_model').trainable_variables)

    def on_epoch_end(self, epoch, logs=None):
        # Get exponential moving average of anchor model weights.
        train_vars = self.model.get_layer('anchor_model').trainable_variables
        averages = [self.ema.average(var) for var in train_vars]
        
        # Assign the average weights to target model
        target_model_vars = self.model.get_layer('target_model').non_trainable_variables
        assert len(target_model_vars) == len(averages)
        for i, var in enumerate(target_model_vars):
            var.assign(averages[i])

In [75]:
model.compile(
    optimizer='adam',
    loss={
        'anchor_model': 'sparse_categorical_crossentropy',
        'target_model': 'sparse_categorical_crossentropy'
    }
)

model.fit(x, y, epochs=10, callbacks=[EMA()])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f0d7a32c750>