In [1]:
import tasks
import models

import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
seed_task = 2
n_ALIF=0
n_LIF=100
n_rec= n_ALIF + n_LIF

In [3]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1, local_connectivity=True, learning_rule="e_prop_autodiff", sparse_readout_connectivity=True, feedback="Random")


In [4]:
task_batches = list(tasks.pattern_generation(n_batches=64, batch_size=8, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
                                             weights=[0.2,0.2,0.2,0.2,0.2], n_population=100,
                                             f_input=10, trial_dur=200))


In [5]:
batch = task_batches[0]

In [6]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 *losses.squared_error(targets=labels, predictions=logits)) # sum over batches and time
    
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.square(av_f_rate - f_target))
  return task_loss + regularization_loss
 

In [7]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [8]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(8,100))


In [9]:
state_1.spatial_params["ALIFCell_0"]["M"]

Array([[0., 0., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [10]:
state_1.eligibility_params['ReadOut_0']['feedback_weights']

Array([[ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.03255503],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.05404093],
       [ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.1242374 ],
       [-0.11830758],
       [-0.        ],
       [ 0.03595912],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0

In [11]:
state_1.params['ReadOut_0']['readout_weights']

Array([[-0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.06558562],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.10665695],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0.01195987],
       [ 0.06727689],
       [ 0.        ],
       [-0.03522432],
       [-0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0.        ],
       [-0.        ],
       [-0.        ],
       [ 0.        ],
       [ 0

In [12]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [13]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="BPTT")

In [14]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  'recurrent_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [6.9610369e-03],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [1.6371243e-02],
         [0.0000000e+00],
         

In [15]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)
 

In [16]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  'recurrent_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [6.9610369e-03],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [1.6371243e-02],
         [0.0000000e+00],
         

In [17]:
hard_grads_1

{'ALIFCell_0': {'input_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  'recurrent_weights': Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [6.9610360e-03],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         [1.6371248e-02],
         [0.0000000e+00],
         

In [18]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

Array([[0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [6.9610369e-03],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [1.6371243e-02],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [3.5518743e-02],
       [7.1204957e-03],
       [0.0000000e+00],
       [2.0489402e-02],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.000000

In [19]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

Array([[0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [6.9610360e-03],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [1.6371248e-02],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [3.5518747e-02],
       [7.1204952e-03],
       [0.0000000e+00],
       [2.0489404e-02],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.0000000e+00],
       [0.000000

In [20]:
recurrent_1 = grads_1['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]

Array([-1.70486502e-03,  1.14710128e-04,  1.68750118e-02, -1.66808942e-03,
        7.37458642e-04,  7.92724371e-04, -2.36279739e-04, -7.73877080e-04,
        1.73533670e-04, -4.51537408e-03, -3.34914662e-02,  7.30470783e-05,
        1.52425957e-03, -1.01388544e-02, -3.65044661e-02, -1.96429924e-03,
        2.04404145e-02, -1.49989908e-03, -1.37558244e-02, -1.97831895e-02,
        3.31572676e-03,  6.23207772e-03, -5.06458152e-03, -5.88215946e-04,
        2.79734639e-04,  1.16249174e-01, -3.79546769e-02,  5.58382971e-03,
        1.38015794e-05, -1.78749524e-02,  3.69736669e-03, -7.77735695e-05,
       -1.12660140e-01, -5.74160926e-03,  8.49212956e-05,  5.08182857e-05,
        6.33286033e-03,  2.56799249e-05,  6.97871903e-04, -1.05879232e-01,
        7.76810013e-03, -1.08338362e-02,  3.47073264e-02, -1.01482165e-05,
        2.94138499e-06,  1.15032319e-03, -1.29582841e-05,  3.93618131e-04,
       -1.14191836e-03,  6.25929842e-03, -1.68486213e-06, -2.08576256e-03,
        4.88504916e-02,  

In [21]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[mask]

Array([ 2.97516887e-03,  7.33273628e-05,  1.13964984e-02, -7.87615019e-04,
       -2.65054498e-03, -7.90267903e-03, -1.12213093e-04,  3.71844159e-04,
       -2.83957110e-04,  7.77982967e-03, -1.59908719e-02, -7.21795310e-04,
       -1.54135060e-02,  9.84556694e-03,  1.78435445e-02, -9.27475630e-04,
       -3.55835259e-02, -7.57000002e-04,  6.65499223e-03, -9.34838131e-03,
       -3.41948047e-02, -3.13010719e-03,  8.54015816e-03,  1.24614438e-04,
       -2.80999317e-04, -2.76014558e-03, -2.01285958e-01,  4.01456654e-02,
       -2.89392890e-03,  9.88308784e-06, -8.98042880e-03, -1.40120508e-02,
        1.31135894e-04,  5.65511510e-02,  6.29937975e-03,  5.77645369e-05,
       -1.82761840e-04,  4.24964540e-03, -1.28652782e-05,  4.68725717e-04,
        5.23915626e-02, -2.83078905e-02,  2.05818992e-02, -1.74675770e-02,
       -5.02639114e-06,  2.10627695e-06,  7.88795063e-04,  2.23432944e-05,
        3.17582249e-04,  1.10888563e-03, -2.88096294e-02,  1.63612413e-06,
       -1.00499531e-03, -

In [22]:
is_correct = jnp.absolute(recurrent_hard_1-recurrent_1) < 1e-3

In [23]:
is_correct.sum()

Array(9948, dtype=int32)