In [None]:
#https://qiita.com/shinmura0/items/811d01384e20bfd1e035

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Lambda, Input, Dense, Reshape
from keras.models import Model
from keras.datasets import mnist
from keras.datasets import fashion_mnist
from keras.losses import mse
from keras.utils import plot_model
from keras import backend as K
from keras.layers import BatchNormalization, Activation, Flatten
from keras.layers.convolutional import Conv2DTranspose, Conv2D

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import os, datetime
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
#ヒートマップの描画
def save_img(x_normal, x_anomaly, img_normal, img_anomaly, name):
    path = 'images/'
    if not os.path.exists(path):
          os.mkdir(path)

    #　※注意　評価したヒートマップを1～10に正規化
    img_max = np.max([img_normal, img_anomaly])
    img_min = np.min([img_normal, img_anomaly])
    img_normal = (img_normal-img_min)/(img_max-img_min) * 9 + 1
    img_anomaly = (img_anomaly-img_min)/(img_max-img_min) * 9 + 1

    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(x_normal[0,:,:,0], cmap='gray')
    plt.axis('off')
    plt.colorbar()

    plt.subplot(2, 2, 2)
    plt.imshow(img_normal[0,:,:,0], cmap='Blues',norm=colors.LogNorm())
    plt.axis('off')
    plt.colorbar()
    plt.clim(1, 10)

    plt.title(name + "normal")

    plt.subplot(2, 2, 3)
    plt.imshow(x_anomaly[0,:,:,0], cmap='gray')
    plt.axis('off')
    plt.colorbar()

    plt.subplot(2, 2, 4)
    plt.imshow(img_anomaly[0,:,:,0], cmap='Blues',norm=colors.LogNorm())
    plt.axis('off')
    plt.colorbar()
    plt.clim(1, 10)

    plt.title(name + "anomaly")

    #plt.savefig(path + name +".png")
    plt.show()
    plt.close()

In [None]:
#ヒートマップの計算
def evaluate_img(model, arg_x_normal, arg_x_anomaly, name, height=8, width=8, move=2):
#     for x_normal, x_anomaly in zip(arg_x_normal, arg_x_anomaly):
    for x_anomaly in arg_x_anomaly:
        x_normal = np.asarray([arg_x_normal[0]])
        x_anomaly = np.asarray([x_anomaly])
        img_normal = np.zeros((x_normal.shape))
        img_anomaly = np.zeros((x_normal.shape))
        for i in range(int((x_normal.shape[1]-height)/move)):
            for j in range(int((x_normal.shape[2]-width)/move)):
                x_sub_normal = x_normal[0, i*move:i*move+height, j*move:j*move+width, 0]
                x_sub_anomaly = x_anomaly[0, i*move:i*move+height, j*move:j*move+width, 0]
                x_sub_normal = x_sub_normal.reshape(1, height, width, 1)
                x_sub_anomaly = x_sub_anomaly.reshape(1, height, width, 1)

                #従来手法
                if name == "old_":
                    #正常のスコア
                    normal_score = model.evaluate(x_sub_normal, batch_size=1, verbose=0)
                    img_normal[0, i*move:i*move+height, j*move:j*move+width, 0] +=  normal_score

                    #異常のスコア
                    anomaly_score = model.evaluate(x_sub_anomaly, batch_size=1, verbose=0)
                    img_anomaly[0, i*move:i*move+height, j*move:j*move+width, 0] +=  anomaly_score

                #提案手法
                else:
                    #正常のスコア
                    mu, sigma = model.predict(x_sub_normal, batch_size=1, verbose=0)
                    loss = 0
                    for k in range(height):
                        for l in range(width):
                            loss += 0.5 * (x_sub_normal[0,k,l,0] - mu[0,k,l,0])**2 / sigma[0,k,l,0]
                    img_normal[0, i*move:i*move+height, j*move:j*move+width, 0] +=  loss

                    #異常のスコア
                    mu, sigma = model.predict(x_sub_anomaly, batch_size=1, verbose=0)
                    loss = 0
                    for k in range(height):
                        for l in range(width):
                            loss += 0.5 * (x_sub_anomaly[0,k,l,0] - mu[0,k,l,0])**2 / sigma[0,k,l,0]
                    img_anomaly[0, i*move:i*move+height, j*move:j*move+width, 0] +=  loss

        save_img(x_normal, x_anomaly, img_normal, img_anomaly, name)

In [None]:
#8×8のサイズに切り出す
def cut_img(x, number, height=8, width=8):
    print("cutting images ...")
    x_out = []
    x_shape = x.shape

    for i in range(number):
        shape_0 = np.random.randint(0,x_shape[0])
        shape_1 = np.random.randint(0,x_shape[1]-height)
        shape_2 = np.random.randint(0,x_shape[2]-width)
        temp = x[shape_0, shape_1:shape_1+height, shape_2:shape_2+width, 0]
        x_out.append(temp.reshape((height, width, x_shape[3])))

    print("Complete.")
    x_out = np.array(x_out)

    return x_out

