In [None]:
import math 
import math, sys, os , numpy as np
from numpy.linalg import norm
from matplotlib import pyplot as plt
import torch, torchvision
from torchvision import models, transforms, datasets

In [None]:
%matplotlib inline

In [None]:
root_dir = './data/MNIST'
torchvision.datasets.MNIST(root= root_dir, download=True)
# torchvision.datasets.MNIST(root = root_dir, download = True)

In [None]:
## Creating training dataset 
train_set = torchvision.datasets.MNIST(root=root_dir, train=True, download=True)
## Loading them into dataloader
MNIST_dataset = torch.utils.data.DataLoader(train_set, batch_size = 1, shuffle=True,num_workers=4)




In [None]:
## Normalizing the images
images = train_set.data.numpy().astype(np.float32)/255
labels = train_set.targets.numpy()

## Lets check them out
print(images.shape, labels.shape) ## (num_imgs, H, W) (num_classes)





In [None]:
## OK Lets move onto the DataVisualization part
def plots(ims, interp=False, titles=None):
    ims = np.array(ims)
    print(ims.shape, ims.ndim)
    mn, mx = ims.min(), ims.max()
    f = plt.figure(figsize=(12,24))
    
    for i in range(len(ims)):
        sp = f.add_subplot(1, len(ims) - 1)
        if not titles is None: 
            sp.set_title(titles[i], fontsize=18)
            plt.imshow(ims[i], interpolation= None if interp else 'none', vmin=mn, vmax = mx)
        
        

