In [1]:
%matplotlib ipympl

import mpl_interactions.ipyplot as iplt
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

from scipy.integrate import solve_ivp

from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, List

from numba import jit

import random

from tqdm.notebook import tqdm

import pickle

from mpl2latex import mpl2latex, latex_figsize

In [2]:
class BioNeurons():
    """
    Simulate synaptic plasticity rule and dynamic equations from [1].
    
    [1]: "Unsupervised learning by competing hidden units", D. Krotov, J. J. Hopfield, 2019, 
         https://www.pnas.org/content/116/16/7723
    """
    
    def __init__(self,
                 input_dim : int = 2,
                 output_dim : int = 4,
                 w_inh : float = .63,
                 delta : float = .4,
                 h_star : float = 0.1,
                 tau : float = 1,
                 tau_L : float = 100,
                 lebesgue_p : float = 4.):
        """
        Set all the parameters for the simulation.
        
        Parameters
        ----------
        input_dim : int
            Number of visible units
        output_dim : int
            Number of hidden units
        w_inh : float
            Strength of global inhibition (from eq. 8 in [1]). Should be >= 0, and high enough so that
            only a small fraction of activations are positive in the steady-state (i.e. self.stationary_activations returns
            a small percentage of positive values).
        delta : float
            Strength of anti-Hebbian learning (from eq. 9 in [1]). 
        h_star : float
            Threshold for activation (from eq. 9 in [1])
        tau : float
            Dynamical time scale of individual neurons (from eq. 8 in [1])
        tau_L : float
            Time scale of learning dynamics (from eq. 3 in [1]). Should be >> tau.
        lebesgue_p : float
            Parameter for Lebesgue measure, used for defining an inner product (from eq. 2 in [1]).
        """
        
        #Store parameters
        self.input_dim  = input_dim
        self.output_dim = output_dim
        self.w_inh      = w_inh
        self.delta      = delta
        self.h_star     = h_star
        self.tau        = tau
        self.tau_L      = tau_L
        self.lebesgue_p = lebesgue_p
        
        #Set initial state
        self.activations = np.zeros(output_dim, dtype=float)
        self.weights     = np.random.normal(size=(output_dim, input_dim)) #normal size
        
        print(str(self))
        
    def __str__(self):
        return f"BioNeurons(input_dim={self.input_dim}, output_dim={self.output_dim}, w_inh={self.w_inh}, delta={self.delta}, lebesgue_p={self.lebesgue_p}, h_star={self.h_star}, tau={self.tau}, tau_L={self.tau_L})"
    
    def __repr__(self):
        return str(self)
    
    @jit
    def forward(self, inputs : np.ndarray) -> np.ndarray:
        """
        Sets the visible units to `inputs`, and propagates this signal through the weights to compute the network's output.
        
        Parameters
        ----------
        inputs : np.ndarray of shape (self.input_dim,)
            Vector with values for the input neurons
        
        Returns
        -------
        out : np.ndarray of shape (self.output_dim,)
            Vector with values for the hidden units (before activation).
        """
        
        return np.dot(self.weights, inputs)
    
    @staticmethod
    def currents(inputs : np.ndarray, weights : np.ndarray, lebesgue_p : float) -> np.ndarray:
        """
        Computes currents at each hidden neuron (eq. 8 from [1]). 
        
        Formula is:
        $$ I_\mu = <W, v>_\mu = \sum_i sgn(W_{\mu i}) |W_{\mu i}|^{p-1} v_i $$
        where p is self.lebesgue_p
        
        Parameters
        ----------
        inputs : np.ndarray of shape (self.input_dim,)
            Vector with values for the input neurons
        weights : np.ndarray of shape (self.output_dim, self.input_dim)
            Weights
        
        Returns
        -------
        currents : np.ndarray of shape (self.output_dim,)
            Vector with currents at each of the hidden neurons
        """
        
        return np.dot( np.sign(weights) * np.abs(weights) ** (lebesgue_p - 1), inputs )
    
    def dh_dt(self,
              t : float,
              y : np.ndarray,
              inputs : np.ndarray,
              weights : np.ndarray) -> np.ndarray:
        """
        Computes the right hand side of eq. 8 from [1], i.e. the time derivative of all neuron activations.
        
        Formula is:
            $$dh_\mu/dt = I_\mu - w_{inh} \sum_{\nu \neq \mu} \max(h_\nu, 0) - h_\mu $$ 
        where $I_\mu$ are the currents at the hidden neurons, and $h_\mu$ are their activations.
        
        Parameters
        ----------
        t : float
            Time instant, needed as interface to scipy.solve_ivp. 
        y : np.ndarray of shape (self.output_dim,)
            Value of hidden neurons
        inputs : np.ndarray of shape (self.input_dim,)
            Value of input neurons
        weights : np.ndarray of shape (self.output_dim, self.input_dim)
            Weights for the network
        
        Returns
        -------
        dh_dt : np.ndarray of shape (self.output_dim,)
            Vector containing the time derivative of each hidden neuron activation.
        """
        
        activations = y
        currents = self.currents(inputs, weights, lebesgue_p=self.lebesgue_p)
        
        positive_activations = activations * (activations > 0) #Set to 0 all the non-positive activations
        global_inhibition = np.sum(positive_activations) - positive_activations #Remove "self" activation by each term
        
        return (currents - self.w_inh * global_inhibition - activations) / self.tau
    
    def stationary_activations(self,
                               inputs : np.ndarray,
                               weights : np.ndarray) -> np.ndarray:
        """
        Computes the hidden neuron activations at stationarity, for a given value of the input neurons and weights of connections.
        For simplicity, this is done by numerically solving the differential equation (8 in [1]) for a some large time.
        
        Parameters
        ----------
        inputs : np.ndarray of shape (self.input_dim,)
            Values for the input neurons
        weights : np.ndarray of shape (self.output_dim, self.input_dim)
            Weights
        
        Returns
        -------
        h* : np.ndarray of shape (self.output_dim,)
            Vector with the activations of the hidden neurons at stationarity
        """
    
        large_time = 5 * self.tau
        
        sol = solve_ivp(neurons.dh_dt, [0, large_time], np.zeros(self.output_dim), args=(inputs, weights), t_eval = [large_time])
        #Stationary solution does not depend on the initual condition, which is here set to 0.
        
        return sol.y.flatten()
    
    @staticmethod
    def g(activations : np.ndarray, h_star : float = 0.8, delta : float = 0.4) -> np.ndarray:
        """
        Activation function for training, implementing temporal competition between the patterns (eq. 9 from [1]). 
        
        Parameters
        ----------
        activations : np.ndarray of shape (self.output_dim,)
            Value of post-synaptic (hidden) neurons
        
        Returns
        -------
        post_act : np.ndarray of shape (self.output_dim,)
            Post-activation values for the hidden neurons.
        """
        
        return np.where(activations < 0, 0, np.where(activations <= h_star, -delta, 1.))
    
    def dW_dt(self,
              t : float,
              y : np.ndarray,
              inputs : np.ndarray,
              flatten : bool = False) -> np.ndarray:
        """
        Computes the right hand side of eq. 3 from [1], i.e. the time derivative of the weights. 
        It is computed in the quasi-stationary approximation, i.e. by setting the hidden units activations to their 
        stationary value at each step (which is a good approximation since the timescale for weight evolution is
        much larger than that of individual neuron dynamics).
        
        Formula is:
            $$dW_{\mu i}/dt = [g(h_\mu) (v_i - I_\mu W_{\mu i})] / \tau_L$$
        where $v_i$ are the visible units, $h_\mu$ the hidden ones, $W_{\mu i}$ are the weights, and $I_\mu$ are the currents.
        
        Parameters
        ----------
        t : float
            Time instant, needed as interface to scipy.solve_ivp. 
        y : np.ndarray of shape (self.output_dim, self.input_dim), or (self.output_dim * self.input_dim,) if flatten=True
            Weights
        inputs : np.ndarray of shape (self.input_dim,)
            Value of input neurons
        flatten : bool
            If True, the output will be flattened to a 1D vector, which is compatible with scipy.solve_ivp.
        
        Returns
        -------
        dW_dt : np.ndarray of shape (self.output_dim, self.input_dim), or (self.output_dim * self.input_dim,) if flatten=True
            Time derivatives of weights
        """
        
        if flatten:
            weights = y.reshape(*self.weights.shape)
        else:
            weights = y
            
        #currents = self.currents(inputs, weights, lebesgue_p=self.lebesgue_p)
        
        activations = self.stationary_activations(inputs, weights)
        #Quasi-stationary approximation: use stationary value of activations, since their evolution is much faster
        
        post_activations = self.g(activations, h_star=self.h_star, delta=self.delta)
        
        result = (np.outer(post_activations, inputs) - ((post_activations * activations).reshape(-1, 1) * weights)) / self.tau_L
        
        return result.flatten() if flatten else result

