In [3]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import random
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

print("TensorFlow Version:", tf.__version__)

# IMAGE PATH LOADING (SAFE)
def load_unlabeled_image_paths(root_folder):
    """
    Returns a list of image file paths.
    """
    image_paths = []
    for fn in os.listdir(root_folder):
        if fn.lower().endswith((".jpg", ".jpeg", ".png")):
            image_paths.append(os.path.join(root_folder, fn))
    if len(image_paths) == 0:
        raise ValueError("No images found in folder.")
    return image_paths


# SAFE AUGMENTATION FUNCTION
def random_augment(image):
    """
    Applies a series of random augmentations for contrastive learning.
    """
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.4)
    image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
    image = tf.image.central_crop(image, central_fraction=0.8)
    image = tf.image.resize(image, (224, 224))
    return image


# CREATE DATASET OF SIMCLR PAIRS
def create_augmented_dataset(image_paths, batch_size):
    """
    Creates a dataset of augmented image pairs for SimCLR training.
    """
    def decode_and_preprocess(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [224, 224])
        return image

    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(decode_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda x: (random_augment(x), random_augment(x)), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(500).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


# ENCODER EXTRACTION
def extract_encoder(mae_model):
    """
    Extracts the encoder part from a trained MAE teacher model.
    """
    encoder_output = mae_model.get_layer("dense_2").output
    encoder = Model(inputs=mae_model.input, outputs=encoder_output)
    return encoder


# NT-XENT LOSS FUNCTION
def nt_xent_loss(z_i, z_j, temperature=0.5):
    """
    Computes the NT-Xent contrastive loss.
    """
    z_i = tf.reshape(z_i, [tf.shape(z_i)[0], -1])
    z_j = tf.reshape(z_j, [tf.shape(z_j)[0], -1])
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=1)

    logits = tf.matmul(z_i, tf.transpose(z_j)) / temperature
    labels = tf.range(tf.shape(z_i)[0])
    loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
    return loss


# SIMCLR FINE-TUNING FUNCTION
def fine_tune_with_simclr(mae_model_path, root_folder, batch_size=32, epochs=5):
    """
    Fine-tunes the MAE encoder using SimCLR contrastive learning.
    """
    print("Loading pre-trained MAE model...")
    mae_model = tf.keras.models.load_model(mae_model_path)
    encoder = extract_encoder(mae_model)
    print("Encoder successfully extracted!")

    print("Loading image paths...")
    image_paths = load_unlabeled_image_paths(root_folder)

    print("Creating dataset...")
    dataset = create_augmented_dataset(image_paths, batch_size)

    optimizer = Adam(learning_rate=0.001)
    max_steps_per_epoch = 300  # <-- Limit to 300 steps per epoch

    print("Starting SimCLR fine-tuning...")
    for epoch in range(epochs):
        epoch_loss = 0
        steps = 0
        for step, (x_i, x_j) in enumerate(dataset):
            if step >= max_steps_per_epoch:
                break
            with tf.GradientTape() as tape:
                z_i = encoder(x_i, training=True)
                z_j = encoder(x_j, training=True)
                loss = nt_xent_loss(z_i, z_j)
            gradients = tape.gradient(loss, encoder.trainable_variables)
            optimizer.apply_gradients(zip(gradients, encoder.trainable_variables))
            epoch_loss += tf.reduce_mean(loss).numpy()
            steps += 1
            if step % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Step {step}, Loss: {tf.reduce_mean(loss).numpy():.4f}")
        avg_loss = epoch_loss / steps
        print(f"Epoch {epoch+1}/{epochs} completed. Avg SimCLR Loss: {avg_loss:.4f}")

    encoder.save("ssl_teacher_model.keras")
    print("Fine-tuning complete! Encoder saved as 'ssl_teacher_model.keras'.")
    return encoder


# RUN SCRIPT
if __name__ == "__main__":
    root_folder = "/Users/morgan/Desktop/DissCodeBTC/Datasets/KaggleUnlabelled"  
    mae_model_path = "SavedModels/mae_model.h5"  
    ssl_teacher = fine_tune_with_simclr(mae_model_path, root_folder, batch_size=16, epochs=5)


TensorFlow Version: 2.18.0
Loading pre-trained MAE model...




Encoder successfully extracted!
Loading image paths...
Creating dataset...
Starting SimCLR fine-tuning...
Epoch 1/5, Step 0, Loss: 2.2546
Epoch 1/5, Step 10, Loss: 1.5369
Epoch 1/5, Step 20, Loss: 1.3956
Epoch 1/5, Step 30, Loss: 1.3800
Epoch 1/5, Step 40, Loss: 1.3115
Epoch 1/5, Step 50, Loss: 1.3469
Epoch 1/5, Step 60, Loss: 1.3125
Epoch 1/5, Step 70, Loss: 1.2954
Epoch 1/5, Step 80, Loss: 1.2756
Epoch 1/5, Step 90, Loss: 1.4639
Epoch 1/5, Step 100, Loss: 1.2647
Epoch 1/5, Step 110, Loss: 1.3916
Epoch 1/5, Step 120, Loss: 1.2931
Epoch 1/5, Step 130, Loss: 1.3059
Epoch 1/5, Step 140, Loss: 1.3207
Epoch 1/5, Step 150, Loss: 1.4004
Epoch 1/5, Step 160, Loss: 1.2928
Epoch 1/5, Step 170, Loss: 1.2831
Epoch 1/5, Step 180, Loss: 1.3366
Epoch 1/5, Step 190, Loss: 1.3151
Epoch 1/5, Step 200, Loss: 1.2208
Epoch 1/5, Step 210, Loss: 1.3464
Epoch 1/5, Step 220, Loss: 1.2441
Epoch 1/5, Step 230, Loss: 1.1976
Epoch 1/5, Step 240, Loss: 1.2773
Epoch 1/5, Step 250, Loss: 1.4197
Epoch 1/5, Step 260, 