In [1]:
import tasks
import models

import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
seed_task = 2
n_ALIF=0
n_LIF=100
n_rec= n_ALIF + n_LIF

In [3]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1,thr=0.01, local_connectivity=True, learning_rule="e_prop_autodiff", sparse_connectivity=False, refractory_period=1)


In [4]:
task_batches = list(tasks.pattern_generation(n_batches=1, batch_size=1, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
                                             weights=[0.2,0.2,0.2,0.2,0.2], n_population=100,
                                             f_input=10, trial_dur=2000))


In [5]:
batch = task_batches[0]

In [6]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 *losses.squared_error(targets=labels, predictions=logits)) # sum over batches and time
    
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.square(av_f_rate - f_target))
  return task_loss + regularization_loss
 

In [7]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [8]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(1,100))


In [9]:
state_1.spatial_params["ALIFCell_0"]["cells_loc"]

Array([[2, 3],
       [6, 0],
       [4, 3],
       [7, 8],
       [8, 8],
       [8, 5],
       [9, 4],
       [3, 1],
       [5, 0],
       [8, 7],
       [7, 2],
       [4, 6],
       [0, 4],
       [2, 1],
       [4, 9],
       [8, 2],
       [5, 1],
       [0, 6],
       [1, 2],
       [7, 4],
       [1, 6],
       [6, 2],
       [3, 9],
       [6, 6],
       [0, 0],
       [9, 3],
       [1, 7],
       [8, 0],
       [7, 6],
       [6, 5],
       [6, 8],
       [0, 9],
       [6, 1],
       [2, 7],
       [8, 9],
       [2, 6],
       [9, 5],
       [5, 5],
       [6, 3],
       [7, 3],
       [4, 7],
       [3, 3],
       [4, 5],
       [1, 9],
       [8, 3],
       [2, 9],
       [9, 1],
       [7, 9],
       [0, 3],
       [3, 2],
       [9, 8],
       [1, 5],
       [7, 0],
       [0, 2],
       [0, 1],
       [4, 0],
       [1, 3],
       [6, 9],
       [2, 2],
       [3, 4],
       [4, 1],
       [9, 6],
       [5, 2],
       [5, 3],
       [0, 7],
       [9, 7],
       [1,

In [10]:
state_1.eligibility_params['ReadOut_0']['feedback_weights']

Array([[-0.04560519],
       [-0.03792412],
       [ 0.02599526],
       [ 0.05817047],
       [ 0.06815916],
       [-0.06558562],
       [ 0.14032502],
       [-0.06831602],
       [-0.00951583],
       [ 0.03664244],
       [-0.04163158],
       [-0.0113758 ],
       [-0.10665695],
       [ 0.02838454],
       [ 0.01483625],
       [-0.07403877],
       [ 0.00496671],
       [-0.03907715],
       [ 0.05302151],
       [ 0.01195987],
       [ 0.06727689],
       [ 0.11256142],
       [-0.03522432],
       [-0.11596951],
       [ 0.06459189],
       [ 0.01746346],
       [-0.06061827],
       [ 0.11847678],
       [-0.07563414],
       [-0.07403923],
       [-0.08067524],
       [-0.14688464],
       [ 0.00432749],
       [-0.02054447],
       [-0.00952856],
       [ 0.00368338],
       [-0.10358795],
       [ 0.03402175],
       [-0.08424084],
       [-0.01010304],
       [ 0.04160201],
       [ 0.0536657 ],
       [-0.12492884],
       [-0.02599308],
       [ 0.02159717],
       [ 0

In [11]:
state_1.params['ReadOut_0']['readout_weights']

Array([[-0.04560519],
       [-0.03792412],
       [ 0.02599526],
       [ 0.05817047],
       [ 0.06815916],
       [-0.06558562],
       [ 0.14032502],
       [-0.06831602],
       [-0.00951583],
       [ 0.03664244],
       [-0.04163158],
       [-0.0113758 ],
       [-0.10665695],
       [ 0.02838454],
       [ 0.01483625],
       [-0.07403877],
       [ 0.00496671],
       [-0.03907715],
       [ 0.05302151],
       [ 0.01195987],
       [ 0.06727689],
       [ 0.11256142],
       [-0.03522432],
       [-0.11596951],
       [ 0.06459189],
       [ 0.01746346],
       [-0.06061827],
       [ 0.11847678],
       [-0.07563414],
       [-0.07403923],
       [-0.08067524],
       [-0.14688464],
       [ 0.00432749],
       [-0.02054447],
       [-0.00952856],
       [ 0.00368338],
       [-0.10358795],
       [ 0.03402175],
       [-0.08424084],
       [-0.01010304],
       [ 0.04160201],
       [ 0.0536657 ],
       [-0.12492884],
       [-0.02599308],
       [ 0.02159717],
       [ 0

In [12]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [13]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="e_prop_autodiff")

In [14]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[-8.80703919e-06, -9.15153032e-06,  5.47270110e-06, ...,
          -2.63295738e-06, -5.86285751e-06,  8.87808437e-06],
         [-3.32491996e-04, -3.27977730e-04,  4.81272618e-05, ...,
          -1.44875768e-04, -1.09715409e-04,  5.06320539e-05],
         [-8.36845047e-08, -9.41241964e-08,  5.71581289e-08, ...,
          -2.97909128e-08, -5.78572603e-08,  1.13658004e-07],
         ...,
         [-7.31408090e-06, -7.39851703e-06,  4.48012770e-06, ...,
          -2.30416390e-06, -5.19476635e-06,  7.79605580e-06],
         [-7.51006592e-05, -9.59246099e-05,  2.05276319e-05, ...,
          -2.62432332e-05, -1.46001148e-05,  3.16290170e-05],
         [-1.76975853e-04, -8.76424820e-05,  8.89411604e-06, ...,
          -7.95215892e-05, -8.65990369e-05,  1.28361266e-09]],      dtype=float32),
  'recurrent_weights': Array([[-0.        , -0.        ,  0.00023938, ..., -0.        ,
          -0.        ,  0.        ],
         [-0.        , -0.        ,  0. 

In [15]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)
 

r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0
no_refFalse
r0.0
no_refTrue
r1.0


In [16]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[-8.80703919e-06, -9.15153032e-06,  5.47270110e-06, ...,
          -2.63295738e-06, -5.86285751e-06,  8.87808437e-06],
         [-3.32491996e-04, -3.27977730e-04,  4.81272618e-05, ...,
          -1.44875768e-04, -1.09715409e-04,  5.06320539e-05],
         [-8.36845047e-08, -9.41241964e-08,  5.71581289e-08, ...,
          -2.97909128e-08, -5.78572603e-08,  1.13658004e-07],
         ...,
         [-7.31408090e-06, -7.39851703e-06,  4.48012770e-06, ...,
          -2.30416390e-06, -5.19476635e-06,  7.79605580e-06],
         [-7.51006592e-05, -9.59246099e-05,  2.05276319e-05, ...,
          -2.62432332e-05, -1.46001148e-05,  3.16290170e-05],
         [-1.76975853e-04, -8.76424820e-05,  8.89411604e-06, ...,
          -7.95215892e-05, -8.65990369e-05,  1.28361266e-09]],      dtype=float32),
  'recurrent_weights': Array([[-0.        , -0.        ,  0.00023938, ..., -0.        ,
          -0.        ,  0.        ],
         [-0.        , -0.        ,  0. 

In [17]:
hard_grads_1

{'ALIFCell_0': {'input_weights': Array([[-8.80703556e-06, -9.15152668e-06,  5.47270020e-06, ...,
          -2.63295806e-06, -5.86285660e-06,  8.87808619e-06],
         [-3.32491967e-04, -3.27977701e-04,  4.81272691e-05, ...,
          -1.44875798e-04, -1.09715424e-04,  5.06320393e-05],
         [-8.36844904e-08, -9.41241893e-08,  5.71581147e-08, ...,
          -2.97909093e-08, -5.78572497e-08,  1.13657990e-07],
         ...,
         [-7.31407954e-06, -7.39851930e-06,  4.48012770e-06, ...,
          -2.30416481e-06, -5.19476498e-06,  7.79605853e-06],
         [-7.51006446e-05, -9.59245954e-05,  2.05276392e-05, ...,
          -2.62432350e-05, -1.46001175e-05,  3.16290134e-05],
         [-1.76975838e-04, -8.76424820e-05,  8.89411604e-06, ...,
          -7.95215819e-05, -8.65990296e-05,  1.28361244e-09]],      dtype=float32),
  'recurrent_weights': Array([[-0.      , -0.      ,  0.000228, ..., -0.      , -0.      ,
           0.      ],
         [-0.      , -0.      ,  0.      , ..., -0. 

In [18]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

Array([[8.8306488e-03],
       [8.3414717e-03],
       [1.0125654e-02],
       [9.4573377e-03],
       [2.3461283e-04],
       [5.0701277e-04],
       [4.4728112e-03],
       [3.4132447e-07],
       [3.5178369e-09],
       [9.6538354e-04],
       [5.5242490e-06],
       [4.1064304e-13],
       [5.6115100e-03],
       [4.8615567e-07],
       [6.5042614e-04],
       [5.8006105e-04],
       [1.1671656e-20],
       [8.5515445e-03],
       [9.4841339e-04],
       [1.0687790e-02],
       [3.0216956e-05],
       [5.3833053e-03],
       [4.1341066e-04],
       [7.4187471e-03],
       [4.4809235e-06],
       [9.7383503e-03],
       [2.2185710e-05],
       [2.9387799e-19],
       [2.1092321e-03],
       [9.4467383e-03],
       [8.6621367e-07],
       [1.2010108e-04],
       [9.1287438e-06],
       [1.0702457e-02],
       [2.0975018e-14],
       [6.4441059e-03],
       [1.4479439e-03],
       [6.3246973e-03],
       [2.7900122e-03],
       [8.4819039e-03],
       [7.0019397e-03],
       [1.043535

In [19]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

Array([[8.8306516e-03],
       [8.3414745e-03],
       [1.0125658e-02],
       [9.4573414e-03],
       [2.3461280e-04],
       [5.0701271e-04],
       [4.4728108e-03],
       [3.4132404e-07],
       [3.5178318e-09],
       [9.6538319e-04],
       [5.5242499e-06],
       [4.1064244e-13],
       [5.6115100e-03],
       [4.8615567e-07],
       [6.5042620e-04],
       [5.8006094e-04],
       [1.1671647e-20],
       [8.5515510e-03],
       [9.4841345e-04],
       [1.0687790e-02],
       [3.0216950e-05],
       [5.3833057e-03],
       [4.1341071e-04],
       [7.4187461e-03],
       [4.4809235e-06],
       [9.7383559e-03],
       [2.2185695e-05],
       [2.9387776e-19],
       [2.1092319e-03],
       [9.4467411e-03],
       [8.6621361e-07],
       [1.2010103e-04],
       [9.1287411e-06],
       [1.0702459e-02],
       [2.0974994e-14],
       [6.4441059e-03],
       [1.4479440e-03],
       [6.3246978e-03],
       [2.7900110e-03],
       [8.4819086e-03],
       [7.0019397e-03],
       [1.043536

In [20]:
recurrent_1 = grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]

Array([-8.8070392e-06, -9.1515303e-06,  5.4727011e-06, ...,
       -7.9521589e-05, -8.6599037e-05,  1.2836127e-09], dtype=float32)

In [21]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[mask]

Array([-8.8070356e-06, -9.1515267e-06,  5.4727002e-06, ...,
       -7.9521582e-05, -8.6599030e-05,  1.2836124e-09], dtype=float32)

In [22]:
is_correct = jnp.absolute(recurrent_hard_1[mask]-recurrent_1[mask]) < 1e-4

In [23]:
is_correct.sum()

Array(10000, dtype=int32)

In [24]:
is_correct.size

10000

In [25]:
jnp.max(recurrent_hard_1[mask]-recurrent_1[mask])

Array(3.4924597e-10, dtype=float32)