In [None]:
import os
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import tensorflow.image as tf_image
from tensorflow.keras.layers import Dense, Flatten, Reshape, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import EfficientNetV2B0


print("TensorFlow Version:", tf.__version__)


# DATA LOADING AND PREPROCESSING


def load_unlabeled_images(root_folder, image_size=(224, 224)):
   images_list = []
   for split in ["Train", "Test"]:
       split_dir = os.path.join(root_folder, split)
       if not os.path.exists(split_dir):
           print(f"Warning: Directory {split_dir} not found. Skipping...")
           continue
       for fn in os.listdir(split_dir):
           if fn.lower().endswith((".jpg", ".jpeg", ".png")):
               img_path = os.path.join(split_dir, fn)
               try:
                   image = cv2.imread(img_path)
                   image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                   image = cv2.resize(image, image_size)
                   image = image.astype(np.float32) / 255.0  # Normalize
                   images_list.append(image)
               except Exception as e:
                   print(f"[Skip] {img_path}: {e}")
   if len(images_list) == 0:
       raise ValueError("No images were loaded. Check dataset structure.")
   return np.stack(images_list, axis=0)


# MASKING FUNCTION


def random_mask(image, mask_ratio=0.25):
   """Applies random masking to an image (lower ratio to preserve more details)."""
   mask = tf.cast(tf.random.uniform(shape=tf.shape(image), minval=0, maxval=1) > mask_ratio, tf.float32)
   return image * mask


# EFFICIENTNETV2B0-BASED MAE WITH UNET-STYLE DECODER


def build_mae_teacher_efficientnet_unet(input_shape=(224, 224, 3), latent_dim=512):
   """
   Builds a Masked Autoencoder (MAE) using EfficientNetV2B0 as the encoder and a UNet-style decoder with skip connections.
   """
   # Load EfficientNetV2B0 without the classification head
   base_model = EfficientNetV2B0(include_top=False, input_shape=input_shape, weights="imagenet")
   base_model.trainable = True


   # Define skip connections (adjust layer names as needed)
   skip1 = base_model.get_layer("block2a_expand_activation").output  # Expected shape: (None, 56, 56, channels) per error message
   skip2 = base_model.get_layer("block3a_expand_activation").output  # Expected shape: (None, 28, 28, channels)
   skip3 = base_model.get_layer("block4a_expand_activation").output  # Expected shape: (None, 14, 14, channels)


   # Upsample skip connections to match target decoder sizes
   skip2_up = UpSampling2D((2, 2), name="skip2_upsampled")(skip2)   # From 28x28 to 56x56
   skip1_up = UpSampling2D((2, 2), name="skip1_upsampled")(skip1)   # From 56x56 to 112x112


   # Encoder output (bottleneck)
   encoder_output = base_model.output  # Expected shape: (None, 7, 7, ?)


   # Bottleneck: flatten and reduce dimensionality
   x = Flatten()(encoder_output)
   x = Dense(latent_dim, activation="relu")(x)
   # Reproject to a feature map; assuming we start decoding at 7x7 with 256 channels
   x = Dense(7 * 7 * 256, activation="relu")(x)
   x = Reshape((7, 7, 256))(x)


   # Decoder
   # Upsample from 7x7 to 14x14
   x = UpSampling2D((2, 2), name="upsample1")(x)  # Now 14x14
   x = Conv2D(256, (3, 3), activation="relu", padding="same", name="conv_dec1")(x)
  
   # Upsample to 28x28 and concatenate with skip3
   x = UpSampling2D((2, 2), name="upsample2")(x)  # Now 28x28
   x = Concatenate(name="concat_skip3")([x, skip3])
   x = Conv2D(128, (3, 3), activation="relu", padding="same", name="conv_dec2")(x)
  
   # Upsample to 56x56 and concatenate with skip2_up
   x = UpSampling2D((2, 2), name="upsample3")(x)  # Now 56x56
   x = Concatenate(name="concat_skip2")([x, skip2_up])
   x = Conv2D(64, (3, 3), activation="relu", padding="same", name="conv_dec3")(x)
  
   # Upsample to 112x112 and concatenate with skip1_up
   x = UpSampling2D((2, 2), name="upsample4")(x)  # Now 112x112
   x = Concatenate(name="concat_skip1")([x, skip1_up])
   x = Conv2D(32, (3, 3), activation="relu", padding="same", name="conv_dec4")(x)
  
   # Upsample to 224x224
   x = UpSampling2D((2, 2), name="upsample5")(x)  # Now 224x224
   x = Conv2D(16, (3, 3), activation="relu", padding="same", name="conv_dec5")(x)
   outputs = Conv2D(3, (3, 3), activation="sigmoid", padding="same", name="decoder_output")(x)


   return Model(inputs=base_model.input, outputs=outputs)


# PERFORMANCE METRICS


def compute_ssim(original, reconstructed):
   return tf.reduce_mean(tf_image.ssim(original, reconstructed, max_val=1.0))


