In [1]:
import tasks
import models
import models_2
import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
seed_task = 2
n_ALIF=0
n_LIF=100
n_rec= n_ALIF + n_LIF

In [3]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1, local_connectivity=False, learning_rule="e_prop_autodiff")
model_2 = models_2.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1, local_connectivity=False)

In [4]:
task_batches = list(tasks.pattern_generation(n_batches=64, batch_size=8, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
                                             weights=[0.2,0.2,0.2,0.2,0.2], n_population=100,
                                             f_input=10, trial_dur=200))


In [5]:
batch = task_batches[0]

In [6]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 *losses.squared_error(targets=labels, predictions=logits)) # sum over batches and time
    
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.square(av_f_rate - f_target))
  return task_loss + regularization_loss
 

In [7]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [8]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(8,100))
state_2 = create_train_state(random.key(0), learning_rate=0.01, model=model_2, input_shape=(8,100))

In [9]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [10]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="BPTT")

In [11]:
logits_2, grads_2 = learning_rules.compute_grads(batch=batch, state=state_2,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="BPTT")

In [12]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[-0.05457361, -0.02295625,  0.01982296, ..., -0.0252021 ,
          -0.01912443,  0.04754121],
         [-0.0181329 , -0.01743844,  0.00800684, ..., -0.01261519,
          -0.01492508,  0.01584369],
         [-0.01866506, -0.04014814,  0.00088814, ..., -0.02177969,
          -0.0427818 ,  0.01979146],
         ...,
         [-0.02389489, -0.00630582,  0.0022009 , ..., -0.01223486,
          -0.00280007,  0.01192845],
         [-0.1142406 , -0.07078892,  0.01671964, ..., -0.0599874 ,
          -0.04771532,  0.08168091],
         [-0.06596739, -0.06327423,  0.00832309, ..., -0.05491588,
          -0.0384869 ,  0.05293014]], dtype=float32),
  'recurrent_weights': Array([[-0.0000000e+00, -4.0713986e-03,  1.3851686e-03, ...,
          -3.8676830e-03, -1.0424457e-03,  4.0551680e-03],
         [-3.5295784e-04, -0.0000000e+00,  6.1064311e-05, ...,
          -1.9661392e-04, -2.2969380e-05,  8.9647889e-05],
         [ 0.0000000e+00,  0.0000000e+00,  0.0000

In [13]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)
 

y_batch: [[[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.08415048]
  [ 0.03300702]
  [ 0.03139725]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.15962838]
  [-0.15184322]
  [-0.14443775]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.0518639 ]
  [-0.04933447]
  [-0.0469284 ]]

 ...

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.08191347]
  [-0.07791851]
  [-0.03565005]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [ 0.13617755]
  [ 0.18770657]
  [ 0.17855202]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [ 0.11717977]
  [ 0.08088971]
  [ 0.07694467]]]
crop_trace: [[[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.1109028e-02
   0.0000000e+00 2.2486849e-04 3.5084430e-02 0.0000000e+00 7.6884493e-02
   0.0000000e+00 0.0000000e+00 3.1745709e-02 0.0000000e+00 3.0275658e-03
   0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
   0.0000000e+00 1.5764458e-02 0.0000000e+00 0.0000000e+0

In [14]:
logits_hard_2, hard_grads_2= learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)

y_batch: [[[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.08415048]
  [ 0.03300702]
  [ 0.03139725]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.15962838]
  [-0.15184322]
  [-0.14443775]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.0518639 ]
  [-0.04933447]
  [-0.0469284 ]]

 ...

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [-0.08191347]
  [-0.07791851]
  [-0.03565005]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [ 0.13617755]
  [ 0.18770657]
  [ 0.17855202]]

 [[ 0.        ]
  [ 0.        ]
  [ 0.        ]
  ...
  [ 0.11717977]
  [ 0.08088971]
  [ 0.07694467]]]