In [None]:
# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [None]:
# dataset
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#1と9のデータ抽出
x_train_1 = []
x_test_1 = []
x_test_2 = []
x_test_3 = []
x_test_4 = []
x_test_5 = []
x_test_6 = []
x_test_7 = []
x_test_8 = []
x_test_9 = []

x_train_shape = x_train.shape

for i in range(len(x_train)):
  if y_train[i] == 1:#スニーカーは7
    temp = x_train[i,:,:,:]
    x_train_1.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))

x_train_1 = np.array(x_train_1)
x_train_1 = cut_img(x_train_1, 100000)
print("train data:",len(x_train_1))

for i in range(len(x_test)):
  if y_test[i] == 1:#スニーカーは7
    temp = x_test[i,:,:,:]
    x_test_1.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 2:
    temp = x_test[i,:,:,:]
    x_test_2.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 3:
    temp = x_test[i,:,:,:]
    x_test_3.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 4:
    temp = x_test[i,:,:,:]
    x_test_4.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 5:
    temp = x_test[i,:,:,:]
    x_test_5.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))

  if y_test[i] == 6:
    temp = x_test[i,:,:,:]
    x_test_6.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 7:
    temp = x_test[i,:,:,:]
    x_test_7.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))
    
  if y_test[i] == 8:
    temp = x_test[i,:,:,:]
    x_test_8.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))

  if y_test[i] == 9:
    temp = x_test[i,:,:,:]
    x_test_9.append(temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3])))

x_test_1 = np.array(x_test_1)
x_test_2 = np.array(x_test_2)
x_test_3 = np.array(x_test_3)
x_test_4 = np.array(x_test_4)
x_test_5 = np.array(x_test_5)
x_test_6 = np.array(x_test_6)
x_test_7 = np.array(x_test_7)
x_test_8 = np.array(x_test_8)
x_test_9 = np.array(x_test_9)

In [None]:
# network parameters
input_shape=(8, 8, 1)
batch_size = 1024
latent_dim = 2
epochs = 10
Nc = 16

# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Conv2D(Nc, kernel_size=2, strides=2)(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(2*Nc, kernel_size=2, strides=2)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Flatten()(x)

z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
#encoder.summary()

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(2*2)(latent_inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape((2,2,1))(x)
x = Conv2DTranspose(2*Nc, kernel_size=2, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(Nc, kernel_size=2, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

x1 = Conv2DTranspose(1, kernel_size=4, padding='same')(x)
x1 = BatchNormalization()(x1)
out1 = Activation('sigmoid')(x1)#out.shape=(n,28,28,1)

x2 = Conv2DTranspose(1, kernel_size=4, padding='same')(x)
x2 = BatchNormalization()(x2)
out2 = Activation('sigmoid')(x2)#out.shape=(n,28,28,1)

decoder = Model(latent_inputs, [out1, out2], name='decoder')
#decoder.summary()

# build VAE model
outputs_mu, outputs_sigma_2 = decoder(encoder(inputs)[2])
vae = Model(inputs, [outputs_mu, outputs_sigma_2], name='vae_mlp')

# VAE loss
m_vae_loss = (K.flatten(inputs) - K.flatten(outputs_mu))**2 / K.flatten(outputs_sigma_2)
m_vae_loss = 0.5 * K.sum(m_vae_loss)

a_vae_loss = K.log(2 * 3.14 * K.flatten(outputs_sigma_2))
a_vae_loss = 0.5 * K.sum(a_vae_loss)

kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5

vae_loss = K.mean(kl_loss + m_vae_loss + a_vae_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

In [None]:
# train the autoencoder
vae.fit(x_train_1,
        epochs=epochs,
        batch_size=batch_size)
        #validation_data=(x_test, None))
now = datetime.datetime.now()
vae.save_weights("vae_mnist_{0:%Y%m%d_%H%M%S}.h5".format(now))

In [None]:
#正常/異常のテストデータ
idx1 = np.random.randint(len(x_test_1))
idx2 = np.random.randint(len(x_test_9))

test_normal = x_test_1[idx1,:,:,:]
test_anomaly = x_test_9[idx2,:,:,:]
test_normal = test_normal.reshape(1, test_normal.shape[0], test_normal.shape[1], test_normal.shape[2])
test_anomaly = test_anomaly.reshape(test_normal.shape)

In [None]:
#従来手法の可視化
#evaluate_img(vae, test_normal, test_anomaly, "old_")

#提案手法の可視化
# evaluate_img(vae, test_normal, test_anomaly, "new_")

In [None]:
evaluate_img(vae, x_test_1[:50], x_test_2[:50], "new_")
evaluate_img(vae, x_test_1[:50], x_test_3[:50], "new_")

In [None]:
evaluate_img(vae, x_test_1[:50], x_test_4[:50], "new_")
evaluate_img(vae, x_test_1[:50], x_test_5[:50], "new_")

In [None]:
evaluate_img(vae, x_test_1[:50], x_test_6[:50], "new_")
evaluate_img(vae, x_test_1[:50], x_test_7[:50], "new_")

In [None]:
evaluate_img(vae, x_test_1[:50], x_test_8[:50], "new_")
evaluate_img(vae, x_test_1[:50], x_test_9[:50], "new_")