In [3]:
#Tests to see if the above functions work
neurons = BioNeurons(w_inh=.3)
neurons.dh_dt(None, neurons.activations, np.random.normal(size=2), neurons.weights) 

BioNeurons(input_dim=2, output_dim=4, w_inh=0.3, delta=0.4, lebesgue_p=4.0, h_star=0.1, tau=1, tau_L=100)


array([ 0.15526475,  1.64196229, -0.17358915, -5.43402091])

In [4]:
neurons.stationary_activations(np.random.normal(size=2), neurons.weights) 

array([-0.73295296, -1.63071772, -0.57035217,  2.5805892 ])

In [5]:
#Visualize effect of w_inh
neurons = BioNeurons(output_dim=10)

t_end = 5
ts = np.linspace(0, t_end, 100)

x = np.random.rand(2)
def activation_evolution(t, w_inh, lebesgue_p):
    neurons.w_inh = w_inh
    neurons.lebesgue_p = lebesgue_p
    
    sol = solve_ivp(neurons.dh_dt, [0, t_end], neurons.activations, args=(x, neurons.weights), t_eval = t)
    
    return sol.y.T


p_values = np.arange(2, 8)

fig, ax = plt.subplots()
controls = iplt.plot(ts, activation_evolution, w_inh=(0, 1, 100), lebesgue_p=p_values, label=[f"$h_{i}$" for i in range(neurons.output_dim)])
_ = plt.legend()
plt.xlabel("Time $t$")
plt.ylabel("Activations $h_\mu$")
plt.show()

