<h1 style="font-size: 35px;"> <b>Hebbian Learning</b></h1>
Recent research [1] suggests that our brain doesn't operate based on a global update rule, as proposed with the gradient descent algorithm, but on a "simple" local update rule. Thus comes the urge to find new and biologically more accurate training mechanisms. 
<br><br>

One approach, which I will implement in this notebook, is based on a postulate from Donald Hebb in his book **The Organization of Behavior**, realeased in 1949. The following paper, on which this notebook is based on, explores this approach in the area of Meta-Learning
<br><br>
<div style="height:250px">
    <img src="../assets/paper-preview.png" width="150", border="1px" style="vertical-align: middle;float:left;">
    
    <h3 style="padding: 66px;float:left"><b>Meta-Learning through Hebbian Plasticity in Random Networks</b><br><i>Elias Najarro Sebatian Risi - IT University Copenhagen</i></h3>
</div>
<div>

This notebook consists of a simple implementation of a multilayer perceptron model with the Hebbian weight update rule using PyTorch, a ES alogirithm need for training and a bunch of custom made graphics and texts explaining the mechanisms.
</div>

## 1.) **Reinforcement Learning vs Hebbian Learning**
Unlike in classical reinforcement learning, our goal is not to learn a static weighted policy network, but a hebbian update rule, which adjusts our network based on the inputs at runtime.

<br>
<div style="text-align:center"><img src="../assets/rlvshl.png"></div>
<br>

In [1]:
# some of this code is inspired or straight up copied from the official implementation
# https://github.com/enajx/HebbianMetaLearning
# Star it. Now. >:-) *stares menacingly*

# imports
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

import numpy as np

import time

from tqdm.notebook import tqdm

import gym
import pybullet_envs

from gym import wrappers as w
from gym.spaces import Discrete, Box

from typing import List, Any

from numba import jit # I AM SPEEEEED (JIT-Compilation)
import multiprocessing as mp # I AM SPEEEED Part 2 (Multiprocessing)

from os.path import exists
from os import mkdir

## 2.) **A simple MLP with the Hebbian Update Rule**
As in RL, we start out by building a simple policy network. It is a fully connected Multi Layer Perceptron without bias.

<br>
<div style="text-align:center"><img src="../assets/hebbiannetwork.png"></div>
<br>

The key differences is the way we update it's weights. Gradient descend, which is a global update rule and therefore not biologically accurate, is replaced by the hebbian update rule. 

$$\Delta w_{ij} = \eta_w \cdot o_i o_j$$

This rule updates the weights dynamically throughout an episode using some evolved term inspired by the hebbian update rule.

In [2]:
class HebbianNetwork(nn.Module):
    "A simple MLP without bias"
    def __init__(self, input_space, action_space):
        super(HebbianNetwork, self).__init__()

        self.l1 = nn.Linear(input_space, 128, bias=False)
        self.l2 = nn.Linear(128, 64, bias=False)
        self.l3 = nn.Linear(64, action_space, bias=False)

    def forward(self, x):
        state = torch.as_tensor(x).float().detach()
        
        x1 = torch.tanh(self.l1(state))   
        x2 = torch.tanh(self.l2(x1))
        o = self.l3(x2)  
         
        return state, x1, x2, o
    
