In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import metrics

import sys
from keras.callbacks import ModelCheckpoint
import os.path
from keras.preprocessing.image import ImageDataGenerator

def fixed_generator(generator):
    for batch in generator:
        yield (batch, batch)

img_rows, img_cols, img_chns = 256, 256, 3
filters = 64
num_conv = 3

train_data_dir = '/Users/pavelgulaev/Desktop/Диплом/Шок-картинки/train'
validation_data_dir = '/Users/pavelgulaev/Desktop/Диплом/Шок-картинки/valid'
nb_train_samples = 1514
nb_validation_samples = 283
epochs = 5
batch_size = 70

if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)
latent_dim = 2
intermediate_dim = 128
epsilon_std = 1.0

x = Input(batch_shape=(batch_size,) + original_img_size)
conv_1 = Conv2D(img_chns, kernel_size=(2, 2), padding='same', activation='relu')(x)
print (conv_1.shape)
conv_2 = Conv2D(filters, kernel_size=(2, 2), padding='same', activation='relu', strides=(2, 2))(conv_1)
print (conv_2.shape)
conv_3 = Conv2D(filters, kernel_size=num_conv, padding='same', activation='relu', strides=1)(conv_2)
print (conv_3.shape)
conv_4 = Conv2D(filters, kernel_size=num_conv, padding='same', activation='relu', strides=1)(conv_3)
print (conv_4.shape)
flat = Flatten()(conv_4)
print (flat.shape)
hidden = Dense(intermediate_dim, activation='relu')(flat)

z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)


def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])

decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(filters * 128 * 128, activation='relu')

if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 128, 128)
else:
    output_shape = (batch_size, 128, 128, filters)

decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Conv2DTranspose(filters, kernel_size=num_conv, padding='same', strides=1, activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters, num_conv, padding='same', strides=1, activation='relu')
if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 257, 257)
else:
    output_shape = (batch_size, 257, 257, filters)
decoder_deconv_3_upsamp = Conv2DTranspose(filters,kernel_size=(3, 3),strides=(2, 2),padding='valid',activation='relu')
decoder_mean_squash = Conv2D(img_chns, kernel_size=2, padding='valid', activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)


def vae_loss(x, x_decoded_mean):
    x = K.flatten(x)
    x_decoded_mean = K.flatten(x_decoded_mean)
    xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return xent_loss + kl_loss

vae = Model(x, x_decoded_mean_squash)
weights_file="/nfs/home/pgulyaev/vae.best.hdf5"
if (os.path.isfile(weights_file)):
    vae.load_weights(weights_file)
vae.compile(optimizer='adadelta', loss=vae_loss)


checkpoint = ModelCheckpoint(weights_file, verbose=1, save_best_only=True)
callbacks_list = [checkpoint]

train_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_rows, img_cols),
    batch_size=batch_size,
    class_mode=None)

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_rows, img_cols),
    batch_size=batch_size,
    class_mode=None)

oldStdout = sys.stdout
fileLog = open('/Users/pavelgulaev/Desktop/Диплом/vae_logFile', 'w')
sys.stdout = fileLog
vae.fit_generator(
        fixed_generator(train_generator),
        steps_per_epoch=nb_train_samples // batch_size,
        epochs=epochs,
        validation_data=fixed_generator(validation_generator),
        validation_steps=nb_validation_samples // batch_size,
        callbacks=callbacks_list)
sys.stdout = oldStdout