BioNeurons(input_dim=2, output_dim=10, w_inh=0.63, delta=0.4, lebesgue_p=4.0, h_star=0.1, tau=1, tau_L=100)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

VBox(children=(HBox(children=(IntSlider(value=0, description='w_inh', max=99, readout=False), Label(value='0.0…

In [6]:
#Save plot
n_neurons = 15
neurons = BioNeurons(output_dim=n_neurons, w_inh=.3, h_star=.8, lebesgue_p=2.)

t_end = 5
ts = np.linspace(0, t_end, 100)

x = np.random.rand(2)

ts = np.linspace(0, t_end, 100)
sol = solve_ivp(neurons.dh_dt, [0, t_end], neurons.activations, args=(x, neurons.weights), t_eval = ts)  

w_inh = neurons.w_inh

neurons.w_inh = 0
sol2 = solve_ivp(neurons.dh_dt, [0, t_end], neurons.activations, args=(x, neurons.weights), t_eval = ts)  

with mpl2latex(True):
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=latex_figsize(wf=1.2, hf=.5, columnwidth=318.67))
    fig.patch.set_facecolor('none')
    
    currents = neurons.currents(x, weights=neurons.weights, lebesgue_p=neurons.lebesgue_p)
    sort_currents = np.argsort(currents)
    stat_activations = neurons.stationary_activations(x, neurons.weights)[sort_currents]

    colors = plt.cm.plasma((stat_activations - np.min(stat_activations)) / (np.max(stat_activations) - np.min(stat_activations)))

    for i, solution in enumerate(sol.y[sort_currents]):
        ax1.plot(ts, solution, color=colors[i], lw=1)
    
    for i, solution in enumerate(sol2.y[sort_currents]):
        ax2.plot(ts, solution, color=colors[i], lw=1)

    ax1.set_xlabel('Time $t$')
    ax2.set_xlabel('Time $t$')
    ax1.set_ylabel(r'Activation $h_\mu$')
    ax1.set_title(f"(w\\_inh={w_inh}, p={neurons.lebesgue_p})", fontsize=8, y=.97)
    ax2.set_title(f"(w\\_inh=0, p={neurons.lebesgue_p})", fontsize=8, y=.97)
    plt.suptitle("Neuron Dynamics", x=.51, y=.99)
    
    ax1.patch.set_facecolor('white')
    ax2.patch.set_facecolor('white')
    plt.show()
    
    plt.savefig("Plots/neuron_dynamics.pdf", transparent=True, bbox_inches='tight')

BioNeurons(input_dim=2, output_dim=15, w_inh=0.3, delta=0.4, lebesgue_p=2.0, h_star=0.8, tau=1, tau_L=100)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
#Test with MNIST dataset

try: #Load MNIST dataset
    X = np.load("MNIST_features.npy", allow_pickle=True)
    y = np.load("MNIST_labels.npy", allow_pickle=True)
except IOError: #If not present, download it from the net
    X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False) #Return tuple (X=features, y=labels) as numpy array
    #(as_frame=False => do not use Pandas DataFrame)

    np.save("MNIST_features.npy", X)
    np.save("MNIST_labels.npy", y)

X = X/255. #Apply normalization

class RandomPicker():
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __call__(self):
        return random.choice(self.dataset)
    
mnist_sample = RandomPicker(X)

def draw_mnist_sample(data):
    """
    Draw a sample from MNIST
    """
    
    plt.figure()
    plt.imshow(data.reshape(28, 28), cmap='gray', vmin=0, vmax=1)
    plt.show()

In [8]:
#Weights evolution
neurons = BioNeurons(output_dim=15, input_dim=28**2, lebesgue_p=2., delta=0.4, h_star=.8, w_inh=15)

BioNeurons(input_dim=784, output_dim=15, w_inh=15, delta=0.4, lebesgue_p=2.0, h_star=0.8, tau=1, tau_L=100)


In [9]:
eval_time = 5
x = mnist_sample()

#In general, not all weights are updated at every iteration
#For the given hyperparameters (e.g. delta, w_inh, lebesgue_p...) the percentage of weights that are changed
#in the first iteration is:
non_zero_deltas_percentage = len(np.flatnonzero(neurons.dW_dt(0., neurons.weights, x, flatten=True))) / len(neurons.weights.flatten()) * 100 
#Given the formula of dW_dt, this is the same percentage of non-zero post-activations
print("Percentange of non-zero deltas: {:.2f}%".format(non_zero_deltas_percentage))

#Empirically, the model is most likely to converge (on the MNIST dataset) if this number is around 5-15%. 
#A higher number means that there are "too many neurons active at once", and w_inh should be set higher to have a higher lateral inhibition
#Otherwise, neurons "won't diversify much", and are likely to converge to very similar patterns.
#A lower number means that inhibition is too high, and so the model will take long to converge, or not converge at all.

Percentange of non-zero deltas: 6.67%


In [10]:
%time sol = solve_ivp(neurons.dW_dt, [0, eval_time], neurons.weights.flatten(), args=(x, True), t_eval = [eval_time]) 
#Measure time of an iteration

Wall time: 286 ms


In [11]:
def draw_weights(weights : np.ndarray,
                 reshape_dim : tuple[int, int] = (28, 28),
                 max_per_row : int = 5,
                 max_rows : int = 5,
                 fig = None): 
    """
    Plot the first few weights as matrices. `weights` should be an array of shape (output_dim, input_dim), i.e.
    `weights[i,j]` is the weight connecting the $j$-th neuron of a layer $n$ to the $i$-th neuron of the $n+1$ layer.
    Namely, all the weights connected to the $i$-th output neuron are the ones in the $i$-th row of `weights`.
    These weights are reshaped according to `reshape_dim` to construct a matrix. The weight matrices of the first neurons
    are then plotted in a grid of up to `max_rows` rows and `max_per_row` columns. 
    """
    
    #Shape of weights is (output_dim, input_dim)
    
    if fig is None:
        fig = plt.figure()
    else:
        plt.clf()
    
    nc = np.max(np.abs(weights)) #(Absolute) range of weights
    
    n_neurons = weights.shape[0] 
    
    #---Infer number of rows/columns---#
    n_columns = max_per_row
    n_rows = n_neurons // max_per_row
    
    if n_rows > max_rows:
        n_rows = max_rows
    if n_rows == 1:
        n_columns = n_neurons
    if n_neurons > max_rows * max_per_row:
        n_neurons = max_rows * max_per_row
    
    #---Generate grid---#
    whole_image = np.zeros(reshape_dim * np.array([n_rows, n_columns]))
    
    i_row = 0
    i_col = 0
    size_x, size_y = reshape_dim
    
    plt.tight_layout()
    
    for index_neuron in range(n_neurons):
        img = weights[index_neuron,...].reshape(reshape_dim)
        whole_image[i_row * size_x:(i_row+1) * size_x,i_col * size_y:(i_col+1) * size_y] = img
        i_col += 1
        
        if (i_col >= n_columns):
            i_col = 0
            i_row += 1
    
    #---Plot---#
    img_plotted = plt.imshow(whole_image, cmap='bwr', vmin=-nc, vmax=nc, interpolation=None)
    fig.colorbar(img_plotted,ticks=[np.amin(whole_image), 0, np.amax(whole_image)])
    plt.show()

In [None]:
#Training loop. ATTENTION! This cell takes VERY LONG to execute (30 min on my PC). 
#To just see the results, skip this cell and load the saved weights.

num_trials = 300

#Each sample is "shown" to the model for a time `eval_time`, which starts at `eval_time_max` and decreases
#linearly to `eval_time_min`. The intuition is that, at first, the model needs to "learn from scratch", and
#so we go "more slowly". Recall that the timescale of plasticity is, by default, 100. Intuitively, the evolution of each weight
#takes into consideration mostly the samples falling inside this timescale. At the start, we want these to be few, 
#so that the neurons "don't get too much confused" and can stick to some specific sample.
#Then, when weights are mostly fixed, we can "go faster" and have a lower `eval_time`. 

#This procedure is motivated by numerical experiments, and is roughly analogous to "lowering the learning rate"
#during training of a supervised learning.

eval_time_max = 25. 
eval_time_min = 5

all_norms = []
fig = plt.figure()

for i in tqdm(range(num_trials)):
    x = mnist_sample()
    
    m = (eval_time_max - eval_time_min) / num_trials
    eval_time = eval_time_max - m * i #Time of presentation of a single sample
    
    sol = solve_ivp(neurons.dW_dt, [0, eval_time], neurons.weights.flatten(), args=(x, True), t_eval = [eval_time])
    
    neurons.weights = sol.y.reshape(*neurons.weights.shape)
    norms = np.sum(np.abs(neurons.weights) ** neurons.lebesgue_p, axis=1)
    
    all_norms.append(norms)
    
    draw_weights(neurons.weights, fig=fig)
    
    plt.savefig(f"figs/weights{i}.png", transparent=True, bbox_inches='tight') #Save images for animation

#Save weights
with open('bio_diffeq_converging', 'wb') as file:
    pickle.dump(neurons, file)

In [12]:
#Load saved weights and show them
with open('bio_diffeq_converging', 'rb') as file:
    neurons = pickle.load(file)
    
draw_weights(neurons.weights)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …