<h1 style="font-size: 35px;"> <b>Hebbian Learning</b></h1>
Recent research [1] suggests that our brain doesn't operate based on a global update rule, as proposed with the gradient descent algorithm, but on a "simple" local update rule. Thus comes the urge to find new and biologically more accurate training mechanisms. 
<br><br>

One approach, which I will implement in this notebook, is based on a postulate from Donald Hebb in his book **The Organization of Behavior**, realeased in 1949. The following paper, on which this notebook is based on, explores this approach in the area of Meta-Learning
<br><br>
<div style="height:250px">
    <img src="../assets/paper-preview.png" width="150", border="1px" style="vertical-align: middle;float:left;">
    
    <h3 style="padding: 66px;float:left"><b>Meta-Learning through Hebbian Plasticity in Random Networks</b><br><i>Elias Najarro Sebatian Risi - IT University Copenhagen</i></h3>
</div>
<div>

This notebook consists of a simple implementation of a multilayer perceptron model with the Hebbian weight update rule using PyTorch, a ES alogirithm need for training and a bunch of custom made graphics and texts explaining the mechanisms.
</div>

## 1.) **Reinforcement Learning vs Hebbian Learning**
Unlike in classical reinforcement learning, our goal is not to learn a static weighted policy network, but a hebbian update rule, which adjusts our network based on the inputs at runtime.

<br>
<div style="text-align:center"><img src="../assets/rlvshl.png"></div>
<br>

In [7]:
# some of this code is inspired or straight up copied from the official implementation
# https://github.com/enajx/HebbianMetaLearning
# Star it. Now. >:-)

# imports
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

import numpy as np

from tqdm.notebook import tqdm

import gym
import pybullet_envs

from gym import wrappers as w
from gym.spaces import Discrete, Box

from typing import List, Any

from numba import njit # I AM SPEEEEED (JIT-Compilation)

## 2.) **A simple MLP with the Hebbian Update Rule**
As in RL, we start out by building a simple policy network. It is a fully connected Multi Layer Perceptron without bias.

<br>
<div style="text-align:center"><img src="../assets/hebbiannetwork.png"></div>
<br>

The key differences is the way we update it's weights. Gradient descend, which is a global update rule and therefore not biologically accurate, is replaced by the hebbian update rule. 

$$\Delta w_{ij} = \eta_w \cdot o_i o_j$$

This rule updates the weights dynamically throughout an episode using some evolved term inspired by the hebbian update rule.

In [8]:
class HebbianNetwork(nn.Module):
    "A simple MLP without bias"
    def __init__(self, input_space, action_space):
        super(HebbianNetwork, self).__init__()

        self.l1 = nn.Linear(input_space, 128, bias=False)
        self.l2 = nn.Linear(128, 64, bias=False)
        self.l3 = nn.Linear(64, action_space, bias=False)

    def forward(self, x):
        state = torch.as_tensor(x).float().detach()
        
        x1 = torch.tanh(self.l1(state))   
        x2 = torch.tanh(self.l2(x1))
        o = self.l3(x2)  
         
        return state, x1, x2, o
    
# NUMBA JIT-Compilation goes brrrrrrrrrrr
@njit
def hebbian_update_rule(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3):
       
        heb_offset = 0
        # Layer 1         
        for i in range(weights1_2.shape[1]): 
            for j in range(weights1_2.shape[0]):  
                idx = (weights1_2.shape[0]-1)*i + i + j
                weights1_2[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j]
                                                           + heb_coeffs[idx][1] * o0[i] 
                                                           + heb_coeffs[idx][2]         * o1[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]

        heb_offset += weights1_2.shape[1] * weights1_2.shape[0]
        # Layer 2
        for i in range(weights2_3.shape[1]): 
            for j in range(weights2_3.shape[0]):  
                idx = heb_offset + (weights2_3.shape[0]-1)*i + i+j
                weights2_3[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j]
                                                           + heb_coeffs[idx][1] * o1[i] 
                                                           + heb_coeffs[idx][2]         * o2[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]
    
        heb_offset += weights2_3.shape[1] * weights2_3.shape[0]
        # Layer 3
        for i in range(weights3_4.shape[1]): 
            for j in range(weights3_4.shape[0]):  
                idx = heb_offset + (weights3_4.shape[0]-1)*i + i+j 
                weights3_4[:,i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j]
                                                           + heb_coeffs[idx][1] * o2[i] 
                                                           + heb_coeffs[idx][2]         * o3[j]  + heb_coeffs[idx][4]) + heb_coeffs[idx][5]

        return weights1_2, weights2_3, weights3_4

Since the simplified hebbian update rule 

$$\Delta w_{ij} = \eta_w \cdot o_i o_j$$

cannot be used in supervised learning tasks due to only beeing a local update rule, we need to extend this rule in order to use it.

The paper introduces four evolveable parameters for the hebbian update rule:

- correlation term $A_w$
- presynaptic term $B_w$
- postsynaptic term $C_w$
- bias $D_w$

Which leads to a modified update rule, where the coefficients $A_w$, $B_w$, $C_w$ define the update dynamics of the network weights

<br>
<div style="text-align:center"><img src="../assets/hebbianrule.png"></div>
<br>

