# Encoder

This notebook trains an autoencoder with all the cropped cells generated in crop_cell_segmentation.ipynb which encoder will be later used in supervised.ipynb

### Imports

In [None]:
import os
import sys
import random
import itertools

import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers
from keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential, Model, clone_model
from tensorflow.keras.optimizers import Adam
from keras.optimizers import Adam

### Definitions

In [None]:
sys.path.insert(0, "../../")
from config import CROPPED_PATH, MODELS_PATH

# Paths
CROPS_PATH = os.path.join(CROPPED_PATH, 'ina', 'images')
ENCODER_PATH = os.path.join(MODELS_PATH, 'encoder', 'encoder_SSIM_MAE_Bparams.keras')

# Configuration
SHAPE = (128,128,1)
BATCH_SIZE = 300
VALIDITAION_SPLIT = 0.2
ON_RAM = False


### Load dataset

In [None]:

if ON_RAM:

    images = np.zeros((len(os.listdir(CROPS_PATH)), SHAPE[0], SHAPE[1]))
    for idx, file in tqdm(enumerate(os.listdir(CROPS_PATH)), total=len(os.listdir(CROPS_PATH))):
        image = cv2.imread(os.path.join(CROPS_PATH, file), cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (SHAPE[0], SHAPE[1]))
        image = np.array(image) #

        images[idx, : ,:] = image


    from sklearn.model_selection import train_test_split
    x_train, x_test = train_test_split(images, test_size=VALIDITAION_SPLIT, random_state=42)
    x_train.shape

else:

    image_paths = os.listdir(CROPS_PATH)

    # Train / val split
    val_size = int(0.2 * len(image_paths))  # 20% para validación
    train_paths = image_paths[:-val_size]
    val_paths = image_paths[-val_size:]

    train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
    val_dataset = tf.data.Dataset.from_tensor_slices(val_paths)

    # Load & augmentation
    def load_image(path):

      image = tf.io.read_file(path)
      image = tf.image.decode_png(image, channels=1)  # Escala de grises
      image = tf.image.resize(image, (SHAPE[0], SHAPE[1]))
      # Augmentación

      image = tf.image.random_flip_left_right(image)
      image = tf.image.random_flip_up_down(image)
      image = tf.image.random_brightness(image, max_delta=0.1)
      image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
      scales = tf.random.uniform([], 0.8, 1.0)
      crop_size = tf.cast(scales * SHAPE[0], tf.int32)
      image = tf.image.random_crop(image, size=[crop_size, crop_size, 1])
      image = tf.image.resize(image, (SHAPE[0], SHAPE[1]))
      # Normalize [0, 1]
      image = tf.cast(image, tf.float32) / 255.0
      return image,image

    # Train dataset
    train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
    train_dataset = train_dataset.map(lambda x: load_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = train_dataset.shuffle(buffer_size=10000)
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
    train_dataset = train_dataset.repeat()

    # Val dataset
    validation_dataset = val_dataset.map(lambda x: load_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    validation_dataset = validation_dataset.batch(BATCH_SIZE)
    validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)
    validation_dataset = validation_dataset.repeat()



In [None]:
len(train_paths)

42915

### Weigthed Loss: SSIM + MAE

In [None]:
class CustomLoss(tf.keras.losses.Loss):
    def __init__(self, y, z, name="custom_loss"):
        super().__init__(name=name)
        # Weights
        self.y = y
        self.z = z

        #MAE
        self.mae_loss_fn = tf.keras.losses.MeanAbsoluteError()

    # SSIM loss
    def ssim_loss(self, y_true, y_pred):

        ssim = (1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))) / 2
        return ssim

    def call(self, y_true, y_pred):

        mae_loss = self.mae_loss_fn(y_true, y_pred)
        ssim = self.ssim_loss(y_true, y_pred)


        return  self.y * mae_loss + self.z * ssim

### AutoEncoder

In [None]:
kernel_size = (4,4)
filter = 32

# Encoder
encoder_input = layers.Input(shape=(128, 128, 1))

x = encoder_input

x = layers.Rescaling(1.0 / 255.0)(x)


