## Setup

In [None]:
from IPython import display

import os
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
from PIL import Image
from pathlib import Path
import tensorflow as tf

import time
from sklearn.model_selection import train_test_split

image_folder = "4.1/images"
weights_folder = "4.1/weights"

train_folder = "148color.train"
anim_file = "4.1.training.gif"

epochs = 500
save_every_n_epoch = 150

In [None]:
Path(image_folder).mkdir(exist_ok=True, parents=True)
Path(weights_folder).mkdir(exist_ok=True)

## Load the dataset

In [None]:
files = os.listdir(train_folder)
files.sort()

print("Number of files:", len(files))

file_pairs = []
for i in range(0, len(files), 2):
    file_pairs.append((os.path.join(train_folder, files[i]), os.path.join(train_folder, files[i+1])))


# ######
# Get image data and demo normalization
# ######


train_pairs, validation_pairs = train_test_split(file_pairs, test_size=0.2, random_state=42)

print("Train pairs:", len(train_pairs))
print("Validation pairs:", len(validation_pairs))

# ######
# x - images
# y - watermarked images

train_images = []
train_waters = []
val_images = []
val_waters = []

for image_path, watermarked_image_path in train_pairs:
    train_images.append(image_path)
    train_waters.append(watermarked_image_path)

for image_path, watermarked_image_path in validation_pairs:
    val_images.append(image_path)
    val_waters.append(watermarked_image_path)

# ######
# generator

class ImagePairDataGenerator(tf.keras.utils.PyDataset):

    def __init__(self, input_files, output_files, batch_size, img_shape):
        self.input_files = input_files
        self.output_files = output_files
        self.batch_size = batch_size
        self.img_shape = img_shape
        self.on_epoch_end()

    def __len__(self):
        return len(self.input_files) // self.batch_size

    def __getitem__(self, index):
        batch_input_files = self.input_files[index * self.batch_size:(index + 1) * self.batch_size]
        batch_output_files = self.output_files[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__data_generation(batch_input_files, batch_output_files)
        return X, y

    def on_epoch_end(self):
        self.indices = np.arange(len(self.input_files))
        np.random.shuffle(self.indices)
        self.input_files = [self.input_files[i] for i in self.indices]
        self.output_files = [self.output_files[i] for i in self.indices]

    def __data_generation(self, batch_input_files, batch_output_files):
        X = np.empty((self.batch_size, *self.img_shape), dtype=np.float32)
        y = np.empty((self.batch_size, *self.img_shape), dtype=np.float32)

        for i, (input_path, output_path) in enumerate(zip(batch_input_files, batch_output_files)):

            input_image = ImagePairDataGenerator.load_and_pp_image(input_path)
            output_image = ImagePairDataGenerator.load_and_pp_image(output_path)

            X[i,] = input_image
            y[i,] = output_image

        return X, y

    @staticmethod
    def load_and_pp_image(img_path: str):
        img = Image.open(img_path)
        img = np.array(img) / 255.0
        img = img.reshape(148, 148, 3).astype('float32')
        return img


# Image generation

In [None]:
batch_size = 4
image_shape = ImagePairDataGenerator.load_and_pp_image(file_pairs[0][0]).shape

train_generator = ImagePairDataGenerator(train_images, train_waters, batch_size, image_shape)
val_generator = ImagePairDataGenerator(val_images, val_waters, batch_size, image_shape)

num_examples_to_generate = 4


def generate_and_save_images(model, epoch, test_sample):
    predictions = model(test_sample, training=False)
    
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0])
        plt.axis('off')

    # tight_layout minimizes the overlap between 2 sub-plots
    plt.savefig(f'{image_folder}/image_at_epoch_{epoch:04d}.png')
    plt.show()


# Pick a sample of the test set for generating output images
assert batch_size >= num_examples_to_generate

plt.close('all')

test_batch_imgs, test_batch_waters = val_generator.__getitem__(0)
test_sample = test_batch_waters[0:num_examples_to_generate]


# Autoencoder

In [None]:
from tensorflow import keras
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Dense, Input, Conv2D, UpSampling2D, BatchNormalization


class AEModel(Model):
        
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.optimizer = tf.keras.optimizers.Adam(1e-4)

    @tf.function
    def train_step(self, x):
        images, waters = x

        with tf.GradientTape() as tape:
            y_pred = self(waters, training=True)
            loss = keras.losses.mse(images, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(images, y_pred)
        return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}

    @property
    def metrics(self):
        return [self.loss_tracker, self.mae_metric]
    
    @staticmethod
    def get_model():
        x = Input(shape=(148, 148, 3))

        e_conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        pool1 = MaxPooling2D((2, 2), padding='same')(e_conv1)
        batchnorm_1 = BatchNormalization()(pool1)

        e_conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(batchnorm_1)
        pool2 = MaxPooling2D((2, 2), padding='same')(e_conv2)
        batchnorm_2 = BatchNormalization()(pool2)

        e_conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(batchnorm_2)
        h = MaxPooling2D((2, 2), padding='same')(e_conv3)

        d_conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(h)
        up1 = UpSampling2D((2, 2))(d_conv1)

        d_conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(up1)
        up2 = UpSampling2D((2, 2))(d_conv2)

        d_conv3 = Conv2D(32, (3, 3), activation='relu')(up2)
        up3 = UpSampling2D((2, 2))(d_conv3)

        r = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(up3)

        return x, r

x, r = AEModel.get_model()

model = AEModel(x, r)
model.summary()

# Training

In [None]:
generate_and_save_images(model, 0, test_sample)

fig = plt.figure(figsize=(4, 4))
for i in range(test_sample.shape[0]):
    plt.subplot(4, 4, i + 1)
    img = test_sample[i, :, :, 0]
    plt.imshow(img)
    plt.axis('off')
plt.show()

In [None]:
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for batch in range(len(train_generator)):
        train_x = train_generator.__getitem__(batch)
        history = model.train_step(train_x)

    end_time = time.time()

    display.clear_output(wait=False)
    
    print('Epoch: {}, MAE: {}, Time: {}'.format(epoch, history['mae'], end_time - start_time))
    
    generate_and_save_images(model, epoch, test_sample)
    if epoch % save_every_n_epoch == 0:
        model.save(f'{weights_folder}/epoch_{epoch:03d}.keras')

model.save(f'{weights_folder}/final.keras')

### Display a generated image from the last training epoch

In [None]:
def display_image(epoch_no):
  return PIL.Image.open(f'{image_folder}/image_at_epoch_{epoch_no:04d}.png')

In [None]:
plt.imshow(display_image(epoch))
plt.axis('off')  # Display images

fig = plt.figure(figsize=(4, 4))
for i in range(test_sample.shape[0]):
    plt.subplot(4, 4, i + 1)
    img = test_sample[i, :, :, 0]
    plt.imshow(img)
    plt.axis('off')
plt.show()

### Display an animated GIF of all the saved images

In [None]:
with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob(f'{image_folder}/*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

In [None]:
import tensorflow_docs.vis.embed as embed

embed.embed_file(anim_file)