In [None]:
%matplotlib inline

In [None]:
import math, sys, os
import numpy as np
from numpy.linalg import norm

In [None]:
from PIL import Image
from matplotlib import pyplot as plt, rcParams, rc

In [None]:
from scipy.ndimage import imread
from skimage.measure import block_reduce
from scipy.ndimage.filters import correlate, convolve

In [None]:
import pickle
from ipywidgets import interact, interactive, fixed
from ipywidgets.widgets import *

In [None]:
rc('animation', html='html5')
rcParams['figure.figsize'] = 3, 6
%precision 4
np.set_printoptions(precision=4, linewidth=100)

In [None]:
def plots(imgs, interp=False, titles=None):
    imgs = np.array(imgs)
    min, max = imgs.min(), imgs.max()
    f = plt.figure(figsize=(12, 24))
    for i in range(len(imgs)):
        sp = f.add_subplot(1, len(imgs), i+1)
        if not titles is None: sp.set_title(titles[i], fontsize=18)
        plt.imshow(imgs[i], interpolation=None if interp else 'none', vmin=min, vmax=max)
        
def plot(img, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    plt.imshow(img, interpolation=None if interp else 'none')
    
plt.gray()
plt.close()

In [None]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [None]:
images = X_train
labels = y_train
N = len(images)
images.shape

In [None]:
plot(images[0])

In [None]:
labels[0]

In [None]:
plots(images[:5], titles=labels[:5])

In [None]:
top = [[-1,-1,-1], [1,1,1], [0,0,0]]
plot(top)

In [None]:
r = (0, 28)
def zoomim(x1=0, x2=28, y1=0, y2=28):
    plot(images[0, y1:y2, x1:x2])
w = interactive(zoomim, x1=r, x2=r, y1=r, y2=r)
w

In [None]:
k = w.kwargs
k

In [None]:
dims = np.index_exp[k['y1']:k['y2']:1, k['x1']:k['x2']]
images[0][dims]

In [None]:
corrtop = correlate(images[0], top)
corrtop[dims]

In [None]:
plot(corrtop[dims])

In [None]:
plot(corrtop)

In [None]:
np.rot90(top, 1)

In [None]:
convtop = convolve(images[0], np.rot90(top, 2))
plot(convtop)
np.allclose(convtop, corrtop)

In [None]:
straights = [np.rot90(top, i) for i in range(4)]
plots(straights)

In [None]:
br = [[0,0,1], [0,1,-1.5], [1,-1.5,0]]
diags = [np.rot90(br, i) for i in range(4)]
plots(diags)

In [None]:
rots = straights + diags
corrs = [correlate(images[0], rot) for rot in rots]
plots(corrs)

In [None]:
def pool(img): return block_reduce(img, (7,7), np.max)

plots([pool(img) for img in corrs])

In [None]:
eights = [images[i] for i in range(N) if labels[i] == 8]
ones = [images[i] for i in range(N) if labels[i] == 1]

In [None]:
plots(eights[:5])
plots(ones[:5])

In [None]:
pool8 = [np.array([pool(correlate(img, rot)) for img in eights]) for rot in rots]

In [None]:
len(pool8), pool8[0].shape

In [None]:
plots(pool8[0][0:5])

In [None]:
def normalize(arr): return (arr-arr.mean())/arr.std()

In [None]:
filts8 = np.array([imgs.mean(axis=0) for imgs in pool8])
filts8 = normalize(filts8)

In [None]:
plots(filts8)

In [None]:
pool1 = [np.array([pool(correlate(imgs, rot)) for imgs in ones]) for rot in rots]
filts1 = np.array([imgs.mean(axis=0) for imgs in pool1])
filts1 = normalize(filts1)

In [None]:
plots(filts1)

In [None]:
def pool_corr(img): return np.array([pool(correlate(img, rot)) for rot in rots])

In [None]:
plots(pool_corr(eights[0]))

In [None]:
def sse(a,b): return ((a-b)**2).sum()
def is8_n2(img): return 1 if sse(pool_corr(img), filts1) > sse(pool_corr(img), filts8) else 0

In [None]:
sse(pool_corr(eights[0]), filts8), sse(pool_corr(eights[0]), filts1)

In [None]:
[np.array([is8_n2(img) for img in imgs]).sum() for imgs in [eights, ones]]

In [None]:
[np.array([1-is8_n2(img) for img in imgs]).sum() for imgs in [eights, ones]]