In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers
import numpy as np


Define the encoder network

In [None]:

encoder = keras.Sequential(
    [
        layers.Conv2D(32, kernel_size=3, activation='relu', input_shape=(28, 28, 1)),
        layers.MaxPooling2D(pool_size=2),
        layers.Conv2D(64, kernel_size=3, activation='relu'),
        layers.MaxPooling2D(pool_size=2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
    ]
)

Define the contrastive loss function

In [None]:

class ContrastiveLoss(tf.keras.losses.Loss):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin
        
    def call(self, output1, output2, target):
        euclidean_distance = tf.reduce_mean(tf.square(output1 - output2), axis=-1, keepdims=True)
        loss_contrastive = tf.reduce_mean((1 - target) * tf.square(euclidean_distance) +
                                          (target) * tf.square(tf.maximum(self.margin - euclidean_distance, 0.0)))
        return loss_contrastive

Load the MNIST dataset

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255.0
x_test = np.expand_dims(x_test, -1).astype("float32") / 255.0

Define the data generator

In [None]:

batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)

Train the encoder using contrastive learning

In [None]:
n_epochs = 10
margin = 2.0
optimizer = optimizers.Adam(lr=0.001)
contrastive_loss = ContrastiveLoss(margin)

Run the training iterations

In [None]:
for epoch in range(1, n_epochs + 1):
    total_loss = 0.0
    for (data, target) in train_dataset:
        # Split the batch into two halves
        half_batch = data.shape[0] // 2
        data1, data2 = data[:half_batch], data[half_batch:]
        target1, target2 = target[:half_batch], target[half_batch:]
        # Generate the encodings for the two halves
        encoding1 = encoder(data1)
        encoding2 = encoder(data2)
        # Compute the contrastive loss
        target_contrastive = tf.cast(target1 == target2, dtype=tf.float32)
        loss_contrastive = contrastive_loss(encoding1, encoding2, target_contrastive)
        # Backpropagate and update the weights
        gradients = tape.gradient(loss_contrastive, encoder.trainable_weights)
        optimizer.apply_gradients(zip(gradients, encoder.trainable_weights))
        total_loss += loss_contrastive.numpy()

Print

In [None]:
print('Epoch: {}, Loss: {:.4f}'.format(epoch, total_loss / (len(train_dataset))))
