In [1]:
import tasks
import models

import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
seed_task = 2
n_ALIF=0
n_LIF=100
n_rec= n_ALIF + n_LIF

In [3]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1,thr=0.01, local_connectivity=True, learning_rule="e_prop_autodiff", sparse_connectivity=False, refractory_period=1)


In [4]:
task_batches = list(tasks.pattern_generation(n_batches=1, batch_size=1, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
                                             weights=[0.2,0.2,0.2,0.2,0.2], n_population=100,
                                             f_input=10, trial_dur=2000))


In [5]:
batch = task_batches[0]

In [6]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 *losses.squared_error(targets=labels, predictions=logits)) # sum over batches and time
    
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.square(av_f_rate - f_target))
  return task_loss + regularization_loss
 

In [7]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [8]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(1,100))


In [9]:
state_1.spatial_params["ALIFCell_0"]["cells_loc"]

Array([[2, 3],
       [6, 0],
       [4, 3],
       [7, 8],
       [8, 8],
       [8, 5],
       [9, 4],
       [3, 1],
       [5, 0],
       [8, 7],
       [7, 2],
       [4, 6],
       [0, 4],
       [2, 1],
       [4, 9],
       [8, 2],
       [5, 1],
       [0, 6],
       [1, 2],
       [7, 4],
       [1, 6],
       [6, 2],
       [3, 9],
       [6, 6],
       [0, 0],
       [9, 3],
       [1, 7],
       [8, 0],
       [7, 6],
       [6, 5],
       [6, 8],
       [0, 9],
       [6, 1],
       [2, 7],
       [8, 9],
       [2, 6],
       [9, 5],
       [5, 5],
       [6, 3],
       [7, 3],
       [4, 7],
       [3, 3],
       [4, 5],
       [1, 9],
       [8, 3],
       [2, 9],
       [9, 1],
       [7, 9],
       [0, 3],
       [3, 2],
       [9, 8],
       [1, 5],
       [7, 0],
       [0, 2],
       [0, 1],
       [4, 0],
       [1, 3],
       [6, 9],
       [2, 2],
       [3, 4],
       [4, 1],
       [9, 6],
       [5, 2],
       [5, 3],
       [0, 7],
       [9, 7],
       [1,

In [10]:
state_1.eligibility_params['ReadOut_0']['feedback_weights']

Array([[-0.04560519],
       [-0.03792412],
       [ 0.02599526],
       [ 0.05817047],
       [ 0.06815916],
       [-0.06558562],
       [ 0.14032502],
       [-0.06831602],
       [-0.00951583],
       [ 0.03664244],
       [-0.04163158],
       [-0.0113758 ],
       [-0.10665695],
       [ 0.02838454],
       [ 0.01483625],
       [-0.07403877],
       [ 0.00496671],
       [-0.03907715],
       [ 0.05302151],
       [ 0.01195987],
       [ 0.06727689],
       [ 0.11256142],
       [-0.03522432],
       [-0.11596951],
       [ 0.06459189],
       [ 0.01746346],
       [-0.06061827],
       [ 0.11847678],
       [-0.07563414],
       [-0.07403923],
       [-0.08067524],
       [-0.14688464],
       [ 0.00432749],
       [-0.02054447],
       [-0.00952856],
       [ 0.00368338],
       [-0.10358795],
       [ 0.03402175],
       [-0.08424084],
       [-0.01010304],
       [ 0.04160201],
       [ 0.0536657 ],
       [-0.12492884],
       [-0.02599308],
       [ 0.02159717],
       [ 0

In [11]:
state_1.params['ReadOut_0']['readout_weights']

Array([[-0.04560519],
       [-0.03792412],
       [ 0.02599526],
       [ 0.05817047],
       [ 0.06815916],
       [-0.06558562],
       [ 0.14032502],
       [-0.06831602],
       [-0.00951583],
       [ 0.03664244],
       [-0.04163158],
       [-0.0113758 ],
       [-0.10665695],
       [ 0.02838454],
       [ 0.01483625],
       [-0.07403877],
       [ 0.00496671],
       [-0.03907715],
       [ 0.05302151],
       [ 0.01195987],
       [ 0.06727689],
       [ 0.11256142],
       [-0.03522432],
       [-0.11596951],
       [ 0.06459189],
       [ 0.01746346],
       [-0.06061827],
       [ 0.11847678],
       [-0.07563414],
       [-0.07403923],
       [-0.08067524],
       [-0.14688464],
       [ 0.00432749],
       [-0.02054447],
       [-0.00952856],
       [ 0.00368338],
       [-0.10358795],
       [ 0.03402175],
       [-0.08424084],
       [-0.01010304],
       [ 0.04160201],
       [ 0.0536657 ],
       [-0.12492884],
       [-0.02599308],
       [ 0.02159717],
       [ 0

In [12]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [13]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="e_prop_autodiff")

In [14]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
           0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
         [-6.0950006e-05, -3.8079525e-05,  1.2152641e-05, ...,
          -2.2642873e-05, -4.6810277e-05,  3.2011652e-05],
         [-4.8343264e-04, -1.1599910e-03,  1.0651440e-04, ...,
          -1.2212555e-04, -2.1682866e-03,  1.7053841e-03],
         ...,
         [-2.2688034e-06, -1.7935907e-06,  6.9821738e-07, ...,
          -1.0640733e-06, -1.7942228e-06,  1.2497868e-06],
         [-3.3404133e-03, -1.0846080e-03,  5.1133125e-04, ...,
          -1.6417631e-05, -4.0256837e-03,  1.1488820e-03],
         [-2.7728104e-04, -1.6355165e-04,  6.2852770e-05, ...,
          -8.3616789e-05, -1.8865784e-04,  1.3403961e-04]], dtype=float32),
  'recurrent_weights': Array([[-0.        , -0.        ,  0.00063526, ..., -0.        ,
          -0.        ,  0.        ],
         [-0.        , -0.        ,  0.        , ..., -0.        ,
          -0. 

In [15]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)
 

r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no_refFalse
r1.0
no_refFalse
r0.0
no_refTrue
r5.0
no_refFalse
r4.0
no_refFalse
r3.0
no_refFalse
r2.0
no

In [16]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
           0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
         [-6.0950006e-05, -3.8079525e-05,  1.2152641e-05, ...,
          -2.2642873e-05, -4.6810277e-05,  3.2011652e-05],
         [-4.8343264e-04, -1.1599910e-03,  1.0651440e-04, ...,
          -1.2212555e-04, -2.1682866e-03,  1.7053841e-03],
         ...,
         [-2.2688034e-06, -1.7935907e-06,  6.9821738e-07, ...,
          -1.0640733e-06, -1.7942228e-06,  1.2497868e-06],
         [-3.3404133e-03, -1.0846080e-03,  5.1133125e-04, ...,
          -1.6417631e-05, -4.0256837e-03,  1.1488820e-03],
         [-2.7728104e-04, -1.6355165e-04,  6.2852770e-05, ...,
          -8.3616789e-05, -1.8865784e-04,  1.3403961e-04]], dtype=float32),
  'recurrent_weights': Array([[-0.        , -0.        ,  0.00063526, ..., -0.        ,
          -0.        ,  0.        ],
         [-0.        , -0.        ,  0.        , ..., -0.        ,
          -0. 

In [17]:
hard_grads_1

{'ALIFCell_0': {'input_weights': Array([[-0.0000000e+00, -0.0000000e+00,  0.0000000e+00, ...,
          -0.0000000e+00, -0.0000000e+00,  0.0000000e+00],
         [-6.0950031e-05, -3.8079535e-05,  1.2152646e-05, ...,
          -2.2642862e-05, -4.6810310e-05,  3.2011660e-05],
         [-4.8343278e-04, -1.1599913e-03,  1.0651446e-04, ...,
          -1.2212562e-04, -2.1682866e-03,  1.7053838e-03],
         ...,
         [-2.2688041e-06, -1.7935918e-06,  6.9821823e-07, ...,
          -1.0640739e-06, -1.7942234e-06,  1.2497870e-06],
         [-3.3404136e-03, -1.0846083e-03,  5.1133125e-04, ...,
          -1.6417634e-05, -4.0256847e-03,  1.1488820e-03],
         [-2.7728113e-04, -1.6355171e-04,  6.2852778e-05, ...,
          -8.3616789e-05, -1.8865800e-04,  1.3403970e-04]], dtype=float32),
  'recurrent_weights': Array([[-0.        , -0.        ,  0.00062059, ..., -0.        ,
          -0.        ,  0.        ],
         [-0.        , -0.        ,  0.        , ..., -0.        ,
          -0. 

In [18]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

Array([[1.40619893e-02],
       [3.73633243e-02],
       [4.00140882e-03],
       [7.78370071e-03],
       [3.43164196e-03],
       [2.69185416e-02],
       [1.78053379e-02],
       [1.07570766e-02],
       [1.83623459e-04],
       [5.88984340e-02],
       [1.17167095e-02],
       [3.36863362e-04],
       [3.66387069e-02],
       [0.00000000e+00],
       [4.74702986e-03],
       [1.03275953e-02],
       [2.55230028e-04],
       [3.73409241e-02],
       [5.55290608e-03],
       [9.99412686e-03],
       [0.00000000e+00],
       [2.56332755e-02],
       [2.98072882e-06],
       [3.65882888e-02],
       [4.90731932e-03],
       [5.28669432e-02],
       [0.00000000e+00],
       [8.10245638e-06],
       [4.63425787e-03],
       [3.42468247e-02],
       [1.12230708e-04],
       [8.60731525e-05],
       [0.00000000e+00],
       [2.54161619e-02],
       [8.31036246e-04],
       [1.63902082e-02],
       [3.43634607e-03],
       [3.87334377e-02],
       [5.92105724e-02],
       [1.81004629e-02],


In [19]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

Array([[1.40619902e-02],
       [3.73633243e-02],
       [4.00140882e-03],
       [7.78370071e-03],
       [3.43164150e-03],
       [2.69185379e-02],
       [1.78053398e-02],
       [1.07570784e-02],
       [1.83623531e-04],
       [5.88984340e-02],
       [1.17167104e-02],
       [3.36863566e-04],
       [3.66387106e-02],
       [0.00000000e+00],
       [4.74703033e-03],
       [1.03275953e-02],
       [2.55230145e-04],
       [3.73409204e-02],
       [5.55290608e-03],
       [9.99412499e-03],
       [0.00000000e+00],
       [2.56332736e-02],
       [2.98073087e-06],
       [3.65882851e-02],
       [4.90732212e-03],
       [5.28669432e-02],
       [0.00000000e+00],
       [8.10246183e-06],
       [4.63425880e-03],
       [3.42468247e-02],
       [1.12230715e-04],
       [8.60731670e-05],
       [0.00000000e+00],
       [2.54161693e-02],
       [8.31036887e-04],
       [1.63902063e-02],
       [3.43634584e-03],
       [3.87334339e-02],
       [5.92105761e-02],
       [1.81004591e-02],


In [20]:
recurrent_1 = grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]

Array([-6.0950006e-05, -3.8079525e-05,  1.2152641e-05, ...,
       -8.3616789e-05, -1.8865784e-04,  1.3403961e-04], dtype=float32)

In [21]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[mask]

Array([-6.0950031e-05, -3.8079535e-05,  1.2152646e-05, ...,
       -8.3616789e-05, -1.8865800e-04,  1.3403970e-04], dtype=float32)

In [22]:
is_correct = jnp.absolute(recurrent_hard_1[mask]-recurrent_1[mask]) < 1e-4

In [23]:
is_correct.sum()

Array(8473, dtype=int32)

In [24]:
is_correct.size

8473

In [25]:
jnp.max(recurrent_hard_1[mask]-recurrent_1[mask])

Array(2.7939677e-09, dtype=float32)