# NUMBA JIT-Compilation goes brrrrrrrrrrr
@jit(nopython=True)
def hebbian_update_rule(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3):
    
        #print(o3.shape)
       
        heb_offset = 0
        # Layer 1         
        for i in range(weights1_2.shape[1]): 
            for j in range(weights1_2.shape[0]):  
                idx = (weights1_2.shape[0]-1)*i + i + j
                weights1_2[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j]
                                                           + heb_coeffs[idx][1] * o0[i] 
                                                           + heb_coeffs[idx][2]         * o1[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]

        heb_offset += weights1_2.shape[1] * weights1_2.shape[0]
        # Layer 2
        for i in range(weights2_3.shape[1]): 
            for j in range(weights2_3.shape[0]):  
                idx = heb_offset + (weights2_3.shape[0]-1)*i + i+j
                weights2_3[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j]
                                                           + heb_coeffs[idx][1] * o1[i] 
                                                           + heb_coeffs[idx][2]         * o2[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]
    
        heb_offset += weights2_3.shape[1] * weights2_3.shape[0]
        # Layer 3
        for i in range(weights3_4.shape[1]): 
            for j in range(weights3_4.shape[0]):  
                idx = heb_offset + (weights3_4.shape[0]-1)*i + i+j 
                weights3_4[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j]
                                                           + heb_coeffs[idx][1] * o2[i] 
                                                           + heb_coeffs[idx][2]         * o3[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]

        return weights1_2, weights2_3, weights3_4

Since the simplified hebbian update rule 

$$\Delta w_{ij} = \eta_w \cdot o_i o_j$$

cannot be used in supervised learning tasks due to only beeing a local update rule, we need to extend this rule in order to use it.

The paper introduces four evolveable parameters for the hebbian update rule:

- correlation term $A_w$
- presynaptic term $B_w$
- postsynaptic term $C_w$
- bias $D_w$

Which leads to a modified update rule, where the coefficients $A_w$, $B_w$, $C_w$ define the update dynamics of the network weights

<br>
<div style="text-align:center"><img src="../assets/hebbianrule.png"></div>
<br>

These coefficients can be evolved using a basic ES algorithm which maximizes a cummulative reward.

## 3.) **The Task**
<img src="../assets/ant.gif">

## 4.) **Evolution strategies**

In [3]:
def fitness_hebb(environment : str, init_weights:str, evolved_parameters: np.ndarray) -> float:
    """
    Evaluate the policy network using some evolved parameters and environment.
    """
    def weights_init(m):
        if isinstance(m, torch.nn.Linear):
            if init_weights == 'xa_uni':  
                torch.nn.init.xavier_uniform(m.weight.data, 0.3)
            elif init_weights == 'sparse':  
                torch.nn.init.sparse_(m.weight.data, 0.8)
            elif init_weights == 'uni':  
                torch.nn.init.uniform_(m.weight.data, -0.1, 0.1)
            elif init_weights == 'normal':  
                torch.nn.init.normal_(m.weight.data, 0, 0.024)
            elif init_weights == 'ka_uni':  
                torch.nn.init.kaiming_uniform_(m.weight.data, 3)
            elif init_weights == 'uni_big':
                torch.nn.init.uniform_(m.weight.data, -1, 1)
            elif init_weights == 'xa_uni_big':
                torch.nn.init.xavier_uniform(m.weight.data)
            elif init_weights == 'ones':
                torch.nn.init.ones_(m.weight.data)
            elif init_weights == 'zeros':
                torch.nn.init.zeros_(m.weight.data)
            elif init_weights == 'default':
                pass
            
    # Unpack evolved parameters
    hebb_coeffs = evolved_parameters

    
    # disable the autograd system
    with torch.no_grad():
                    
        # Load environment
        try:
            env = gym.make(environment, verbose = 0)
        except:
            env = gym.make(environment)
            
        # env.render()  # render bullet envs

        # get the input dimensions of the environment
        input_dim = env.observation_space.shape[0]
            
        # Determine action space dimension
        action_dim = env.action_space.shape[0]
        
        # Initialize policy network
        p = HebbianNetwork(input_dim, action_dim)          
          
        # Randomly sample initial weights from chosen distribution
        p.apply(weights_init)
        p = p.float()
        
        # Unpack network's weights
        weights1_2, weights2_3, weights3_4 = list(p.parameters())
            
        # JIT
        weights1_2 = weights1_2.detach().numpy()
        weights2_3 = weights2_3.detach().numpy()
        weights3_4 = weights3_4.detach().numpy()
        
        # reset the environment
        observation = env.reset() 

        # Burnout phase for the bullet ant so it starts off from the floor
        if environment == 'AntBulletEnv-v0':
            action = np.zeros(8)
            for _ in range(40):
                __ = env.step(action)        
        
        # Normalize weights flag for non-bullet envs
        normalised_weights = False if environment[-12:-6] == 'Bullet' else True


        # Main loop
        neg_count = 0 # count the amount of times we receive a negative reward
        rew_ep = 0 # cummulative reward over an episode
        t = 0 # timestep
        
        while True:
            
            # For obaservation ∈ gym.spaces.Discrete, we one-hot encode the observation
            if isinstance(env.observation_space, Discrete): 
                observation = (observation == torch.arange(env.observation_space.n)).float()
            
            o0, o1, o2, o3 = p([observation])
            
            # JIT
            o0 = o0.numpy().flatten()
            o1 = o1.numpy().flatten()
            o2 = o2.numpy().flatten()
            
            # preprocess the observation
            o3 = torch.tanh(o3).numpy().flatten()
            action = o3
            
            
            # Environment simulation step
            observation, reward, done, info = env.step(action)  
            if environment == 'AntBulletEnv-v0': reward = env.unwrapped.rewards[1] # Distance walked
            rew_ep += reward
            
            # env.render('human') # Gym envs
                                       
            # Early stopping conditions
            if environment[-12:-6] == 'Bullet':
                ## Special stopping condition for bullet envs
                # always play 200 episodes
                if t > 200:
                    # after 200 episodes: count the amount of negative
                    # reward we receive in a row
                    neg_count = neg_count+1 if reward < 0.0 else 0
                    
                    # if we receive a negative reward 30 times in row, stop
                    if (done or neg_count > 30):
                        break
            else:
                if done:
                    break
            
            t += 1
            
            #### Episodic/Intra-life hebbian update of the weights
            weights1_2, weights2_3, weights3_4 = hebbian_update_rule(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3)
            

            # Normalise weights per layer
            if normalised_weights == True:
                (a, b, c) = (0, 1, 2) if not pixel_env else (2, 3, 4)
                list(p.parameters())[a].data /= list(p.parameters())[a].__abs__().max()
                list(p.parameters())[b].data /= list(p.parameters())[b].__abs__().max()
                list(p.parameters())[c].data /= list(p.parameters())[c].__abs__().max()
        
        # close the environment
        env.close()

    return rew_ep


