### This code illustrates the fast AI implementation of the unsupervised "biological" learning algorithm from [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116) on MNIST data set. 
If you want to learn more about this work you can also check out this [lecture](https://www.youtube.com/watch?v=4lY-oAY0aQU) from MIT's [6.S191 course](http://introtodeeplearning.com/). 

This cell loads the data and normalizes it to the [0,1] range

In [16]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
mat = scipy.io.loadmat('mnist_all.mat')

Nc=10
N=784
Ns=60000

mat = scipy.io.loadmat('mnist_all.mat')

def read_mnist(mat, label, num_classes=10):
    X = np.zeros((0, 784))
    Y = np.zeros((0, num_classes))
    for i in range(num_classes):
        data = mat[label + str(i)] # [:500]

        X = np.concatenate((X, data), axis=0)

        Yi = np.zeros((data.shape[0], num_classes))
        Yi[:, i] = 1.0
        Y = np.concatenate((Y, Yi), axis=0)

    X = X / 255.0
    return X.astype(np.float32), Y.astype(np.float32)

M, MY = read_mnist(mat, "train", Nc)
Mt, MYt = read_mnist(mat, "test", Nc)

In [17]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, callbacks

def train_nn(
    X, 
    Y, 
    Xt, 
    Yt, 
    hidden_layer_size=None, 
    epochs=300,
    init_learning_rate=0.02,
    verbose=0
):
    def read_mnist_tf(X, Y):
        return (
            tf.data.Dataset.from_tensor_slices((X,Y))
                .shuffle(X.shape[0])
                .batch(100)
        )
    
    def scheduler(epoch):
        return init_learning_rate * (1.0 - epoch / epochs)
        
    train_dataset = read_mnist_tf(X, Y)
    test_dataset = read_mnist_tf(Xt, Yt)
    
    if hidden_layer_size is not None:
        model = Sequential([
            layers.Dense(hidden_layer_size, activation='relu'), 
            layers.Dense(Y.shape[1])
        ])
    else:
        model = Sequential([
            layers.Dense(Y.shape[1])
        ])

    model.compile(
        optimizer=optimizers.SGD(init_learning_rate),
        loss=tf.losses.CategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    model.fit(
        train_dataset, epochs=epochs,
        validation_data=test_dataset,
        validation_steps=2,
        callbacks=[
            callbacks.EarlyStopping(monitor='val_loss', patience=10),
            callbacks.LearningRateScheduler(scheduler)
        ],
        verbose=verbose
    )    
    return max(model.history.history["val_accuracy"])

def draw_weights(synapses, Kx, Ky):
    yy=0
    HM=np.zeros((28*Ky,28*Kx))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapses[yy,:].reshape(28,28)
            yy += 1
    plt.clf()
    nc=np.amax(np.absolute(HM))
    im=plt.imshow(HM,cmap='bwr',vmin=-nc,vmax=nc)
    fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
    plt.axis('off')
    fig.canvas.draw()   
    
def weights_inner_product(W, p):
    return np.asarray([np.linalg.norm(W[:, ni], ord=p) for ni in range(W.shape[1])])


This cell defines paramaters of the algorithm: `eps0` - initial learning rate that is linearly annealed during training; `hid` - number of hidden units that are displayed as an `Ky` by `Kx` array by the helper function defined above; `mu` - the mean of the gaussian distribution that initializes the weights; `sigma` - the standard deviation of that gaussian; `Nep` - number of epochs; `Num` - size of the minibatch; `prec` - parameter that controls numerical precision of the weight updates; `delta` - the strength of the anti-hebbian learning; `p` - Lebesgue norm of the weights; `k` - ranking parameter. 

In [18]:
eps0=2e-2    # learning rate
Kx=45        # so number of hidden units around 2000 (2025)
Ky=45
hid=Kx*Ky    # number of hidden units that are displayed in Ky by Kx array
mu=0.0
sigma=1.0
Nep=200      # number of epochs
Num=101      # size of the minibatch
prec=1e-30
delta=0.4    # Strength of the anti-hebbian learning
p=2.0        # Lebesgue norm of the weights
k=2          # ranking parameter, must be integer that is bigger or equal than 2

### Baseline performance of 1-layer NN configuration

Let's learn NN with MNIST data with just 1-layer NN configration to see the peformance of a network w/o any additional transformations

In [19]:
train_nn(M, MY, Mt, MYt, hidden_layer_size=None, epochs=300, init_learning_rate=0.02, verbose=1)

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300


0.96

### Baseline performance of multilayered NN configuration, 2025 hidden units

In [20]:
train_nn(M, MY, Mt, MYt, hidden_layer_size=hid, epochs=300, init_learning_rate=0.02, verbose=1)

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300


0.985

### Baseline of BCM network without weight changes

In [23]:
synapses = np.random.normal(mu, sigma, (hid, N))

In [24]:
Ma = np.zeros((M.shape[0], hid), dtype=np.float32)
Mta = np.zeros((Mt.shape[0], hid), dtype=np.float32)

for i in range(Ns // Num):
    inputs=np.transpose(M[i*Num:(i+1)*Num,:])
    sig=np.sign(synapses)
    tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs)
    Ma[i*Num:(i+1)*Num,:] = tot_input[:].T

for i in range(Mt.shape[0] // Num):
    inputs=np.transpose(Mt[i*Num:(i+1)*Num,:])
    sig=np.sign(synapses)
    tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs)
    Mta[i*Num:(i+1)*Num,:] = tot_input[:].T

train_nn(Ma, MY, Mta, MYt, verbose=1)  # no hidden layer by default

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300


0.93

This cell defines the main code. The external loop runs over epochs `nep`, the internal loop runs over minibatches. For every minibatch the overlap with the data `tot_input` is calculated for each data point and each hidden unit. The sorted strengths of the activations are stored in `y`. The variable `yl` stores the activations of the post synaptic cells - it is denoted by g(Q) in Eq 3 of [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116), see also Eq 9 and Eq 10. The variable `ds` is the right hand side of Eq 3. The weights are updated after each minibatch in a way so that the largest update is equal to the learning rate `eps` at that epoch. The weights are displayed by the helper function after each epoch. 

In [25]:
for nep in range(Nep):
    eps=eps0*(1-nep/Nep)
    perm = np.random.permutation(Ns)
    M=M[perm,:]
    MY=MY[perm,:]
    
    for i in range(Ns // Num):
        inputs=np.transpose(M[i*Num:(i+1)*Num,:])
        sig=np.sign(synapses)
        tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs)
        
        y=np.argsort(tot_input,axis=0)
        yl=np.zeros((hid,Num))
        yl[y[hid-1,:],np.arange(Num)]=1.0
        yl[y[hid-k],np.arange(Num)]=-delta
        
        xx=np.sum(np.multiply(yl,tot_input),1)
        ds=np.dot(yl,np.transpose(inputs)) - np.multiply(np.tile(xx.reshape(xx.shape[0],1),(1,N)),synapses)
        
        nc=np.amax(np.absolute(ds))
        if nc<prec:
            nc=prec
        synapses += eps*np.true_divide(ds,nc)
        Ma[i*Num:(i+1)*Num,:] = tot_input[:].T
        
        
    for i in range(Mt.shape[0] // Num):
        inputs=np.transpose(Mt[i*Num:(i+1)*Num,:])
        sig=np.sign(synapses)
        tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs)
        Mta[i*Num:(i+1)*Num,:] = tot_input[:].T

    if nep % (Nep // 100) == 0 or nep == Nep-1:        
        res = train_nn(Ma, MY, Mta, MYt)
        print("Epoch {}, min(W) = {:.4f}, max(W) = {:.4f}, |W| = {:.4f}, acc = {:.4f}".format(
            nep, 
            np.min(synapses),
            np.max(synapses),
            np.mean(weights_inner_product(synapses.T, 2)),
            res
        ))
        

Epoch 0, min(W) = -5.1240, max(W) = 4.6549, |W| = 26.9375, acc = 0.9250
Epoch 2, min(W) = -5.1285, max(W) = 4.6594, |W| = 24.3657, acc = 0.9350
Epoch 4, min(W) = -5.0237, max(W) = 4.6204, |W| = 21.6753, acc = 0.9150
Epoch 6, min(W) = -4.6638, max(W) = 4.5443, |W| = 19.0511, acc = 0.9300
Epoch 8, min(W) = -4.5552, max(W) = 4.2885, |W| = 16.5440, acc = 0.9250
Epoch 10, min(W) = -4.1431, max(W) = 3.8233, |W| = 14.1823, acc = 0.9250


KeyboardInterrupt: 