# Biological Learning 

## Unsupervised learning part
### This cell loads the data and normalizes it to the [0,1] range

In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
mat = scipy.io.loadmat('mnist_all.mat')

Nc=10
N=784
Ns=60000
M=np.zeros((0,N))
for i in range(Nc):
    M=np.concatenate((M, mat['train'+str(i)]), axis=0)
M=M/255.0

To draw a heatmap of the weights a helper function is created:

In [None]:
def draw_weights(synapses, Kx, Ky, yy=np.random.randint(1800)):
    #yy=0
    HM=np.zeros((28*Ky,28*Kx))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*28:(y+1)*28,x*28:(x+1)*28]=synapses[yy,:].reshape(28,28)
            yy += 1
    plt.clf()
    nc=np.amax(np.absolute(HM))
    im=plt.imshow(HM,cmap='bwr',vmin=-nc,vmax=nc)
    fig.colorbar(im,ticks=[np.amin(HM), 0, np.amax(HM)])
    plt.axis('off')
    fig.canvas.draw()   
    

This cell defines paramaters of the algorithm: `eps0` - initial learning rate that is linearly annealed during training; `hid` - number of hidden units that are displayed as an `Ky` by `Kx` array by the helper function defined above; `mu` - the mean of the gaussian distribution that initializes the weights; `sigma` - the standard deviation of that gaussian; `Nep` - number of epochs; `Num` - size of the minibatch; `prec` - parameter that controls numerical precision of the weight updates; `delta` - the strength of the anti-hebbian learning; `p` - Lebesgue norm of the weights; `k` - ranking parameter. 

In [None]:
eps0=0.04    # learning rate
Kx=10
Ky=10
hid=2000 #Kx*Ky    # number of hidden units that are displayed in Ky by Kx array
mu=0.0
sigma=1.0
Nep=1000      # number of epochs
Num=100      # size of the minibatch
prec=1e-30
delta=0.4    # Strength of the anti-hebbian learning
p=3.0        # Lebesgue norm of the weights
k=7          # ranking parameter, must be integer that is bigger or equal than 2

This cell defines the main code. The external loop runs over epochs `nep`, the internal loop runs over minibatches. For every minibatch the overlap with the data `tot_input` is calculated for each data point and each hidden unit. The sorted strengths of the activations are stored in `y`. The variable `yl` stores the activations of the post synaptic cells - it is denoted by g(Q) in Eq 3 of [Unsupervised Learning by Competing Hidden Units](https://doi.org/10.1073/pnas.1820458116), see also Eq 9 and Eq 10. The variable `ds` is the right hand side of Eq 3. The weights are updated after each minibatch in a way so that the largest update is equal to the learning rate `eps` at that epoch. The weights are displayed by the helper function after each epoch. 

In [None]:
%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(12.9,10))

synapses = np.random.normal(mu, sigma, (hid, N))
for nep in range(Nep):
    print(nep)
    eps=eps0*(1-nep/Nep)
    M=M[np.random.permutation(Ns),:]
    for i in range(Ns//Num):
        inputs=np.transpose(M[i*Num:(i+1)*Num,:])
        sig=np.sign(synapses)
        tot_input=np.dot(sig*np.absolute(synapses)**(p-1),inputs)
        
        y=np.argsort(tot_input,axis=0)
        yl=np.zeros((hid,Num))
        yl[y[hid-1,:],np.arange(Num)]=1.0
        yl[y[hid-k],np.arange(Num)]=-delta
        
        xx=np.sum(np.multiply(yl,tot_input),1)
        ds=np.dot(yl,np.transpose(inputs)) - np.multiply(np.tile(xx.reshape(xx.shape[0],1),(1,N)),synapses)
        
        nc=np.amax(np.absolute(ds))
        if nc<prec:
            nc=prec
        synapses += eps*np.true_divide(ds,nc)
        
    draw_weights(synapses, Kx, Ky)
        

In [None]:
np.save("./synapses", synapses)

Visualize neurons of hidden layers

In [None]:
#%matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(10,10))
synapses = np.load("./synapses.npy")
draw_weights(synapses, 10, 10, yy=np.random.randint(1800))

## Supervised learning part

### Baseline Model 

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import tensorflow.keras as keras 
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.callbacks import ReduceLROnPlateau



# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()


x_train = x_train.reshape(x_train.shape[0], 784)
x_test = x_test.reshape(x_test.shape[0], 784)
    
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0


# convert class vectors to binary class matrices
#y_train = one_hot(y_train)
#y_test = one_hot(y_test)

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model = Sequential()
model.add(Dense(2000, input_dim=784, activation='relu'))
model.add(Dense(10, activation="softmax"))

sgd = keras.optimizers.Adam(lr=0.0002,) #decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=sgd,
              metrics=['accuracy'])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                              patience=15, min_lr=0.000001)
#early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
#                              patience=3,)
#score = model.evaluate(x_test, y_test, verbose=0)
logdir = "logs/end_to_end/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

model.fit(x_train, y_train,
          epochs=300,
          steps_per_epoch=60000 // 100,
          validation_data=(x_test, y_test),
          validation_steps=10000 // 100,verbose=1,
          callbacks=[reduce_lr, tensorboard_callback])

### Bio learning model - training output layer

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import tensorflow.keras as keras 
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.callbacks import ReduceLROnPlateau

synapses = np.load("./synapses.npy")

@tf.custom_gradient
def custom_activation(x):

    zeros = tf.zeros(tf.shape(x), dtype=x.dtype.base_dtype)
    out = tf.math.pow(x, 1.)

    def grad(dy):
        return dy
    return keras.backend.switch(x > 0.0, out, zeros), grad

def loss(labels, pred):
    return tf.math.reduce_mean(tf.math.pow(tf.abs(tf.math.subtract(pred, labels)), 6.))

def one_hot(arr, nb_classes = 10):
    targets = np.array(arr).reshape(-1)
    one_hot_targets = np.eye(nb_classes)[targets]
    one_hot_targets[one_hot_targets == 0] = -1
    return one_hot_targets

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()


x_train = x_train.reshape(x_train.shape[0], 784)
x_test = x_test.reshape(x_test.shape[0], 784)
    
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255


# convert class vectors to binary class matrices
#y_train = one_hot(y_train)
#y_test = one_hot(y_test)

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model = Sequential()
model.add(Dense(synapses.shape[0], input_dim=synapses.shape[1], activation=custom_activation, trainable=False))
model.layers[0].set_weights([np.transpose(synapses), np.zeros((synapses.shape[0]))])
model.add(Dense(10, activation='softmax', trainable=True))

sgd = keras.optimizers.Adam(lr=0.0002,) #decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=loss,
              optimizer=sgd,
              metrics=['accuracy'])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                              patience=15, min_lr=0.000001)

#score = model.evaluate(x_test, y_test, verbose=0)
logdir = "logs/bio/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

model.fit(x_train, y_train,
          epochs=300,
          steps_per_epoch=60000 // 100,
          validation_data=(x_test, y_test),
          validation_steps=10000 // 100,verbose=1, 
          callbacks=[tensorboard_callback, reduce_lr])