In [12]:
def agent_worker(arg):
    get_reward_func,  env,  init_weights, coeffs = arg
    
    wp = np.array(coeffs)
    decay = - 0.01 * np.mean(wp**2)
    
    r = get_reward_func(env,  init_weights, coeffs) + decay
    
    return r

In [13]:
def compute_ranks(x):
    """
    Returns rank as a vector of len(x) with integers from 0 to len(x)
    """
    assert x.ndim == 1
    ranks = np.empty(len(x), dtype=int)
    ranks[x.argsort()] = np.arange(len(x))
    return ranks

def compute_centered_ranks(x):
    """
    Maps x to [-0.5, 0.5] and returns the rank
    """
    y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
    y /= (x.size - 1)
    y -= .5
    return y

class EvolutionStrategies(object):
    def __init__(self, environment, init_weights="uni", population_size=100, sigma=0.1, learning_rate=0.2, decay=0.995, num_threads=-1, distribution = 'normal'):
                             
        self.environment = environment                         
        self.init_weights = init_weights              
        self.POPULATION_SIZE = population_size
        self.SIGMA = sigma
        self.learning_rate = learning_rate            
        self.decay = decay
        self.num_threads = mp.cpu_count() if num_threads == -1 else num_threads
        print(f"[INFO] Using upto {self.num_threads} threads.")
        self.update_factor = self.learning_rate / (self.POPULATION_SIZE * self.SIGMA)
        self.distribution = distribution
        
        self.coefficients_per_synapse = 6
        
        # make the environment
        env = gym.make(environment)

        # get the input dimensions of the environment
        input_dim = env.observation_space.shape[0]
            
        # Determine action space dimension
        action_dim = env.action_space.shape[0]
        
        # initialize the weights
        plastic_weights = (128*input_dim) + (64*128) + (action_dim*64)
        
        if self.distribution == 'uniform': 
            self.coeffs = np.random.uniform(-1,1,(plastic_weights, self.coefficients_per_synapse)) 
        elif self.distribution == 'normal':
            self.coeffs = torch.randn(plastic_weights, self.coefficients_per_synapse).detach().numpy().squeeze() 
        
        # set the reward function
        self.get_reward = fitness_hebb
        
    def _get_params_try(self, w, p):

        param_try = []
        for index, i in enumerate(p):
            jittered = self.SIGMA * i
            param_try.append(w[index] + jittered)
        param_try = np.array(param_try).astype(np.float32)
        
        return param_try
    
    def get_coeffs(self):
        return self.coeffs.astype(np.float32)
    
    def _get_population(self, coevolved_param = False): 
        
    
        # x_ = np.random.randn(int(self.POPULATION_SIZE/2), self.coeffs.shape[0], self.coeffs[0].shape[0])
        # population = np.concatenate((x_,-1*x_)).astype(np.float32)
        
        population = []
            
        if coevolved_param == False:
            for i in range( int(self.POPULATION_SIZE/2) ):
                x = []
                x2 = []
                for w in self.coeffs:
                    j = np.random.randn(*w.shape)             # j: (coefficients_per_synapse, 1) eg. (5,1)
                    x.append(j)                                                   # x: (coefficients_per_synapse, number of synapses) eg. (92690, 5)
                    x2.append(-j) 
                population.append(x)                                              # population : (population size, coefficients_per_synapse, number of synapses), eg. (10, 92690, 5)
                population.append(x2)
                
        elif coevolved_param == True:
            for i in range( int(self.POPULATION_SIZE/2) ):
                x = []
                x2 = []
                for w in self.initial_weights_co:
                    j = np.random.randn(*w.shape)
                    x.append(j)                    
                    x2.append(-j) 

                population.append(x)               
                population.append(x2)
                
        return np.array(population).astype(np.float32)
    
    def _get_rewards(self, pool, population):
        if pool is not None:

            worker_args = []
            for p in population:

                heb_coeffs_try1 = []
                for index, i in enumerate(p):
                    jittered = self.SIGMA * i
                    heb_coeffs_try1.append(self.coeffs[index] + jittered) 
                heb_coeffs_try = np.array(heb_coeffs_try1).astype(np.float32)

                worker_args.append( (self.get_reward, self.environment,  self.init_weights,  heb_coeffs_try) )
                
            rewards = pool.map(agent_worker, worker_args)
            
        else:
            rewards = []
            for p in population:
                heb_coeffs_try = np.array(self._get_params_try(self.coeffs, p))
                rewards.append(self.get_reward(self.environment,  self.init_weights, heb_coeffs_try))
        
        rewards = np.array(rewards).astype(np.float32)
        return rewards
    
    def _update_coeffs(self, rewards, population):
        rewards = compute_centered_ranks(rewards)

        std = rewards.std()
        if std == 0:
            raise ValueError('Variance should not be zero')
                
        rewards = (rewards - rewards.mean()) / std
                
        for index, c in enumerate(self.coeffs):
            layer_population = np.array([p[index] for p in population])
                      
            self.update_factor = self.learning_rate / (self.POPULATION_SIZE * self.SIGMA)                
            self.coeffs[index] = c + self.update_factor * np.dot(layer_population.T, rewards).T 

        if self.learning_rate > 0.001:
            self.learning_rate *= self.decay

        #Decay sigma
        if self.SIGMA>0.01:
            self.SIGMA *= 0.999 
            
            
    def run(self, iterations, print_step=10, path='heb_coeffs'):                                                    
        
        id_ = str(int(time.time()))
        if not exists(path + '/' + id_):
            mkdir(path + '/' + id_)
            
        print('Run: ' + id_ + '\n\n........................................................................\n')
            
        pool = mp.Pool(self.num_threads) if self.num_threads > 1 else None
        
        generations_rewards = []

        for iteration in range(iterations):                                                                     # Algorithm 2. Salimans, 2017: https://arxiv.org/abs/1703.03864
            population = self._get_population()                                                                 # Sample normal noise:         Step 5
            rewards = self._get_rewards(pool, population)                                                       # Compute population fitness:  Step 6
            self._update_coeffs(rewards, population)                                                            # Update coefficients:         Steps 8->12
                
                
            # Print fitness and save Hebbian coefficients
            if (iteration + 1) % print_step == 0:
                rew_ = rewards.mean()
                print('iter %4i | reward: %3i |  update_factor: %f  lr: %f | sum_coeffs: %i sum_abs_coeffs: %4i' % (iteration + 1, rew_ , self.update_factor, self.learning_rate, int(np.sum(self.coeffs)), int(np.sum(abs(self.coeffs)))), flush=True)
                
                if rew_ > 100:
                    torch.save(self.get_coeffs(),  path + "/"+ id_ + '/HEBcoeffs__' + self.environment + "__rew_" + str(int(rew_)) + '__' + self.hebb_rule + "__init_" + str(self.init_weights) + "__pop_" + str(self.POPULATION_SIZE) + '__coeffs' + "__{}.dat".format(iteration))
                generations_rewards.append(rew_)
                np.save(path + "/"+ id_ + '/Fitness_values_' + id_ + '_' + self.environment + '.npy', np.array(generations_rewards))
       
        if pool is not None:
            pool.close()
            pool.join()

