Importing Libraries

In [None]:
import numpy as np
import tensorflow as tf
import os
from tensorflow.keras import layers,models
import matplotlib.pyplot as plt
from tqdm import tqdm

Data Augmentation

In [None]:
data_augmentation = tf.keras.Sequential([
    layers.Rescaling(1./255),
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
])

Load Training Datasets(No Label)

In [None]:
import os
train_dirs = [f"ssl_dataset/train.X{i}" for i in range(1, 5)]

def load_ssl_dataset():
    all_datasets = []
    for dir_path in train_dirs:
        ds = tf.keras.preprocessing.image_dataset_from_directory(
            dir_path,
            label_mode=None,
            image_size=(224, 224),
            batch_size=64
        )
        # Two views for SimCLR
        ds = ds.map(lambda x: (data_augmentation(x), data_augmentation(x)))
        all_datasets.append(ds)
    return all_datasets

ssl_datasets = load_ssl_dataset()


train_ds = ssl_datasets.map(lambda x: (data_augmentation(x), data_augmentation(x)))


SimCLR Encoder + Projection Head

In [None]:
def build_simclr_model():
    base_model = tf.keras.applications.ResNet50(include_top=False, weights=None, pooling='avg', input_shape=(224, 224, 3))

    inputs = tf.keras.Input(shape=(224, 224, 3))
    features = base_model(inputs)

    # Projection Head
    x = layers.Dense(512, activation='relu')(features)
    outputs = layers.Dense(128)(x)

    model = tf.keras.Model(inputs, outputs)
    return model


NT Xent Contrastive Loss

In [None]:
def contrastive_loss(z_i, z_j, temperature=0.5):
    # Normalize
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=1)

    batch_size = tf.shape(z_i)[0]
    z = tf.concat([z_i, z_j], axis=0)

    # Cosine similarity
    sim_matrix = tf.matmul(z, z, transpose_b=True)
    sim_matrix = sim_matrix / temperature

    labels = tf.range(batch_size)
    labels = tf.concat([labels, labels], axis=0)

    # Create contrastive loss using cross-entropy
    logits_mask = tf.linalg.diag(tf.ones_like(labels, dtype=tf.float32)) == 0
    sim_matrix = tf.boolean_mask(sim_matrix, logits_mask)
    sim_matrix = tf.reshape(sim_matrix, [2*batch_size, 2*batch_size - 1])

    positives = tf.reduce_sum(z_i * z_j, axis=-1) / temperature
    positives = tf.concat([positives, positives], axis=0)

    loss = tf.keras.losses.sparse_categorical_crossentropy(labels, sim_matrix, from_logits=True)
    return tf.reduce_mean(loss)


Training Loop

In [None]:
model = build_simclr_model()
optimizer = tf.keras.optimizers.Adam(1e-3)

@tf.function
def train_step(x1, x2):
    with tf.GradientTape() as tape:
        z1 = model(x1, training=True)
        z2 = model(x2, training=True)
        loss = contrastive_loss(z1, z2)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

# Training
EPOCHS = 3

for epoch in range(EPOCHS):
    for ssl_ds in ssl_datasets:
        for x1, x2 in tqdm(ssl_ds, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            loss = train_step(x1, x2)



Saving Encoder for Linear Probing

In [None]:
encoder = tf.keras.Model(inputs=model.input, outputs=model.layers[-3].output)
encoder.save('simclr_encoder.h5')
