#### Parameters

In [1]:
#PARAMETERS
num_of_testimages = 100 #small number just to get some images to look at the result
mask_size = 9 #Maks size for the masking of the images
mask_maxsize = 14 #Max size of the mask
mask_minsize = 8 #Min size of the mask
ratio_of_dataset = 0.7

Loading the dataset and cropping it

In [None]:
from keras.datasets import mnist,fashion_mnist
import matplotlib.pyplot as plt
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import importlib
from tensorflow.keras.callbacks import EarlyStopping


import importlib
import expansion
importlib.reload(expansion)

#(train_X, train_y), (test_X, test_y) = mnist.load_data()
(train_X, train_y), (test_X, test_y) = fashion_mnist.load_data()
#Takes a ratio of the dataset
train_X = train_X[0:int(ratio_of_dataset*len(train_X))]
test_X = test_X[0:int(ratio_of_dataset*len(test_X))]
#Normalize values 0 to 1
train_X = train_X/255
test_X = test_X/255

np.random.shuffle(test_X) #Get randomness in the test images later

train_X = np.expand_dims(train_X, axis=-1)  # Shape: (60000, 28, 28, 1)
valid_X = np.expand_dims(test_X[num_of_testimages:10000], axis=-1) # All the images for validation
test_X = np.expand_dims(test_X[0:num_of_testimages], axis=-1)    # The test images, to visualize later

print('X_train: ' + str(train_X.shape))
print('valid_X:  '  + str(valid_X.shape))
print('test_X '  + str(test_X.shape))

expansion.image_grid(test_X, name="Images before crop")

#Mask with center mask with fixed size
#train_X_crop = center_mask(train_X, mask_size)
#valid_X_crop = center_mask(valid_X, mask_size)
#test_X_crop = center_mask(test_X, mask_size)

#Mask with random mask position and size
train_X_crop, train_mask = expansion.random_mask(train_X, mask_minsize, mask_maxsize)
valid_X_crop, valid_mask = expansion.random_mask(valid_X, mask_minsize, mask_maxsize)
test_X_crop, test_mask = expansion.random_mask(test_X, mask_minsize, mask_maxsize)

expansion.image_grid(test_X_crop, name="After crop")

In [None]:
#PARAMETRAR FÖR TRÄNING
runs = 4 #Number of runs to calculate the avarage over    5
epochs = 40 #Number of epoches for all the models, so they use same number of epochs, we have early stopping if this is too much   30

# Models

### baseline

In [None]:
SSIM_results = []
mse_results = []
for n in range(runs):
  K.clear_session()

  batch_size = 128

  image_input = layers.Input(shape=train_X_crop.shape[1:], name="image_input")

  # Encoder
  conv1 = layers.Conv2D(8, (3, 3), activation='relu', padding='same',strides=(2,2))(image_input)
  conv1 = layers.BatchNormalization()(conv1)

  conv2 = layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=(2,2))(conv1)
  conv2 = layers.BatchNormalization()(conv2)

  # Bottleneck
  conv3 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
  conv3 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv3)

  # Decoder
  up1 = layers.UpSampling2D((2, 2))(conv3)
  conv4 = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up1)

  up2 = layers.UpSampling2D((2, 2))(conv4)
  conv5 = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(up2)

  # Output layer
  outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(conv5) #Sigmoid for [0, 1] pixel values

  # Model
  model = keras.models.Model(inputs=image_input, outputs=outputs)
  model.summary()

  model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['mse'])

  early_stopping = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

  # Train the model
  history = model.fit(
      train_X_crop,
      train_X,
      epochs=epochs,
      batch_size=batch_size,
      validation_data=(valid_X_crop, valid_X),
      callbacks=[early_stopping]
      )

  # Predict and visualize
  predicted_images = model.predict(test_X_crop)
  # Calculate SSIM
  ssim_score = expansion.calculate_ssim(test_X, predicted_images)
  print("SSIM Score:", ssim_score)
  SSIM_results.append(ssim_score)

  # Calculate MSE
  mse_test_images = np.mean((test_X - predicted_images) ** 2)
  print("MSE Score:", mse_test_images)
  mse_results.append(mse_test_images)

  expansion.image_grid(test_X_crop, name="Cropped images")
  expansion.image_grid(test_X, name="Original images")
  expansion.image_grid(predicted_images, name="Predicted images")

print(SSIM_results)
print(mse_results)

print("Average SSIM Score over ",runs," runs: ", np.mean(SSIM_results))
print("Average MSE over ",runs," runs: ", np.mean(mse_results))

### Improved autoencoder

In [None]:
SSIM_results = []
mse_results = []
for n in range(runs):
  K.clear_session()

  batch_size = 128

  image_input = layers.Input(shape=train_X_crop.shape[1:], name="image_input")
  mask_input = layers.Input(shape=train_mask.shape[1:], name="mask_input")
  combined_input = layers.concatenate([image_input, mask_input], axis=-1)

  # Encoder
  conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(combined_input)
  conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same',strides=(2,2))(conv1)
  conv1 = layers.BatchNormalization()(conv1)

  conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv1)
  conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same', strides=(2,2))(conv2)
  conv2 = layers.BatchNormalization()(conv2)

  # Bottleneck
  conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv2)
  conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)

  # Decoder
  up1 = layers.UpSampling2D((2, 2))(conv3)
  conv4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up1)
  conv4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv4)


  up2 = layers.UpSampling2D((2, 2))(conv4)
  conv5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up2)
  conv5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv5)

  # Output layer
  outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(conv5)  # Sigmoid for [0, 1] pixel values

  # Model
  model = keras.models.Model(inputs=[image_input, mask_input], outputs=outputs)
  model.summary()

  model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['mse'])

  early_stopping = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

  # Train the model
  history = model.fit(
      [train_X_crop, train_mask],
      train_X,
      epochs=epochs,
      batch_size=batch_size,
      validation_data=([valid_X_crop, valid_mask], valid_X),
      callbacks=[early_stopping]
      )

  # Predict and visualize
  predicted_images = model.predict([test_X_crop, test_mask])

  # Calculate SSIM
  ssim_score = expansion.calculate_ssim(test_X, predicted_images)
  print("SSIM Score:", ssim_score)
  SSIM_results.append(ssim_score)

  # Calculate MSE
  mse_test_images = np.mean((test_X - predicted_images) ** 2)
  print("MSE Score:", mse_test_images)
  mse_results.append(mse_test_images)

  expansion.image_grid(test_X_crop, name="Cropped images")
  expansion.image_grid(test_X, name="Original images")
  expansion.image_grid(predicted_images, name="Predicted images")

print("Average SSIM Score over ",runs," runs: ", np.mean(SSIM_results))
print("Average MSE over ",runs," runs: ", np.mean(mse_results))

### U-Net autoencoder

In [None]:
SSIM_results = []
mse_results = []
for n in range(runs):
  K.clear_session()

  batch_size = 128

  image_input = layers.Input(shape=train_X_crop.shape[1:], name="image_input")
  mask_input = layers.Input(shape=train_mask.shape[1:], name="mask_input")

  combined_input = layers.concatenate([image_input, mask_input], axis=-1)

  # Encoder
  conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(combined_input)
  pool1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same',strides=(2,2))(conv1)
  pool1 = layers.BatchNormalization()(pool1)

  conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
  pool2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same', strides=(2,2))(conv2)
  pool2 = layers.BatchNormalization()(pool2)

  # Bottleneck
  conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
  conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)

  # Decoder
  up1 = layers.UpSampling2D((2, 2))(conv3)
  concat1 = layers.concatenate([up1, conv2], axis=-1)  # Skip connection
  conv4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(concat1)
  conv4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv4)


  up2 = layers.UpSampling2D((2, 2))(conv4)
  concat2 = layers.concatenate([up2, conv1], axis=-1)  # Skip connection
  conv5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(concat2)
  conv5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv5)


  # Output layer
  outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(conv5)  # Sigmoid for [0, 1] pixel values

  # Model
  model = keras.models.Model(inputs=[image_input, mask_input], outputs=outputs)
  model.summary()

  model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['mse'])

  early_stopping = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

  # Train the model
  history = model.fit(
      [train_X_crop, train_mask],
      train_X,
      epochs=epochs,
      batch_size=batch_size,
      validation_data=([valid_X_crop, valid_mask], valid_X),
      callbacks=[early_stopping]
      )


  # Predict and visualize
  predicted_images = model.predict([test_X_crop, test_mask])

  # Calculate SSIM
  ssim_score = expansion.calculate_ssim(test_X, predicted_images)
  print("SSIM Score:", ssim_score)
  SSIM_results.append(ssim_score)

  # Calculate MSE
  mse_test_images = np.mean((test_X - predicted_images) ** 2)
  print("MSE Score:", mse_test_images)
  mse_results.append(mse_test_images)

  expansion.image_grid(test_X_crop, name="Cropped images")
  expansion.image_grid(test_X, name="Original images")
  expansion.image_grid(predicted_images, name="Predicted images")

print("Average SSIM Score over ",runs," runs: ", np.mean(SSIM_results))
print("Average MSE over ",runs," runs: ", np.mean(mse_results))