In [14]:
es = EvolutionStrategies("AntBulletEnv-v0")

[INFO] Using upto 8 threads.


In [15]:
es.run(100, print_step=1)

Run: 1621328743

........................................................................





iter    1 | reward:   0 |  update_factor: 0.020000  lr: 0.199000 | sum_coeffs: 336 sum_abs_coeffs: 59777
iter    2 | reward:   5 |  update_factor: 0.019920  lr: 0.198005 | sum_coeffs: 406 sum_abs_coeffs: 61177
iter    3 | reward:   7 |  update_factor: 0.019840  lr: 0.197015 | sum_coeffs: 337 sum_abs_coeffs: 62113
iter    4 | reward:  16 |  update_factor: 0.019761  lr: 0.196030 | sum_coeffs: 382 sum_abs_coeffs: 63402
iter    5 | reward:   7 |  update_factor: 0.019682  lr: 0.195050 | sum_coeffs: 389 sum_abs_coeffs: 64508
iter    6 | reward:  13 |  update_factor: 0.019603  lr: 0.194075 | sum_coeffs: 330 sum_abs_coeffs: 65751
iter    7 | reward:  19 |  update_factor: 0.019524  lr: 0.193104 | sum_coeffs: 331 sum_abs_coeffs: 66782
iter    8 | reward:  18 |  update_factor: 0.019446  lr: 0.192139 | sum_coeffs: 334 sum_abs_coeffs: 67518
iter    9 | reward:  14 |  update_factor: 0.019368  lr: 0.191178 | sum_coeffs: 432 sum_abs_coeffs: 68514
iter   10 | reward:  15 |  update_factor: 0.019291  lr:

Process ForkPoolWorker-19:
Process ForkPoolWorker-23:
Traceback (most recent call last):
Process ForkPoolWorker-20:
Process ForkPoolWorker-24:
Process ForkPoolWorker-22:
Process ForkPoolWorker-18:
Process ForkPoolWorker-17:
Traceback (most recent call last):
Process ForkPoolWorker-21:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceba

KeyboardInterrupt: 

KeyboardInterrupt
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
KeyboardInterrupt
KeyboardInterrupt
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/queues.py", line 355, in get
    with self._rlock:
KeyboardInterrupt
  File "/home/luanademi/anaconda3/envs/datascience/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt


In [None]:
gym.make("AntBulletEnv-v0").action_space.sample().shape

# References
