In [None]:
import keras.optimizers
import scipy.signal
import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import fashion_mnist
from keras.models import Sequential
from keras.layers import Conv2D

%matplotlib inline

In [None]:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.astype(np.float32)
x_train = x_train / 255.0

In [None]:
plt.clf()
plt.figure(figsize=(100, 20), dpi=20)

for i in range(20):
    plt.subplot(2, 10, i + 1)
    plt.imshow(x_train[i], aspect='auto', cmap='gray', vmin=0.0, vmax=1.0)

![](./assets/005_filter.gif)

In [None]:
sharp_filter = np.array([
    [-0.01, -0.10, -0.01],
    [-0.10, +1.00, -0.10],
    [-0.01, -0.10, -0.01]])

plt.clf()
plt.imshow(sharp_filter, cmap='gray', vmin=-1.0, vmax=1.0)

In [None]:
y_train = x_train.copy()

for i in range(y_train.shape[0]):
    y_train[i] = scipy.signal.convolve2d(
        x_train[i], sharp_filter, mode='same', boundary='fill', fillvalue=0.0)

plt.clf()
plt.subplot(1, 2, 1)
plt.imshow(x_train[0], cmap='gray', vmin=0.0, vmax=1.0)
plt.subplot(1, 2, 2)
plt.imshow(y_train[0], cmap='gray', vmin=0.0, vmax=1.0)

In [None]:
def build_model():
    """
    """
    model = Sequential()

    model.add(Conv2D(1, (3, 3), padding='same', use_bias=False, input_shape=(28, 28, 1)))

    adam = keras.optimizers.Adam(lr=0.001, decay=0.001)
    
    model.compile(optimizer=adam, loss='mean_squared_error')
    
    return model

In [None]:
x_train = x_train.reshape(-1, 28, 28, 1)
y_train = y_train.reshape(-1, 28, 28, 1)

model = build_model()

initial_filter = model.layers[0].get_weights()[0].reshape(3, 3)

plt.clf()
plt.subplot(1, 2, 1)
plt.imshow(sharp_filter, cmap='gray', vmin=-1.0, vmax=1.0)
plt.subplot(1, 2, 2)
plt.imshow(initial_filter, cmap='gray', vmin=-1.0, vmax=1.0)

In [None]:
model.fit(x=x_train, y=y_train, batch_size=128, epochs=40, verbose=2)

In [None]:
print sharp_filter
print model.layers[0].get_weights()[0].reshape(3, 3)

In [None]:
learnt_filter = model.layers[0].get_weights()[0].reshape(3, 3)

plt.clf()
plt.subplot(1, 2, 1)
plt.imshow(sharp_filter, cmap='gray', vmin=-1.0, vmax=1.0)
plt.subplot(1, 2, 2)
plt.imshow(learnt_filter, cmap='gray', vmin=-1.0, vmax=1.0)