# Import

In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
from PIL import Image
import random
import io
import imageio

from tensorflow import keras
from keras.layers import*
from keras.models import Sequential, Model
from keras.datasets import mnist
from keras.optimizers import Adam, RMSprop
from keras.applications import VGG19

import warnings
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount('/content/drive')

# Prepare dataset

## Helping functions

In [None]:
# function which will add noise for all images in dataset
def add_noise(dataset, max_disp):
    noise = np.random.uniform(0, max_disp, (40000, 28, 28, 1))
    dataset = (dataset + noise) / (1 + max_disp)
    return dataset

#function to show images
def show_images(examples, n):
    for i in range(n * n):
        plt.subplot(n, n, 1 + i)
        plt.axis('off')
        plt.imshow(examples[i])
    plt.show()

## Load and prepare data

In [None]:
# load x_train mnist
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train[:40000] # I left 40000 because this is the maximum dataset size that free google colab can work with
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_train = x_train.reshape((1, 40000, 28, 28, 1))
datasets = x_train.copy()

max_disp = 0.3
steps = 12

# create dataset "datasets" from noisy images
for i in range(steps):
    per = add_noise(datasets[i], max_disp).reshape((1, 40000, 28, 28, 1))
    datasets = np.concatenate((datasets, per), axis=0)

# look at noisy images from "datasets"
print(f'len dataset: {len(datasets[0])}')
for i in range(steps+1):
    print(f'dataset number {i}')
    show_images(datasets[i], 5)

# Denoiser model

In [None]:
# define denoiser
def create_denoiser():
    model = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(64, (3, 3), activation='relu', padding='same'),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(1, (3, 3), padding='same', activation='sigmoid'),
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

# create list of denoisers
models = [create_denoiser() for _ in range(steps)]
models[0].summary()

In [None]:
# define connection layers
def create_connection_layers():
    connection = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(1, (3, 3), padding='same', activation='sigmoid')
    ])
    return connection

#create list of connection layers
models_con = [create_connection_layers() for _ in range(steps)]
models_con[0].summary()

# Train denoisers and gan_model

In [None]:
# define gan_model
gan_model = Sequential([
    Input(shape=(28, 28, 1)),
])

epochs = 20
batch_size = 64

for i in range(steps):
    # define data in current epoch
    data = datasets[steps - i]
    # define target in current epoch
    target = datasets[steps - i - 1]
    if i % 2 == 1:
        # train denoiser number i
        print(f'############### train denoiser_{i + 1}/{len(models)} ###############')
        models[i].fit(data, target, epochs=epochs, batch_size=batch_size)

        #train layers which will connect 2 denoiser models
        print(f'############### train model to connect denoiser_{i + 1} and denoiser_{i} ###############')

        # define variable gan
        gan_model_per = Sequential([
            Input(shape=(28, 28, 1)),
        ])

        # construct gan
        models[i - 1].trainable = False
        gan_model_per.add(models[i - 1])
        gan_model_per.add(models_con[i - 1])
        models[i].trainable = False
        gan_model_per.add(models[i])

        # compile and train variable gan
        gan_model_per.compile(loss='mse', optimizer='adam')
        gan_model_per.fit(datasets[steps - i], datasets[steps - i + 1], epochs=epochs, batch_size=batch_size)

        # add denoiser his pretrained connection layer
        gan_model.add(models[i - 1])
        gan_model.add(models_con[i - 1])
        gan_model.save_weights('drive/MyDrive/gan_denoiser_1.h5')

        # look at losses
        print(f'gan_model summary: {gan_model.summary()} /ngan_model_per summary: {gan_model_per.summary()}')
    # condition for the first epoch
    elif i % 2 == 0:
        print(f'############### train denoiser_{i + 1}/{len(models)} ###############')

        # train first denoiser
        models[i].fit(data, target, epochs=epochs, batch_size=batch_size)

# compile
gan_model.compile(loss='mse', optimizer='adam')

gan_model.save_weights('drive/MyDrive/gan_denoiser_1.h5')

In [None]:
# make denoisers non-trainable
for i in range(steps):
    models[i].trainable = True
# train conncetion layers
print('############### train connection layers ###############')
gan_model.fit(datasets[-1], datasets[0], epochs=(epochs * 2), batch_size=batch_size)

gan_model.save_weights('drive/MyDrive/gan_denoiser_2.h5')

In [None]:
# final training with
print('############### train full gan_model ###############')

# make denoisers trainable
for i in range(steps):
    models[i].trainable = True
# train full gan_model
#gan_model.fit(datasets[-1], datasets[0], epochs=epochs // 2, batch_size=batch_size)

gan_model.save_weights('drive/MyDrive/gan_denoiser_3.h5')

# Look at results

In [None]:
b = np.random.randn(121 * 28 * 28 * 1)
b = b.reshape(121, 28, 28, 1)
pred = gan_model.predict(b)
print('noise prediction')
show_images(pred, 6)

## Saving model weights

In [None]:
gan_model.save_weights('./denoiser_gan_weights.h5')

# Weights loading and fast model creation

## Fast model creation

In [None]:
def create_denoiser():
    model = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(64, (3, 3), activation='relu', padding='same'),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(1, (3, 3), padding='same', activation='sigmoid'),
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

models = [create_denoiser() for _ in range(steps)]

In [None]:
def create_connection_layers():
    connection = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPool2D( (2, 2), padding='same'),
        BatchNormalization(trainable=True),

        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'),
        Dropout(0.5),

        Conv2D(1, (3, 3), padding='same', activation='sigmoid')
    ])
    return connection

models_con = [create_connection_layers() for _ in range(steps)]

In [None]:
epochs = 20
batch_size = 64

In [None]:
for i in range(steps):
    models[i].trainable = False

gan_model = Sequential([Input(shape=(28, 28, 1)),])

for i in range(0, steps, 2):
    gan_model.add(models[i])
    gan_model.add(models_con[i])

gan_model.compile(loss='mse', optimizer='adam')

## Loading weights

In [None]:
#gan_model.save_weights('./denoiser_gan_weights.h5')
gan_model.load_weights('drive/MyDrive/gan_denoiser_1.h5')