In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
import random
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


print("TensorFlow Version:", tf.__version__)


# DATA LOADING
def load_unlabeled_images(root_folder, image_size=(224, 224)):
   """
   Loads all images from the 'UnlabeledDataset' directory for self-supervised learning.
  
   Args:
       root_folder (str): Path to the folder containing 'Train' and 'Test' subdirectories.
       image_size (tuple): Target size for resizing images.
  
   Returns:
       np.array: Array of loaded and normalized images.
   """
   images_list = []
   for split in ["Train", "Test"]:
       split_dir = os.path.join(root_folder, split)
       if not os.path.exists(split_dir):
           print(f"Warning: Directory {split_dir} not found. Skipping...")
           continue
       for fn in os.listdir(split_dir):
           if fn.lower().endswith((".jpg", ".jpeg", ".png")):
               img_path = os.path.join(split_dir, fn)
               try:
                   image = cv2.imread(img_path)
                   image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                   image = cv2.resize(image, image_size)
                   image = image.astype(np.float32) / 255.0  # Normalize to [0,1]
                   images_list.append(image)
               except Exception as e:
                   print(f"[Skip] {img_path}: {e}")
   if len(images_list) == 0:
       raise ValueError("No images were loaded. Please check the dataset structure.")
   return np.stack(images_list, axis=0)


def random_augment(image):
   """
   Applies a series of random augmentations for contrastive learning.
  
   Args:
       image (tf.Tensor): Input image tensor.
  
   Returns:
       tf.Tensor: Augmented image.
   """
   image = tf.image.random_flip_left_right(image)
   image = tf.image.random_brightness(image, max_delta=0.4)
   image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
   image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
   image = tf.image.rot90(image, k=random.randint(0, 3))
   image = tf.image.random_crop(image, size=[180, 180, 3])
   image = tf.image.resize(image, (224, 224))
   return image


# ------------------------------
# ENCODER EXTRACTION
# ------------------------------
def extract_encoder(mae_model):
   """
   Extracts the encoder part from a trained MAE teacher model.
  
   For the new MAEv5 architecture, we select the output of the layer named "dense"
   (which produces the 512-dimensional latent vector) as the feature representation.
  
   Args:
       mae_model (tf.keras.Model): Pre-trained MAE teacher model.
      
   Returns:
       tf.keras.Model: Encoder model.
   """
   encoder_output = mae_model.get_layer("dense").output  # Use the layer named "dense"
   encoder = Model(inputs=mae_model.input, outputs=encoder_output)
   return encoder


# CONTRASTIVE LOSS FUNCTION (NT-Xent)


def nt_xent_loss(z_i, z_j, temperature=0.5):
   """
   Computes the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss.
  
   Args:
       z_i (tf.Tensor): Feature vector from augmented view 1.
       z_j (tf.Tensor): Feature vector from augmented view 2.
       temperature (float): Temperature scaling parameter.
      
   Returns:
       tf.Tensor: NT-Xent loss.
   """
   # Reshape to [batch, feature_dim]
   z_i = tf.reshape(z_i, [tf.shape(z_i)[0], -1])
   z_j = tf.reshape(z_j, [tf.shape(z_j)[0], -1])
  
   # Normalize features
   z_i = tf.math.l2_normalize(z_i, axis=1)
   z_j = tf.math.l2_normalize(z_j, axis=1)
  
   # Compute cosine similarity matrix scaled by temperature
   logits = tf.matmul(z_i, tf.transpose(z_j)) / temperature
   labels = tf.range(tf.shape(z_i)[0])  # Each example should match itself
  
   loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
   return loss


# SIMCLR FINE-TUNING (CONTRASTIVE LEARNING)


def fine_tune_with_simclr(mae_model_path, root_folder, batch_size=32, epochs=5):
   """
   Fine-tunes the MAE encoder using SimCLR contrastive learning.
  
   Args:
       mae_model_path (str): Path to the pre-trained MAE teacher model file.
       root_folder (str): Path to the dataset containing unlabeled images.
       batch_size (int): Batch size for training.
       epochs (int): Number of training epochs.
  
   Returns:
       tf.keras.Model: The fine-tuned encoder model.
   """
   print("Loading pre-trained MAE model...")
   mae_model = tf.keras.models.load_model(mae_model_path)
   encoder = extract_encoder(mae_model)
   print("Encoder successfully extracted!")
  
   print("Loading unlabeled images for SimCLR training...")
   images = load_unlabeled_images(root_folder)
  
   # Create a dataset of augmented image pairs
   dataset = tf.data.Dataset.from_tensor_slices(images)
   dataset = dataset.map(lambda x: (random_augment(x), random_augment(x)))
   dataset = dataset.batch(batch_size).shuffle(500).prefetch(tf.data.AUTOTUNE)
  
   optimizer = Adam(learning_rate=0.001)
  
   print("Starting SimCLR fine-tuning...")
   for epoch in range(epochs):
       epoch_loss = 0
       steps = 0
       for step, (x_i, x_j) in enumerate(dataset):
           with tf.GradientTape() as tape:
               z_i = encoder(x_i, training=True)
               z_j = encoder(x_j, training=True)
               loss = nt_xent_loss(z_i, z_j)
           gradients = tape.gradient(loss, encoder.trainable_variables)
           optimizer.apply_gradients(zip(gradients, encoder.trainable_variables))
           epoch_loss += tf.reduce_mean(loss).numpy()
           steps += 1
           if step % 10 == 0:
               print(f"Epoch {epoch+1}/{epochs}, Step {step}, Loss: {tf.reduce_mean(loss).numpy():.4f}")
       avg_loss = epoch_loss / steps
       print(f"Epoch {epoch+1}/{epochs} completed. Avg SimCLR Loss: {avg_loss:.4f}")
  
   # Save the fine-tuned encoder using the native Keras format (avoid HDF5 issues)
   encoder.save("ssl_teacher_model.keras")
   print("Fine-tuning complete! Encoder saved as 'ssl_teacher_model.keras'.")
  
   return encoder


# RUN FINE-TUNING
if __name__ == "__main__":
   root_folder = "UnlabeledDataset"  # Path to unlabeled images
   mae_model_path = "mae_teacher_model_improved.h5"  # Path to improved MAE teacher model
   ssl_teacher = fine_tune_with_simclr(mae_model_path, root_folder)




