<a href="https://colab.research.google.com/github/ayulockin/TF-MSN/blob/main/notebooks/EMA_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -qq wandb

In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import layers
from tensorflow.keras import models

import wandb
from wandb.keras import WandbCallback

wandb.login()

In [None]:
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

train_imgs, train_labels = x_train[:1000], y_train[:1000]
valid_imgs, valid_labels = x_train[1000:1100], y_train[1000:1100]

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 32

def preprocess_image(image, label):
    img = tf.cast(image, tf.float32)
    img = img/255.

    return img, label

trainloader = tf.data.Dataset.from_tensor_slices((train_imgs, train_labels))
validloader = tf.data.Dataset.from_tensor_slices((valid_imgs, valid_labels))

trainloader = (
    trainloader
    .shuffle(1024)
    .map(preprocess_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

validloader = (
    validloader
    .map(preprocess_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

In [None]:
def build_anchor_model():
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(3, 3, activation="relu")(inputs)
    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)

    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)
    
    x = layers.GlobalAvgPool2D()(x)
    x = layers.Dense(64)(x)
    classifier = layers.Dense(10, activation="softmax")(x)

    return models.Model(inputs, classifier, name="anchor_model")

def build_target_model():
    inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(3, 3, activation="relu")(inputs)
    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)

    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.Conv2D(3, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)
    
    x = layers.GlobalAvgPool2D()(x)
    x = layers.Dense(64)(x)
    classifier = layers.Dense(10, activation="softmax")(x)

    return models.Model(inputs, classifier, name="target_model")

In [None]:
tf.keras.backend.clear_session()
anchor_model = build_anchor_model()
anchor_model.summary()

In [None]:
tf.keras.backend.clear_session()
target_model = build_target_model()
target_model.summary()

In [None]:
def siamese_network():
    inputs = layers.Input(shape=(28,28,1))
    # Init anchor model
    anchor_model = build_anchor_model()
    # Init target model without trainable params.
    target_model = build_target_model()
    target_model.trainable = False

    z1 = anchor_model(inputs)
    z2 = target_model(inputs)

    return models.Model(inputs, outputs=[z1, z2])

### Without EMA

In [None]:
run = wandb.init()

tf.keras.backend.clear_session()
model = siamese_network()
model.summary(expand_nested=False)

model.compile(
    optimizer='adam',
    loss={
        'anchor_model': 'sparse_categorical_crossentropy',
        'target_model': 'sparse_categorical_crossentropy'
    },
    metrics=["accuracy"]
)

model.fit(trainloader, validation_data=validloader, epochs=100, callbacks=[WandbCallback(save_model=False)])

run.finish()

### With EMA

In [None]:
class EMA(tf.keras.callbacks.Callback):
    def __init__(self, decay=0.999):
        super(EMA, self).__init__()
        self.decay = decay

        # Create an ExponentialMovingAverage object
        self.ema = tf.train.ExponentialMovingAverage(decay=self.decay)

    def on_train_begin(self, logs=None):
        self.ema.apply(self.model.get_layer('anchor_model').trainable_variables)

    def on_epoch_end(self, epoch, logs=None):
        # Get exponential moving average of anchor model weights.
        train_vars = self.model.get_layer('anchor_model').trainable_variables
        averages = [self.ema.average(var) for var in train_vars]

        # Assign the average weights to target model
        target_model_vars = self.model.get_layer('target_model').non_trainable_variables
        assert len(target_model_vars) == len(averages)
        for i, var in enumerate(target_model_vars):
            var.assign(averages[i])

        self.ema.apply(self.model.get_layer('anchor_model').trainable_variables)

In [None]:
run = wandb.init()

tf.keras.backend.clear_session()
model = siamese_network()
model.summary()

model.compile(
    optimizer='adam',
    loss={
        'anchor_model': 'sparse_categorical_crossentropy',
        'target_model': 'sparse_categorical_crossentropy'
    },
    metrics=["accuracy"]
)

model.fit(trainloader, validation_data=validloader, epochs=100, callbacks=[EMA(), WandbCallback(save_model=False)])

run.finish()