In [1]:
import tasks
import models
import models_2
import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
import jax
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [2]:
model_1 = models.LSSN(n_ALIF=0, n_LIF=2, n_out=1, local_connectivity=False)
model_2 = models_2.LSSN(n_ALIF=0, n_LIF=2, n_out=1, local_connectivity=False)

In [3]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
  if labels.ndim==2: # calling labels what normally people call targets in regression tasks
      labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

  task_loss = jnp.sum(0.5 * losses.squared_error(targets=labels, predictions=logits)) # mean over batches and time 
  jax.debug.print("loss{}",task_loss)
  return task_loss 
optimization_loss_fn = optimization_loss

In [4]:
inputs = jnp.expand_dims(jnp.array([[1,1], [0,0], [1,0], [0,1]]), axis=1)
inputs.shape
labels = jnp.expand_dims(jnp.array([1, 2, 0, 3]), axis=1)
labels.shape
trial_len = jnp.array([1,1,1,1])
batch = {"input":inputs, "label":labels, "trial_duration":trial_len}

In [5]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [6]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(4,2))
state_2 = create_train_state(random.key(0), learning_rate=0.01, model=model_2, input_shape=(4,2))

In [7]:
state_1.params["ALIFCell_0"]["input_weights"] = jnp.array([[1.,1.], [1.,1.]])
state_1.params["ALIFCell_0"]["recurrent_weights"] = jnp.array([[1.,0.], [0.,1.]])
state_1.params["ReadOut_0"]["readout_weights"] = jnp.array([[1.], [1.]])
state_1.params

{'ALIFCell_0': {'input_weights': Array([[1., 1.],
         [1., 1.]], dtype=float32),
  'recurrent_weights': Array([[1., 0.],
         [0., 1.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[1.],
         [1.]], dtype=float32)}}

In [8]:
state_1.spatial_params

{'ALIFCell_0': {'M': Array([[1., 1.],
         [1., 1.]], dtype=float32),
  'cells_loc': Array([[2, 3],
         [6, 0]], dtype=int32),
  'diff_K': Array([[[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]]], dtype=float32)}}

In [9]:
LS_avail = 1
c_reg =0
f_target = 10
optimization_loss_fn = optimization_loss
task = "regression"
local_connectivity = True
learning_rule = "e_prop_hardcoded"

In [10]:
logits_1, grads_1 = learning_rules.autodiff_grads(batch=batch,state=state_1, optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail= LS_avail, c_reg=c_reg, f_target=f_target)

loss5.0


In [11]:
logits_1

Array([[[2.]],

       [[0.]],

       [[2.]],

       [[2.]]], dtype=float32)

In [12]:
grads_1

{'ALIFCell_0': {'input_weights': Array([[ 0.33333337,  0.33333337],
         [-0.16666669, -0.16666669]], dtype=float32),
  'recurrent_weights': Array([[0., 0.],
         [0., 0.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[2.],
         [2.]], dtype=float32)}}

In [13]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, local_connectivity=local_connectivity, 
                                                  f_target=f_target, c_reg=c_reg, learning_rule=learning_rule, task=task)

y_batch: [[[2.]]

 [[0.]]

 [[2.]]

 [[2.]]]
crop_trace: [[[1. 1.]
  [0. 0.]
  [1. 1.]
  [1. 1.]]]
err: [[[ 1.]]

 [[-2.]]

 [[ 2.]]

 [[-1.]]]


In [14]:
hard_grads_1

{'ALIFCell_0': {'input_weights': Array([[0., 0.],
         [0., 0.]], dtype=float32),
  'recurrent_weights': Array([[0., 0.],
         [0., 0.]], dtype=float32)},
 'ReadOut_0': {'readout_weights': Array([[2.],
         [2.]], dtype=float32)}}