## Setup

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf


image_folder = "4.1/images"
weights_folder = "4.1/weights"
test_folder = "148color.test"

### Generator

In [None]:
class ImagePairDataGenerator(tf.keras.utils.PyDataset):

    def __init__(self, input_files, output_files, batch_size, img_shape):
        self.input_files = input_files
        self.output_files = output_files
        self.batch_size = batch_size
        self.img_shape = img_shape
        self.on_epoch_end()

    def __len__(self):
        return len(self.input_files) // self.batch_size

    def __getitem__(self, index):
        batch_input_files = self.input_files[index * self.batch_size:(index + 1) * self.batch_size]
        batch_output_files = self.output_files[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__data_generation(batch_input_files, batch_output_files)
        return X, y

    def on_epoch_end(self):
        self.indices = np.arange(len(self.input_files))
        np.random.shuffle(self.indices)
        self.input_files = [self.input_files[i] for i in self.indices]
        self.output_files = [self.output_files[i] for i in self.indices]

    def __data_generation(self, batch_input_files, batch_output_files):
        X = np.empty((self.batch_size, *self.img_shape), dtype=np.float32)
        y = np.empty((self.batch_size, *self.img_shape), dtype=np.float32)

        for i, (input_path, output_path) in enumerate(zip(batch_input_files, batch_output_files)):

            input_image = ImagePairDataGenerator.load_and_pp_image(input_path)
            output_image = ImagePairDataGenerator.load_and_pp_image(output_path)

            X[i,] = input_image
            y[i,] = output_image

        return X, y

    @staticmethod
    def load_and_pp_image(img_path: str):
        img = Image.open(img_path)
        img = np.array(img) / 255.0
        img = img.reshape(148, 148, 3).astype('float32')
        return img


# Model

In [None]:
import tensorflow as tf

from tensorflow import keras
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Dense, Input, Conv2D, UpSampling2D, BatchNormalization


class AEModel(Model):
        
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.optimizer = tf.keras.optimizers.Adam(1e-4)

    @tf.function
    def train_step(self, x):
        images, waters = x

        with tf.GradientTape() as tape:
            y_pred = self(waters, training=True)
            loss = keras.losses.mse(images, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(images, y_pred)
        return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}

    @property
    def metrics(self):
        return [self.loss_tracker, self.mae_metric]
    
    @staticmethod
    def get_model():
        x = Input(shape=(148, 148, 3))

        e_conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        pool1 = MaxPooling2D((2, 2), padding='same')(e_conv1)
        batchnorm_1 = BatchNormalization()(pool1)

        e_conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(batchnorm_1)
        pool2 = MaxPooling2D((2, 2), padding='same')(e_conv2)
        batchnorm_2 = BatchNormalization()(pool2)

        e_conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(batchnorm_2)
        h = MaxPooling2D((2, 2), padding='same')(e_conv3)

        d_conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(h)
        up1 = UpSampling2D((2, 2))(d_conv1)

        d_conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(up1)
        up2 = UpSampling2D((2, 2))(d_conv2)

        d_conv3 = Conv2D(32, (3, 3), activation='relu')(up2)
        up3 = UpSampling2D((2, 2))(d_conv3)

        r = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(up3)

        return x, r

filepath = os.path.join(weights_folder, "final.keras")

model = tf.keras.models.load_model(filepath, custom_objects={'AEModel': AEModel})
model.summary()

# Testing

In [None]:
from icecream import ic


files = os.listdir(test_folder)
files.sort()

ic("Number of files:", len(files))

test_images = []
test_waters = []
for i in range(0, len(files), 2):
    test_images.append(os.path.join(test_folder, files[i]))
    test_waters.append(os.path.join(test_folder, files[i+1]))

for test_image, test_water in zip(test_images[:5], test_waters[:5]):
    ic(test_image, test_water)

image_shape = ImagePairDataGenerator.load_and_pp_image(test_images[0]).shape
test_generator = ImagePairDataGenerator(test_images, test_waters, 1, image_shape)

print("Size of test generator", len(test_generator))

In [None]:
images_to_try = 40

plt.close('all')

fig, ax = plt.subplots(images_to_try, 2, figsize=(8,50))
for i in range(images_to_try):
    test_image, test_water = test_generator.__getitem__(i)

    prediction = model(test_water, training=False)
    test_image = np.squeeze (test_image)
    test_water = np.squeeze(test_water)
    prediction = np.squeeze(prediction)

    ax[i][0].imshow(test_water)
    ax[i][1].imshow(prediction)
    ax[i][0].axis('off')
    ax[i][1].axis('off')
