In [1]:

from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.datasets import cifar10
import numpy as np


Using TensorFlow backend.
  return f(*args, **kwds)


In [2]:

LOG_DIR = "logs/keras_autoencoders"
KERAS_WEIGHTS = "keras_weights/weights"


In [3]:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train[np.where(y_train==1)[0],:,:,:]
x_test = x_test[np.where(y_test==1)[0],:,:,:]


In [4]:

# Normalize and COnvert the data
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

# Add Gaussian Noise
x_train_n = x_train + 0.5 * np.random.normal(loc=0.0, scale=0.4, size=x_train.shape)
x_test_n = x_test + 0.5 * np.random.normal(loc=0.0, scale=0.4, size=x_test.shape)

x_train_n = np.clip(x_train_n, 0., 1.)
x_test_n = np.clip(x_test_n, 0., 1.)


In [5]:

# 32X32 pixel x 3 RGB values
inp_img = Input(shape=(32, 32, 3))

# Model
img = Conv2D(32, (3, 3), activation='relu', padding='same')(inp_img)
img = MaxPooling2D((2, 2), padding='same')(img)
img = Conv2D(32, (3, 3), activation='relu', padding='same')(img)
img = UpSampling2D((2, 2))(img)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(img)

autoencoder = Model(inp_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')


In [6]:

tensorboard = TensorBoard(log_dir=LOG_DIR, histogram_freq=0, 
                         write_graph=True, write_images=True)

model_saver = ModelCheckpoint(filepath=KERAS_WEIGHTS,verbose=0, period=2)

autoencoder.fit(x_train_n, x_train, 
                epochs=10, 
                batch_size=64,
                shuffle=True, 
                validation_data=(x_test_n, x_test),
                callbacks=[tensorboard, model_saver])


Train on 5000 samples, validate on 1000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f79455f1fd0>