x = layers.Conv2D(filter, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2, 2), padding='same')(x)

x = layers.Conv2D(filter//2, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2, 2), padding='same')(x)

x = layers.Conv2D(filter//4, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2, 2), padding='same')(x)

x = layers.Conv2D(filter//8, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2, 2), padding='same')(x)

x = layers.Conv2D(filter//16, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D((2, 2), padding='same')(x)

x = layers.Flatten()(x)

encoder_output = x


encoder = Model(encoder_input, encoder_output)

# Decoder

decoder_input = layers.Input(shape=[encoder.output_shape[-1]])

x = decoder_input

x = layers.Reshape((2*filter//16, 2*filter//16, filter//16))(x)


x = layers.Conv2D(filter//16, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

x = layers.Conv2D(filter//8, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

x = layers.Conv2D(filter//4, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

x = layers.Conv2D(filter//2, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

x = layers.Conv2D(filter, kernel_size, activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

x = layers.Conv2D(1, kernel_size, activation='sigmoid', padding='same')(x)

#x = layers.Rescaling(255.0)(x)

decoder_output = x

decoder = Model(decoder_input, decoder_output)

# Build the autoencoder model
# autoencoder = Model(encoder_input, decoder_output)
autoencoder = Sequential([
    encoder,
    decoder
])

In [None]:
encoder.summary()

In [None]:
decoder.summary()

In [None]:
# Summary of the model
autoencoder.summary()

### Compile & fit - Grid search

In [None]:
# Weight range from 0.2 to 1
y_values = np.arange(0.2, 1.2, 0.2)
z_values = np.arange(0.2, 1.2, 0.2)

# Combinations of weights
param_combinations = list(itertools.product(y_values,z_values))

random.seed(34)
random.shuffle(param_combinations)

# Normalize weights ---> y + z = 1
def normalize_weights(weights):
    total_sum = sum(weights)
    return tuple(weight / total_sum for weight in weights)

param_combinations = [normalize_weights(weights) for weights in param_combinations]
param_combinations =set(param_combinations) #delete duplicate
print("Total combinations: ", len(param_combinations))
for i,combination in enumerate(param_combinations):
    print(f"Combination {i+1}: ", tuple(map(float, combination)))

Total combinations:  19
Combination 1:  (0.25, 0.7500000000000001)
Combination 2:  (0.37500000000000006, 0.625)
Combination 3:  (0.3333333333333333, 0.6666666666666666)
Combination 4:  (0.7142857142857143, 0.28571428571428575)
Combination 5:  (0.4285714285714286, 0.5714285714285714)
Combination 6:  (0.4444444444444445, 0.5555555555555556)
Combination 7:  (0.5555555555555556, 0.4444444444444445)
Combination 8:  (0.7500000000000001, 0.25)
Combination 9:  (0.5, 0.5)
Combination 10:  (0.28571428571428575, 0.7142857142857143)
Combination 11:  (0.5714285714285714, 0.4285714285714286)
Combination 12:  (0.625, 0.37500000000000006)
Combination 13:  (0.8333333333333334, 0.16666666666666669)
Combination 14:  (0.6000000000000001, 0.4)
Combination 15:  (0.6666666666666666, 0.3333333333333333)
Combination 16:  (0.2, 0.8)
Combination 17:  (0.4, 0.6000000000000001)
Combination 18:  (0.8, 0.2)
Combination 19:  (0.16666666666666669, 0.8333333333333334)


In [None]:
def evaluate_model(autoencoder, train_gen, val_gen, params, steps_per_epoch, val_steps):
    y, z = [np.round(arr, 2) for arr in params]

    loss_fn = CustomLoss(y=y, z=z)
    model = clone_model(autoencoder)
    model.compile(
        loss=loss_fn,
        optimizer=Adam(learning_rate=1e-3)
    )

    early_stop = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=0
    )

    history=model.fit(
        train_gen,
        steps_per_epoch=steps_per_epoch,
        epochs=20,
        validation_data=val_gen,
        validation_steps=val_steps,
        callbacks=[early_stop],
        verbose=1
        )

    val_loss = min(history.history['val_loss'])
    return val_loss


best_params = None
best_loss = float('inf')


steps_per_epoch = int(len(train_paths) // BATCH_SIZE)
steps_per_epoch_val = int(len(val_paths) // BATCH_SIZE)


# Evaluate Autoencoder for all the combinations of weights
for params in param_combinations:
    try:
        val_loss = evaluate_model(autoencoder, train_dataset, validation_dataset, params, steps_per_epoch, steps_per_epoch_val)
        print(f"Params {params} -> Val Loss: {val_loss}")

        if val_loss < best_loss:
            best_loss = val_loss
            best_params = params
    except Exception as e:
        print(f"{e}")



Epoch 1/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m829s[0m 527ms/step - loss: 0.2713 - val_loss: 0.2570
Epoch 2/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 515ms/step - loss: 0.2291 - val_loss: 0.2496
Epoch 3/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 528ms/step - loss: 0.2255 - val_loss: 0.2476
Epoch 4/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 528ms/step - loss: 0.2224 - val_loss: 0.2387
Epoch 5/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 524ms/step - loss: 0.2218 - val_loss: 0.2141
Epoch 6/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 521ms/step - loss: 0.2204 - val_loss: 0.2028
Epoch 7/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 519ms/step - loss: 0.2196 - val_loss: 0.1983
Epoch 8/20
[1m137/137[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 513ms/step - loss: 0.2189 - val_loss: 0.2089
Epoch 9/20
[1m

In [None]:
# Resultados finales
print(f"Best Params: {best_params} -> Best Val Loss: {best_loss}")

Best Params: (np.float64(0.8333333333333334), np.float64(0.16666666666666669)) -> Best Val Loss: 0.10360830277204514


### Compile y fit: Best Params

In [None]:
loss_fn = CustomLoss(y=0.83, z=0.17)

steps_per_epoch = int(len(train_paths) // BATCH_SIZE)
steps_per_epoch_val = int(len(val_paths) // BATCH_SIZE)

autoencoder.compile(
    loss=loss_fn,
    optimizer=Adam(learning_rate=1e-3)
)
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=0
)

autoencoder.fit(
    train_dataset,
    steps_per_epoch=steps_per_epoch,
    epochs=100,
    validation_data=validation_dataset,
    validation_steps=steps_per_epoch_val,
    callbacks=[early_stop],
    verbose=1
)

In [None]:
encoder.save(ENCODER_PATH)

### Results

In [None]:
metrics = pd.DataFrame(autoencoder.history.history)
metrics[['loss', 'val_loss']].plot()

In [None]:
validation_images = validation_dataset.take(1)


for images, _ in validation_images:
    shuffled_images = tf.random.shuffle(images)
    first_three_images = shuffled_images[:3]
    break

processed_images = []
for img in first_three_images:
    img_with_batch = tf.expand_dims(img, axis=0)
    processed_image = autoencoder(img_with_batch).numpy()
    processed_images.append(processed_image.squeeze())

#Plot
n = 0
plt.figure(dpi=150)

plt.subplot(3, 2, 1)
plt.imshow(first_three_images[n], cmap='gray')  # Mostrar la imagen original
plt.axis(False)
plt.title(f"Original Image {n+1}")

plt.subplot(3, 2, 2)
plt.imshow(processed_images[n], cmap='gray')  # Mostrar la imagen procesada
plt.axis(False)
plt.title(f"Processed Image {n+1}")


n = 1
plt.subplot(3, 2, 3)
plt.imshow(first_three_images[n], cmap='gray')
plt.axis(False)
plt.title(f"Original Image {n+1}")

plt.subplot(3, 2, 4)
plt.imshow(processed_images[n], cmap='gray')
plt.axis(False)
plt.title(f"Processed Image {n+1}")


n = 2
plt.subplot(3, 2, 5)
plt.imshow(first_three_images[n], cmap='gray')
plt.axis(False)
plt.title(f"Original Image {n+1}")

plt.subplot(3, 2, 6)
plt.imshow(processed_images[n], cmap='gray')
plt.axis(False)
plt.title(f"Processed Image {n+1}")

plt.tight_layout()
plt.show()