In [None]:
# Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from tensorflow.keras import backend as K

In [None]:
# Directories

In [None]:
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)

In [None]:
(x_train, _), (x_test, _) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

In [None]:
# Visualize sample clean images

In [None]:
plt.figure(figsize=(10,2))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.imshow(x_train[i])
    plt.axis('off')
plt.suptitle("Sample CIFAR-10 Images (Clean)")
plt.show()

In [None]:
# Add Noise

In [None]:
def add_noise(data, mode='gaussian', var=0.01, amount=0.02):
    if mode == 'gaussian':
        noisy = random_noise(data, mode='gaussian', var=var)
    elif mode == 's&p':
        noisy = random_noise(data, mode='s&p', amount=amount)
    else:
        noisy = data
    noisy = np.clip(noisy, 0., 1.)
    return noisy

x_train_noisy = add_noise(x_train, 'gaussian', 0.01)
x_test_noisy = add_noise(x_test, 'gaussian', 0.01)

In [None]:
# Visualize noisy vs clean

In [None]:
plt.figure(figsize=(20,4))
for i in range(10):
    # Noisy
    ax = plt.subplot(2,10,i+1)
    plt.imshow(x_train_noisy[i])
    plt.axis('off')
    if i==0: plt.ylabel("Noisy")
    # Original
    ax = plt.subplot(2,10,i+11)
    plt.imshow(x_train[i])
    plt.axis('off')
    if i==0: plt.ylabel("Clean")
plt.show()

In [None]:
# Autoecoder

In [None]:
def residual_block(x, filters):
    conv = Conv2D(filters, (3,3), padding='same', activation='relu')(x)
    conv = BatchNormalization()(conv)
    conv = Conv2D(filters, (3,3), padding='same', activation='relu')(conv)
    conv = BatchNormalization()(conv)
    out = Add()([x, conv])
    return out

input_img = Input(shape=(32,32,3))

In [None]:
# Encoder

In [None]:
x = Conv2D(64, (3,3), activation='relu', padding='same')(input_img)
x = BatchNormalization()(x)
x = MaxPooling2D((2,2), padding='same')(x)
x = residual_block(x, 64)
x = Conv2D(128, (3,3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
encoded = MaxPooling2D((2,2), padding='same')(x)

In [None]:
# Decoder

In [None]:
x = Conv2D(128, (3,3), activation='relu', padding='same')(encoded)
x = BatchNormalization()(x)
x = UpSampling2D((2,2))(x)
x = residual_block(x, 128)
x = Conv2D(64, (3,3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = UpSampling2D((2,2))(x)
decoded = Conv2D(3, (3,3), activation='sigmoid', padding='same')(x)

In [None]:
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer=Adam(0.001), loss='mse')
autoencoder.summary()