In [1]:
import scipy.io
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
mat = scipy.io.loadmat('mnist_all.mat')

In [2]:
""" 
def inner_product(X, Y, W, p):
    if X.shape[1] != Y.shape[0]:
        raise Exception('Mismatched shape: \n Shape of X = ' + str(X.shape) + '\n' + 'Shape of Y = ' + str(Y.shape))
    
    if(p==2):
        return np.dot(X, Y)

    s = 0

    for i in range(X.shape[0]):
        n_ij = np.power(np.abs(W[i]), p-2)
        s += n_ij*X[i]*Y[i]
    return s """

" \ndef inner_product(X, Y, W, p):\n    if X.shape[1] != Y.shape[0]:\n        raise Exception('Mismatched shape: \n Shape of X = ' + str(X.shape) + '\n' + 'Shape of Y = ' + str(Y.shape))\n    \n    if(p==2):\n        return np.dot(X, Y)\n\n    s = 0\n\n    for i in range(X.shape[0]):\n        n_ij = np.power(np.abs(W[i]), p-2)\n        s += n_ij*X[i]*Y[i]\n    return s "

In [3]:
hidden = 2000
inp = 784

lr = 1e-1
delta = 0.4

list(mat.keys())
mat['train0']
ex = (mat['train0']/255.0)
print(ex.shape)

W = np.random.normal(0.0, 1.0, (inp, hidden))



(5923, 784)


In [4]:
fig=plt.figure(figsize=(12.9,10))


def draw_weights(synapses, Kx, Ky, scale_restriction, fig):
    
    yy=0
    HM=np.zeros((28*Ky,28*Kx))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapses[yy,:].reshape(28,28)
            yy += 1
    plt.clf()
    nc=np.amax(np.absolute(HM))
    im=plt.imshow(HM,cmap='bwr',vmin=-scale_restriction,vmax=scale_restriction)
    fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
    plt.axis('off')
    fig.canvas.draw()
    plt.show()

    

<Figure size 928.8x720 with 0 Axes>

In [5]:
def inner_product_W(X, W, p):
    
    if(p==2):
        return np.dot(X, W)
    else:
        return np.dot(X, np.sign(W)*np.power(np.absolute(W), p-1))


def train_batch(inputs: np.ndarray, W: np.ndarray, p = 2, k = 2 , delta = 0.4, lr = 0.01): # p is Lebesgue-p norm number. k is which ranking of the hidden states should be discouraged for an example. delta is the value used in the activation function.
    m = inputs.shape[0] # Number of training examples in batch.
    n = inputs.shape[1] # Number of input values, in this case 784. Inputs is an (m x n) matrix
    
    hidden = W.shape[1] # W is a (n x hidden) matrix.

    
    h1 = inner_product_W(inputs, W, p) # h1 will be a (m x hidden) matrix containing the hidden states before the activation function has been applied.

    h1_sorted = np.argsort(h1, axis=1) # Creates the list of sorted indexes for the hidden states.

    
    a1 = np.zeros_like(h1) # Creates template array for the hidden states after activations. (m x hidden) matrix.

    a1[np.array(range(m)), h1_sorted[:, -1]] = 1.0 # Sets the indexes with the maximum hidden state to be 1.0 in each example.
    a1[np.array(range(m)), h1_sorted[:, -k]] = -delta # Sets the index with rank k hidden state to be -delta in each example.

    gQv = np.dot(np.transpose(inputs), a1) #First term of equation 3.

    gWv = np.reshape(np.sum(np.multiply(h1, a1), axis=0), (1, hidden)) # (1 x hidden) matrix

    gWvW = np.multiply(W, gWv) # Second term in equation 3.

    dW = gQv - gWvW # Add together to get the change in parameters.


    W +=  lr*np.divide(dW, np.max(np.absolute(dW))) # Normalize change in weights to avoid explosion of parameter magnitude.

    return W

print(W.shape)
train_batch(ex[:], W)
print(W.shape)

(784, 2000)
(784, 2000)


In [30]:
batch_size = 50
epochs = 20
reset = True

hidden = 200
inp = 784

k=2
lr = 1e-2
delta = 0.4
p=2

Kx=4
Ky=4


if reset:
    W = np.random.normal(0.0, 1.0, (inp, hidden))

synapse_history = []

M=np.zeros((0,784))
for i in range(10):
    M=np.concatenate((M, mat['train'+str(i)]), axis=0)

print(M.shape)

for epoch in range(epochs):
    print('Epoch: '+str(epoch + 1))
    for i in range(M.shape[0]//batch_size):
        batch = M[batch_size*i: batch_size*(i+1), :]
        W = train_batch(batch, W, p=p, k=k, delta=delta, lr=lr)
        if i%30==0:
            synapse_history.append(np.transpose(np.copy(W)))
    draw_weights(np.transpose(W), Kx, Ky, 0.5, fig=fig)
    
        
        

(60000, 784)
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Epoch: 20


In [32]:
%matplotlib notebook
Kx=4
Ky=4
draw_weights(np.transpose(W), Kx, Ky, 0.5, fig)

<IPython.core.display.Javascript object>

  fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])


In [33]:
# Create a figure and axis
fig, ax = plt.subplots()


#HM=np.zeros((28*Ky,28*Kx))
#im = ax.imshow(HM,cmap='bwr',vmin=-0.5,vmax=0.5)


def update_image(frame):
    yy=0
    HM=np.zeros((28*Ky,28*Kx))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapse_history[frame][yy,:].reshape(28,28)
            yy += 1
    plt.clf()
    plt.imshow(HM,cmap='bwr',vmin=-0.5,vmax=0.5)
    #nc=np.amax(np.absolute(HM))
    
    



# Create the animation
animation = FuncAnimation(fig, update_image, frames=len(synapse_history), interval=2)
print(len(synapse_history))
# Display the animation

plt.show()

<IPython.core.display.Javascript object>

1200