crop_trace: [[[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.1109028e-02
   0.0000000e+00 2.2486849e-04 3.5084430e-02 0.0000000e+00 7.6884493e-02
   0.0000000e+00 0.0000000e+00 3.1745709e-02 0.0000000e+00 3.0275658e-03
   0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
   0.0000000e+00 1.5764458e-02 0.0000000e+00 0.0000000e+0

In [15]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[-0.05457361, -0.02295625,  0.01982296, ..., -0.0252021 ,
          -0.01912443,  0.04754121],
         [-0.0181329 , -0.01743844,  0.00800684, ..., -0.01261519,
          -0.01492508,  0.01584369],
         [-0.01866506, -0.04014814,  0.00088814, ..., -0.02177969,
          -0.0427818 ,  0.01979146],
         ...,
         [-0.02389489, -0.00630582,  0.0022009 , ..., -0.01223486,
          -0.00280007,  0.01192845],
         [-0.1142406 , -0.07078892,  0.01671964, ..., -0.0599874 ,
          -0.04771532,  0.08168091],
         [-0.06596739, -0.06327423,  0.00832309, ..., -0.05491588,
          -0.0384869 ,  0.05293014]], dtype=float32),
  'recurrent_weights': Array([[-0.0000000e+00, -4.0713986e-03,  1.3851686e-03, ...,
          -3.8676830e-03, -1.0424457e-03,  4.0551680e-03],
         [-3.5295784e-04, -0.0000000e+00,  6.1064311e-05, ...,
          -1.9661392e-04, -2.2969380e-05,  8.9647889e-05],
         [ 0.0000000e+00,  0.0000000e+00,  0.0000

In [16]:
hard_grads_1

{'ALIFCell_0': {'input_weights': Array([[-0.05454997, -0.02295536,  0.01982296, ..., -0.02479655,
          -0.01912443,  0.04695164],
         [-0.01812583, -0.01744035,  0.00800684, ..., -0.01260292,
          -0.01492508,  0.01573021],
         [-0.01859048, -0.04014815,  0.00088814, ..., -0.01938568,
          -0.04277477,  0.01913013],
         ...,
         [-0.02373615, -0.00630582,  0.0022009 , ..., -0.0121943 ,
          -0.00279247,  0.01123303],
         [-0.11419986, -0.07078759,  0.01671964, ..., -0.05826627,
          -0.04770755,  0.07626332],
         [-0.06591284, -0.06327423,  0.00832309, ..., -0.05466373,
          -0.03847982,  0.05009513]], dtype=float32),
  'recurrent_weights': Array([[-0.0000000e+00, -3.9139558e-03,  1.3550217e-03, ...,
          -3.7019735e-03, -1.0142594e-03,  3.7940568e-03],
         [-3.3198181e-04, -0.0000000e+00,  5.8751186e-05, ...,
          -1.8729626e-04, -2.1849150e-05,  8.0821366e-05],
         [ 0.0000000e+00,  0.0000000e+00,  0.0000

In [17]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

Array([[1.10505950e-02],
       [2.61046953e-04],
       [0.00000000e+00],
       [6.93848848e-01],
       [2.26825848e-01],
       [8.12870543e-03],
       [1.88892853e-04],
       [6.11202866e-02],
       [0.00000000e+00],
       [4.27747220e-02],
       [0.00000000e+00],
       [2.30479818e-02],
       [1.45858005e-02],
       [0.00000000e+00],
       [7.90748596e-02],
       [2.19487399e-02],
       [0.00000000e+00],
       [0.00000000e+00],
       [1.54686831e-02],
       [3.15828025e-02],
       [4.50852700e-03],
       [5.70087992e-02],
       [3.17198336e-02],
       [1.93239242e-01],
       [6.83328956e-02],
       [3.12756658e-01],
       [4.89911102e-02],
       [0.00000000e+00],
       [8.70194985e-04],
       [4.01734978e-01],
       [0.00000000e+00],
       [0.00000000e+00],
       [0.00000000e+00],
       [5.07401109e-01],
       [0.00000000e+00],
       [1.21768378e-01],
       [0.00000000e+00],
       [7.75570050e-02],
       [1.32872832e+00],
       [1.12203993e-02],


In [18]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

Array([[1.10505959e-02],
       [2.61046924e-04],
       [0.00000000e+00],
       [6.93848729e-01],
       [2.26825833e-01],
       [8.12870823e-03],
       [1.88892896e-04],
       [6.11202754e-02],
       [0.00000000e+00],
       [4.27747183e-02],
       [0.00000000e+00],
       [2.30479855e-02],
       [1.45858042e-02],
       [0.00000000e+00],
       [7.90748522e-02],
       [2.19487455e-02],
       [0.00000000e+00],
       [0.00000000e+00],
       [1.54686859e-02],
       [3.15828025e-02],
       [4.50852886e-03],
       [5.70088029e-02],
       [3.17198299e-02],
       [1.93239242e-01],
       [6.83328882e-02],
       [3.12756658e-01],
       [4.89911065e-02],
       [0.00000000e+00],
       [8.70195101e-04],
       [4.01735008e-01],
       [0.00000000e+00],
       [0.00000000e+00],
       [0.00000000e+00],
       [5.07401109e-01],
       [0.00000000e+00],
       [1.21768378e-01],
       [0.00000000e+00],
       [7.75570050e-02],
       [1.32872844e+00],
       [1.12204002e-02],


In [19]:
read_out_hard_2 = hard_grads_2["ReadOut_0"]["readout_weights"]
read_out_hard_2

Array([[1.10505959e-02],
       [2.61046924e-04],
       [0.00000000e+00],
       [6.93848729e-01],
       [2.26825833e-01],
       [8.12870823e-03],
       [1.88892896e-04],
       [6.11202754e-02],
       [0.00000000e+00],
       [4.27747183e-02],
       [0.00000000e+00],
       [2.30479855e-02],
       [1.45858042e-02],
       [0.00000000e+00],
       [7.90748522e-02],
       [2.19487455e-02],
       [0.00000000e+00],
       [0.00000000e+00],
       [1.54686859e-02],
       [3.15828025e-02],
       [4.50852886e-03],
       [5.70088029e-02],
       [3.17198299e-02],
       [1.93239242e-01],
       [6.83328882e-02],
       [3.12756658e-01],
       [4.89911065e-02],
       [0.00000000e+00],
       [8.70195101e-04],
       [4.01735008e-01],
       [0.00000000e+00],
       [0.00000000e+00],
       [0.00000000e+00],
       [5.07401109e-01],
       [0.00000000e+00],
       [1.21768378e-01],
       [0.00000000e+00],
       [7.75570050e-02],
       [1.32872844e+00],
       [1.12204002e-02],


In [20]:
read_out_2 = grads_2["ReadOut_0"]["readout_weights"]
read_out_2

Array([[1.10505950e-02],
       [2.61046953e-04],
       [0.00000000e+00],
       [6.93848848e-01],
       [2.26825848e-01],
       [8.12870543e-03],
       [1.88892853e-04],
       [6.11202866e-02],
       [0.00000000e+00],
       [4.27747220e-02],
       [0.00000000e+00],
       [2.30479818e-02],
       [1.45858005e-02],
       [0.00000000e+00],
       [7.90748596e-02],
       [2.19487399e-02],
       [0.00000000e+00],
       [0.00000000e+00],
       [1.54686831e-02],
       [3.15828025e-02],
       [4.50852700e-03],
       [5.70087992e-02],
       [3.17198336e-02],
       [1.93239242e-01],
       [6.83328956e-02],
       [3.12756658e-01],
       [4.89911102e-02],
       [0.00000000e+00],
       [8.70194985e-04],
       [4.01734978e-01],
       [0.00000000e+00],
       [0.00000000e+00],
       [0.00000000e+00],
       [5.07401109e-01],
       [0.00000000e+00],
       [1.21768378e-01],
       [0.00000000e+00],
       [7.75570050e-02],
       [1.32872832e+00],
       [1.12203993e-02],


In [21]:
recurrent_1 = grads_1['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]

Array([-0.0040714 ,  0.00138517,  0.00829898, ...,  0.14877799,
       -0.04856402, -0.01517615], dtype=float32)

In [22]:
recurrent_2 = grads_2['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_2!=0.)
recurrent_2[mask]

Array([-0.00061192,  0.00044177,  0.00094082, ...,  0.05536162,
       -0.02411162, -0.01077862], dtype=float32)

In [23]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[mask]

Array([-0.00391396,  0.00135502,  0.0079028 , ...,  0.14373933,
       -0.04792067, -0.01469209], dtype=float32)

In [24]:
recurrent_hard_2 = hard_grads_2['ALIFCell_0']["recurrent_weights"]
mask = jnp.where(recurrent_hard_2==0.)
recurrent_hard_2[mask].shape

(2479,)

In [25]:
is_correct = jnp.absolute(recurrent_hard_1-recurrent_1) < 1e-3

In [26]:
is_correct.sum()

Array(8848, dtype=int32)