# Intro to Convolutions

## Set up

In [4]:
%matplotlib inline
import math,sys,os,numpy as np
from numpy.linalg import norm
from PIL import Image
from matplotlib import pyplot as plt, rcParams, rc
from imageio import imread
#from scipy.ndimage import imread
from skimage.measure import block_reduce
import pickle as pickle
from scipy.ndimage.filters import correlate, convolve
rc('animation', html='html5')
rcParams['figure.figsize'] = 3, 6
%precision 4
np.set_printoptions(precision=4, linewidth=100)

In [5]:
def plots(ims, interp=False, titles=None):
    ims=np.array(ims)
    mn,mx=ims.min(),ims.max()
    f = plt.figure(figsize=(12,24))
    for i in range(len(ims)):
        sp=f.add_subplot(1, len(ims), i+1)
        if not titles is None: sp.set_title(titles[i], fontsize=18)
        plt.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn,vmax=mx)

def plot(im, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    # plt.show(im)
    plt.imshow(im, interpolation=None if interp else 'none')

plt.gray()
plt.close()

## MNIST Data

In [7]:
from sklearn.datasets import fetch_openml
#mnist = fetch_openml('MNIST original') #this has some problem 
mnist = fetch_openml('mnist_784',version=1, cache = True) #

In [None]:
mnist.keys()

In [None]:
mnist['data'].shape, mnist['target'].shape

In [None]:
images = np.reshape(mnist['data'], (70000, 28, 28))
labels = mnist['target'].astype(int)
n=len(images)
images.shape, labels.shape

In [None]:
images = images/255

In [None]:
plot(images[0])

In [None]:
labels[0]

In [None]:
plots(images[:5], titles=labels[:5])

we can zoom in on part of the image

In [None]:
plot(images[0,0:14, 8:22])

## Edge Detection

We will look at how to create an Edge detector:

In [None]:
top=[[-1,-1,-1],
     [ 1, 1, 1],
     [ 0, 0, 0]]

plot(top)

In [None]:
dims = np.index_exp[10:28:1,3:13]
images[0][dims]

In [None]:
corrtop = correlate(images[0], top)

In [None]:
corrtop[dims]

In [None]:
plot(corrtop)

In [None]:
np.rot90(top, 1)

In [None]:
convtop = convolve(images[0], np.rot90(top,2))
plot(convtop)
np.allclose(convtop, corrtop)

In [None]:
straights=[np.rot90(top,i) for i in range(4)]
plots(straights)

In [None]:
br=[[ 0, 0, 1],
    [ 0, 1,-1.5],
    [ 1,-1.5, 0]]

diags = [np.rot90(br,i) for i in range(4)]
plots(diags)

In [None]:
rots = straights + diags
corrs = [correlate(images[0], rot) for rot in rots]
plots(corrs)

In [None]:
eights=[images[i] for i in range(n) if labels[i]==8]
ones=[images[i] for i in range(n) if labels[i]==1]

In [None]:
plots(eights[:5])
plots(ones[:5])

In [None]:
def normalize(arr): return (arr-arr.mean())/arr.std()

In [None]:
filts8 = np.array([ims.mean(axis=0) for ims in pool8])
filts8 = normalize(filts8)

In [None]:
plots(filts8)

In [None]:
pool1 = [np.array([pool(correlate(im, rot)) for im in ones]) for rot in rots]
filts1 = np.array([ims.mean(axis=0) for ims in pool1])
filts1 = normalize(filts1)

In [None]:
plots(filts1)

In [None]:
def pool_corr(im): return np.array([pool(correlate(im, rot)) for rot in rots])

In [None]:
plots(pool_corr(eights[0]))

In [None]:
def sse(a,b): return ((a-b)**2).sum()
def is8_n2(im): return 1 if sse(pool_corr(im),filts1) > sse(pool_corr(im),filts8) else 0

In [None]:
sse(pool_corr(eights[0]), filts8), sse(pool_corr(eights[0]), filts1)

In [None]:
[np.array([is8_n2(im) for im in ims]).sum() for ims in [eights,ones]]

In [None]:
[np.array([(1-is8_n2(im)) for im in ims]).sum() for ims in [eights,ones]]

In [None]:
def n1(a,b): return (np.fabs(a-b)).sum()
def is8_n1(im): return 1 if n1(pool_corr(im),filts1) > n1(pool_corr(im),filts8) else 0

In [None]:
[np.array([is8_n1(im) for im in ims]).sum() for ims in [eights,ones]]

In [None]:
[np.array([(1-is8_n1(im)) for im in ims]).sum() for ims in [eights,ones]]