def compute_psnr(original, reconstructed):
   mse = tf.keras.losses.MeanSquaredError()(original, reconstructed)
   return 10.0 * tf.math.log(1.0 / mse) / tf.math.log(10.0)


# HYBRID LOSS FUNCTION (MSE + SSIM)


def hybrid_loss(original, reconstructed):
   mse = tf.keras.losses.MeanSquaredError()(original, reconstructed)
   ssim_loss = 1 - tf.reduce_mean(tf_image.ssim(original, reconstructed, max_val=1.0))
   return 0.7 * mse + 0.3 * ssim_loss


# MAE TRAINING FUNCTION


def train_mae(root_folder, batch_size=16, epochs=15):
   print("Loading unlabeled images...")
   images = load_unlabeled_images(root_folder)
  
   dataset = tf.data.Dataset.from_tensor_slices(images)
   dataset = dataset.map(lambda x: (random_mask(x), x))
   dataset = dataset.batch(batch_size).shuffle(1000).prefetch(tf.data.AUTOTUNE)


   print("Building MAE model (UNet decoder + EfficientNetV2B0 encoder)...")
   mae_model = build_mae_teacher_efficientnet_unet(input_shape=(224, 224, 3), latent_dim=512)
   mae_model.summary()


   optimizer = Adam(learning_rate=1e-4)


   history = {"loss": [], "ssim": [], "psnr": []}


   print("Starting MAE training...")
   for epoch in range(epochs):
       epoch_loss, epoch_ssim, epoch_psnr = 0, 0, 0
       step_count = 0
       for masked_x, original_x in dataset:
           with tf.GradientTape() as tape:
               reconstructed_x = mae_model(masked_x, training=True)
               loss_val = hybrid_loss(original_x, reconstructed_x)
           gradients = tape.gradient(loss_val, mae_model.trainable_variables)
           optimizer.apply_gradients(zip(gradients, mae_model.trainable_variables))
          
           mse_val = tf.keras.losses.MeanSquaredError()(original_x, reconstructed_x)
           ssim_val = tf.reduce_mean(tf_image.ssim(original_x, reconstructed_x, max_val=1.0))
           psnr_val = 10.0 * tf.math.log(1.0 / mse_val) / tf.math.log(10.0)
          
           epoch_loss += loss_val.numpy()
           epoch_ssim += ssim_val.numpy()
           epoch_psnr += psnr_val.numpy()
           step_count += 1


       avg_loss = epoch_loss / step_count
       avg_ssim = epoch_ssim / step_count
       avg_psnr = epoch_psnr / step_count


       history["loss"].append(avg_loss)
       history["ssim"].append(avg_ssim)
       history["psnr"].append(avg_psnr)


       print(f"Epoch {epoch+1}/{epochs}, Hybrid Loss: {avg_loss:.4f}, SSIM: {avg_ssim:.4f}, PSNR: {avg_psnr:.4f}")


   mae_model.save("mae_teacher_model_improved.h5")
   print("Training complete")
   return mae_model, history, images


# PLOTTING FUNCTIONS


def plot_metrics(history):
   epochs = range(1, len(history["loss"]) + 1)
   plt.figure(figsize=(12, 5))
  
   # Hybrid Loss Curve
   plt.subplot(1, 3, 1)
   plt.plot(epochs, history["loss"], marker='o', label="Hybrid Loss")
   plt.xlabel("Epochs")
   plt.ylabel("Loss")
   plt.title("Loss Curve (MSE+SSIM)")
   plt.legend()
   plt.grid()
  
   # SSIM Curve
   plt.subplot(1, 3, 2)
   plt.plot(epochs, history["ssim"], marker='o', label="SSIM Score", color="green")
   plt.xlabel("Epochs")
   plt.ylabel("SSIM")
   plt.title("SSIM Trend")
   plt.legend()
   plt.grid()
  
   # PSNR Curve
   plt.subplot(1, 3, 3)
   plt.plot(epochs, history["psnr"], marker='o', label="PSNR (dB)", color="red")
   plt.xlabel("Epochs")
   plt.ylabel("PSNR")
   plt.title("PSNR Trend")
   plt.legend()
   plt.grid()
  
   plt.show()
   
   
def plot_reconstructions(model, images, num_samples=5):
   sample_images = images[:num_samples]
   masked_images = np.array([random_mask(img) for img in sample_images])
   reconstructions = model.predict(masked_images)
  
   fig, axes = plt.subplots(num_samples, 3, figsize=(10, num_samples * 3))
   for i in range(num_samples):
       axes[i, 0].imshow(sample_images[i])
       axes[i, 0].set_title("Original")
      
       axes[i, 1].imshow(masked_images[i])
       axes[i, 1].set_title("Masked")
      
       axes[i, 2].imshow(reconstructions[i])
       axes[i, 2].set_title("Reconstructed")
  
   plt.tight_layout()
   plt.show()


# RUN TRAINING & EVALUATION


if __name__ == "__main__":
   mae_model, training_history, dataset_images = train_mae("UnlabeledDataset")
   plot_metrics(training_history)
   plot_reconstructions(mae_model, dataset_images, num_samples=5)