Loading Dataset (Cr: Enrico Deccen Aristan)

In [1]:
import numpy as np, matplotlib.pyplot as plt, tensorflow as tf
from tensorflow.keras import layers as lyr, models as mods

# Load CIFAR-10 dataset
(xtr, ytr), (xts, yts) = tf.keras.datasets.cifar10.load_data()

# Input Normalization
xtr = xtr.astype("float32") / 255.0
xts = xts.astype("float32") / 255.0

# Label 1-Hot Encoding
ytr_1h = tf.keras.utils.to_categorical(ytr, 10)
yts_1h = tf.keras.utils.to_categorical(yts, 10)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


Data Augmenter (Cr: Enrico Deccen Aristan)

In [8]:
data_aug = tf.keras.Sequential([
    lyr.RandomFlip("horizontal"),
    lyr.RandomRotation(0.25),
    lyr.RandomZoom(0.15)
])

SimCLR Components (Cr: Enrico Deccen Aristan)

In [9]:
# ResNet-like base encoder
def create_encoder():
    ins = lyr.Input(shape=(32, 32, 3))
    x = lyr.Conv2D(32, (3, 3), activation="relu", padding="same")(ins)
    x = lyr.MaxPooling2D((2, 2))(x)
    x = lyr.Conv2D(64, (3, 3), activation="relu", padding="same")(x)
    x = lyr.MaxPooling2D((2, 2))(x)
    x = lyr.Flatten()(x)
    x = lyr.Dense(128, activation="relu")(x)  # Feature embeddings
    return mods.Model(ins, x)

def create_projection_head(base_mod):
    ins = lyr.Input(shape=(128,))
    x = lyr.Dense(64, activation="relu")(ins)
    x = lyr.Dense(64)(x)  # Embedding space for contrastive loss
    return mods.Model(ins, x)

Contrastive Loss Function (Cr: Enrico Deccen Aristan)

In [10]:
def contrastive_loss(projs, temp=0.5):
    projs = tf.nn.l2_normalize(projs, axis=1)  # Normalize projections
    sim_mtx = tf.matmul(projs, projs, transpose_b=True)  # Cosine similarity
    
    batch_size = tf.shape(projs)[0] // 2  # Halve the batch size given 2 augmentations
    
    # Create labels for positive pairs (diagonal of the similarity matrix)
    lbls = tf.one_hot(tf.range(batch_size), batch_size * 2)  # Labels for positive pairs
    lbls = tf.concat([lbls, lbls], axis=0)  # Duplicate for both augmented views
    
    # logits: similarity matrix scaled by temperature
    lgts = sim_mtx / temp

    loss = tf.nn.softmax_cross_entropy_with_logits(lbls, lgts)
    
    return tf.reduce_mean(loss)

SimCLR Training Loop (Cr: Enrico Deccen Aristan)

In [None]:
# Instantiate SimCLR Models
encoder = create_encoder()
projection_head = create_projection_head(encoder)

# Hyperparameters
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
batch_size = 256
epochs = 10

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    for i in range(0, len(xtr), batch_size):
        batch_images = xtr[i:i + batch_size]
        augmented_1 = data_aug(batch_images)
        augmented_2 = data_aug(batch_images)

        # Compute projections
        with tf.GradientTape() as tape:
            proj_1 = projection_head(encoder(augmented_1), training=True)
            proj_2 = projection_head(encoder(augmented_2), training=True)

            # Concatenate for NT-Xent Loss
            projections = tf.concat([proj_1, proj_2], axis=0)
            loss = contrastive_loss(projections)

        # Apply gradients
        gradients = tape.gradient(loss, encoder.trainable_variables + projection_head.trainable_variables)
        optimizer.apply_gradients(zip(gradients, encoder.trainable_variables + projection_head.trainable_variables))

    print(f"Loss: {loss.numpy()}")

Epoch 1/10
Loss: 3.302361249923706
Epoch 2/10


MoCo Components

In [None]:
# ResNet-like base encoder
def moco_create_encoder():
    ins = lyr.Input(shape=(32, 32, 3))
    x = lyr.Conv2D(32, (3, 3), activation="relu", padding="same")(ins)
    x = lyr.MaxPooling2D((2, 2))(x)
    x = lyr.Conv2D(64, (3, 3), activation="relu", padding="same")(x)
    x = lyr.MaxPooling2D((2, 2))(x)
    x = lyr.Flatten()(x)
    x = lyr.Dense(128, activation="relu")(x)  # Feature embeddings
    return mods.Model(ins, x)

def moco_create_projection_head(base_mod):
    ins = lyr.Input(shape=(128,))
    x = lyr.Dense(64, activation="relu")(ins)
    x = lyr.Dense(64)(x)  # Embedding space for contrastive loss
    return mods.Model(ins, x)

In [None]:
# Contrastive Loss for MoCo
def moco_contrastive_loss(projs, memory, temp=0.5):
    projs = tf.nn.l2_normalize(projs, axis=1) 
    memory = tf.nn.l2_normalize(memory, axis=1)

    sim_mtx = tf.matmul(projs, memory, transpose_b=True) 
    
    batch_size = tf.shape(projs)[0] // 2 
    lbls = tf.one_hot(tf.range(batch_size), batch_size * 2)
    lbls = tf.concat([lbls, lbls], axis=0)
    
    lgts = sim_mtx / temp
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(lbls, lgts))

    return loss

# MoCo-specific momentum encoder update function
def moco_update_momentum_encoder(encoder, momentum_encoder, momentum=0.99):
    for var, momentum_var in zip(encoder.trainable_variables, momentum_encoder.trainable_variables):
        momentum_var.assign(momentum * momentum_var + (1.0 - momentum) * var)

MoCo Training Loop

In [None]:
# Instantiate MoCo Models
encoder = moco_create_encoder()
projection_head = moco_create_projection_head(encoder)

# Initialize the momentum encoder (target encoder)
momentum_encoder = moco_create_encoder()
momentum_projection_head = moco_create_projection_head(momentum_encoder)

# Hyperparameters
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
batch_size = 256
epochs = 10
momentum = 0.99  # Momentum for the momentum encoder

# Memory bank (used to store the keys from the momentum encoder)
memory_bank = tf.Variable(tf.zeros((batch_size, 64)), trainable=False)  # Adjust dimensions as necessary

# MoCo Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    for i in range(0, len(xtr), batch_size):
        batch_images = xtr[i:i + batch_size]
        augmented_1 = data_aug(batch_images)
        augmented_2 = data_aug(batch_images)

        # Compute projections
        with tf.GradientTape() as tape:
            proj_1 = projection_head(encoder(augmented_1), training=True)
            proj_2 = projection_head(encoder(augmented_2), training=True)

            # Concatenate projections from two augmentations
            projections = tf.concat([proj_1, proj_2], axis=0)

            # Get memory bank (keys from momentum encoder)
            momentum_proj_1 = momentum_projection_head(momentum_encoder(augmented_1), training=False)
            momentum_proj_2 = momentum_projection_head(momentum_encoder(augmented_2), training=False)
            memory_proj = tf.concat([momentum_proj_1, momentum_proj_2], axis=0)
            
            # Compute MoCo contrastive loss
            loss = moco_contrastive_loss(projections, memory_proj)

        # Apply gradients
        gradients = tape.gradient(loss, encoder.trainable_variables + projection_head.trainable_variables)
        optimizer.apply_gradients(zip(gradients, encoder.trainable_variables + projection_head.trainable_variables))

        # Update momentum encoder
        moco_update_momentum_encoder(encoder, momentum_encoder, momentum=momentum)

    print(f"Loss: {loss.numpy()}")