In [1]:
import tasks
import models

import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
seed_task = 2
n_ALIF=0
n_LIF=100
n_rec= n_ALIF + n_LIF

In [3]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=1,thr=0.01, local_connectivity=True, learning_rule="e_prop_autodiff", sparse_connectivity=False)


In [4]:
task_batches = list(tasks.pattern_generation(n_batches=64, batch_size=8, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
                                             weights=[0.2,0.2,0.2,0.2,0.2], n_population=100,
                                             f_input=10, trial_dur=200))


In [5]:
batch = task_batches[0]

In [6]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 *losses.squared_error(targets=labels, predictions=logits)) # sum over batches and time
    
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.square(av_f_rate - f_target))
  return task_loss + regularization_loss
 

In [7]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [8]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(8,100))


In [None]:
state_1.spatial_params["ALIFCell_0"]["M"]

In [None]:
state_1.eligibility_params['ReadOut_0']['feedback_weights']

In [None]:
state_1.params['ReadOut_0']['readout_weights']

In [12]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [13]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, task=task, learning_rule="e_prop_autodiff")

In [None]:
grads_1

In [15]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)
 

In [None]:
grads_1

In [None]:
hard_grads_1

In [None]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

In [None]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

In [None]:
recurrent_1 = grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]

In [None]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[mask]

In [22]:
is_correct = jnp.absolute(recurrent_hard_1[mask]-recurrent_1[mask]) < 1e-3

In [None]:
is_correct.sum()

In [None]:
is_correct.size

In [None]:
jnp.max(recurrent_hard_1[mask]-recurrent_1[mask])