def plot(im, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    plt.imshow(im, interpolation=None if interp else 'none')


    
plt.gray()
plt.close()

    
    




    

In [None]:
plot(images[5000]), labels[5000]

In [None]:
plots(images[5000:5005], titles=labels[5000:5005])


In [None]:
## Lets build a simple classifier
n = len(images)
print(n)

eights = [images[i] for i in range(n) if labels[i]==8]
ones = [images[i] for i in range(n) if labels[i]==1]

print(len(eights), len(ones))



In [None]:
plots(eights[:5])
plots(ones[:5])

In [None]:
raws8 = np.mean(eights[1000:], axis=0) ## Keeping first 1000 for test purpose and averaging the rest
plot(raws8)

In [None]:
raws1 = np.mean(ones[1000:], axis=0) ## Keeping first 1000 for test purpose and averaging the rest
plot(raws1)

In [None]:
## sum of squared distance
def sse(a,b):
    return ((a-b)**2).sum()

## return 1 if closest to 8 and 0 otherwise
def is8_raw_n2(im):
    return 1 if sse(im, raws1)>sse(im, raws8) else 0


In [None]:
nb_8_predicted_8, nb_1_predicted_8 = [np.array([is8_raw_n2(im) for im in ims]).sum() for ims in [eights[:1000],ones[:1000]]]

nb_8_predicted_1, nb_1_predicted_1 = [np.array([(1-is8_raw_n2(im)) for im in ims]).sum() for ims in [eights[:1000],ones[:1000]]]

# just to check 
print(nb_8_predicted_1+nb_8_predicted_8, nb_1_predicted_1+nb_1_predicted_8)

In [None]:
def compute_scores(nb_8_predicted_8,nb_8_predicted_1,nb_1_predicted_1,nb_1_predicted_8):
    Precision_8 = nb_8_predicted_8/(nb_8_predicted_8+nb_1_predicted_8)
    Recall_8 = nb_8_predicted_8/(nb_8_predicted_1+nb_8_predicted_8)
    Precision_1 = nb_1_predicted_1/(nb_1_predicted_1+nb_8_predicted_1)
    Recall_1 = nb_1_predicted_1/(nb_1_predicted_1+nb_1_predicted_8)
    return Precision_8, Recall_8, Precision_1, Recall_1

Precision_8, Recall_8, Precision_1, Recall_1 = compute_scores(nb_8_predicted_8,nb_8_predicted_1,nb_1_predicted_1,nb_1_predicted_8)

print('precision 8:', Precision_8, 'recall 8:', Recall_8)
print('precision 1:', Precision_1, 'recall 1:', Recall_1)
print('accuracy :', (Recall_1+Recall_8)/2)

In [None]:
## Filters and Convolutions
top = [[-1,-1,-1],[1,1,1],[0,0,0]]
plot(top)


cross = np.zeros((28,28))

cross += np.eye(28)
for i in range(4):
    cross[12+i,:] = np.ones(28)
    cross[:, 12+i] = np.ones(28)

plot(cross)















In [None]:
from scipy.ndimage.filters import convolve, correlate
corr_cross = correlate(cross, top)
plot(corr_cross)

In [None]:
?correlate

In [None]:
corr_cross = correlate(cross, top, mode='constant')
plot(corr_cross)

In [None]:
np.rot90(top,1)

In [None]:
corrtop = correlate(images[5000], top)
plot(corrtop)

convtop = convolve(images[5000], np.rot90(top, 2))
plot(convtop)
np.allclose(convtop, corrtop)

In [None]:
straights=[np.rot90(top,i) for i in range(4)]
plots(straights)

In [None]:
br=[[ 0, 0, 1],
    [ 0, 1,-1.5],
    [ 1,-1.5, 0]]

diags = [np.rot90(br,i) for i in range(4)]
plots(diags)

In [None]:
rots = straights + diags
corrs_cross = [correlate(cross, rot) for rot in rots]
plots(corrs_cross)

In [None]:
rots = straights + diags
corrs = [correlate(images[5000], rot) for rot in rots]
plots(corrs)

In [None]:

import skimage

from skimage.measure import block_reduce

def pool(im): return block_reduce(im, (7,7), np.max)

plots([pool(im) for im in corrs])

In [None]:
pool8 = [np.array([pool(correlate(im, rot)) for im in eights[1000:]]) for rot in rots]


In [None]:
len(pool8), pool8[0].shape


In [None]:
plots(pool8[0][0:5])


In [None]:
plots([pool8[i][0] for i in range(8)])
plots([pool8[i][1] for i in range(8)])
plots([pool8[i][2] for i in range(8)])
plots([pool8[i][3] for i in range(8)])

In [None]:
def normalize(arr): return (arr-arr.mean())/arr.std()


In [None]:
filts8 = np.array([ims.mean(axis=0) for ims in pool8])
filts8 = normalize(filts8)
plot(filts8)

In [None]:
pool1 = [np.array([pool(correlate(im, rot)) for im in ones[1000:]]) for rot in rots]
filts1 = np.array([ims.mean(axis=0) for ims in pool1])
filts1 = normalize(filts1)

In [None]:
plots(filts1)
def pool_corr(im): return np.array([pool(correlate(im, rot)) for rot in rots])
plots(pool_corr(eights[1000]))
#check 
plots([pool8[i][0] for i in range(8)])
np.allclose(pool_corr(eights[1000]),[pool8[i][0] for i in range(8)])

In [None]:
def is8_n2(im): return 1 if sse(pool_corr(im),filts1) > sse(pool_corr(im),filts8) else 0


In [None]:
sse(pool_corr(eights[0]), filts8), sse(pool_corr(eights[0]), filts1)


In [None]:
plot(eights[0])


In [None]:
nb_8_predicted_8, nb_1_predicted_8 = [np.array([is8_n2(im) for im in ims]).sum() for ims in [eights[:1000],ones[:1000]]]

nb_8_predicted_1, nb_1_predicted_1 = [np.array([(1-is8_n2(im)) for im in ims]).sum() for ims in [eights[:1000],ones[:1000]]]

In [None]:
Precisionf_8, Recallf_8, Precisionf_1, Recallf_1 = compute_scores(nb_8_predicted_8,nb_8_predicted_1,nb_1_predicted_1,nb_1_predicted_8)

print('precision 8:', Precisionf_8, 'recall 8:', Recallf_8)
print('precision 1:', Precisionf_1, 'recall 1:', Recallf_1)
print('accuracy :', (Recallf_1+Recallf_8)/2)
print('accuracy baseline:', (Recall_1+Recall_8)/2)