These coefficients can be evolved using a basic ES algorithm which maximizes a cummulative reward.

## 3.) **Evolution strategies**

In [18]:
def fitness_hebb(environment : str, init_weights:str, evolved_parameters: np.ndarray) -> float:
    """
    Evaluate the policy network using some evolved parameters and environment.
    """
    def weights_init(m):
        if isinstance(m, torch.nn.Linear):
            if init_weights == 'xa_uni':  
                torch.nn.init.xavier_uniform(m.weight.data, 0.3)
            elif init_weights == 'sparse':  
                torch.nn.init.sparse_(m.weight.data, 0.8)
            elif init_weights == 'uni':  
                torch.nn.init.uniform_(m.weight.data, -0.1, 0.1)
            elif init_weights == 'normal':  
                torch.nn.init.normal_(m.weight.data, 0, 0.024)
            elif init_weights == 'ka_uni':  
                torch.nn.init.kaiming_uniform_(m.weight.data, 3)
            elif init_weights == 'uni_big':
                torch.nn.init.uniform_(m.weight.data, -1, 1)
            elif init_weights == 'xa_uni_big':
                torch.nn.init.xavier_uniform(m.weight.data)
            elif init_weights == 'ones':
                torch.nn.init.ones_(m.weight.data)
            elif init_weights == 'zeros':
                torch.nn.init.zeros_(m.weight.data)
            elif init_weights == 'default':
                pass
            
    # Unpack evolved parameters
    hebb_coeffs = evolved_parameters

    
    # disable the autograd system
    with torch.no_grad():
                    
        # Load environment
        try:
            env = gym.make(environment, verbose = 0)
        except:
            env = gym.make(environment)
            
        # env.render()  # render bullet envs

        # get the input dimensions of the environment
        input_dim = env.observation_space.shape[0]
            
        # Determine action space dimension
        action_dim = env.action_space.shape[0]
        
        # Initialize policy network
        p = HebbianNetwork(input_dim, action_dim)          
          
        # Randomly sample initial weights from chosen distribution
        p.apply(weights_init)
        p = p.float()
        
        # Unpack network's weights
        weights1_2, weights2_3, weights3_4 = list(p.parameters())
            
        # JIT
        weights1_2 = weights1_2.detach().numpy()
        weights2_3 = weights2_3.detach().numpy()
        weights3_4 = weights3_4.detach().numpy()
        
        # reset the environment
        observation = env.reset() 

        # Burnout phase for the bullet ant so it starts off from the floor
        if environment == 'AntBulletEnv-v0':
            action = np.zeros(8)
            for _ in range(40):
                __ = env.step(action)        
        
        # Normalize weights flag for non-bullet envs
        normalised_weights = False if environment[-12:-6] == 'Bullet' else True


        # Main loop
        neg_count = 0 # count the amount of times we receive a negative reward
        rew_ep = 0 # cummulative reward over an episode
        t = 0 # timestep
        
        while True:
            
            # For obaservation ∈ gym.spaces.Discrete, we one-hot encode the observation
            if isinstance(env.observation_space, Discrete): 
                observation = (observation == torch.arange(env.observation_space.n)).float()
            
            o0, o1, o2, o3 = p([observation])
            
            # JIT
            o0 = o0.numpy()
            o1 = o1.numpy()
            o2 = o2.numpy()
            
            # preprocess the observation
            if environment[-12:-6] == 'Bullet':
                o3 = torch.tanh(o3).numpy()
                action = o3
            else: 
                if isinstance(env.action_space, Box):
                    action = o3.numpy()                        
                    action = np.clip(action, env.action_space.low, env.action_space.high)  
                elif isinstance(env.action_space, Discrete):
                    action = np.argmax(o3).numpy()
                o3 = o3.numpy()

            
            # Environment simulation step
            observation, reward, done, info = env.step(action)  
            if environment == 'AntBulletEnv-v0': reward = env.unwrapped.rewards[1] # Distance walked
            rew_ep += reward
            
            # env.render('human') # Gym envs
                                       
            # Early stopping conditions
            if environment[-12:-6] == 'Bullet':
                ## Special stopping condition for bullet envs
                # always play 200 episodes
                if t > 200:
                    # after 200 episodes: count the amount of negative
                    # reward we receive in a row
                    neg_count = neg_count+1 if reward < 0.0 else 0
                    
                    # if we receive a negative reward 30 times in row, stop
                    if (done or neg_count > 30):
                        break
            else:
                if done:
                    break
            
            t += 1
            
            #### Episodic/Intra-life hebbian update of the weights
            weights1_2, weights2_3, weights3_4 = hebbian_update_rule(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3)
            

            # Normalise weights per layer
            if normalised_weights == True:
                (a, b, c) = (0, 1, 2) if not pixel_env else (2, 3, 4)
                list(p.parameters())[a].data /= list(p.parameters())[a].__abs__().max()
                list(p.parameters())[b].data /= list(p.parameters())[b].__abs__().max()
                list(p.parameters())[c].data /= list(p.parameters())[c].__abs__().max()
        
        # close the environment
        env.close()

    return rew_ep

# References
