In [None]:
import tasks
import models
import jax
import learning_rules
import learning_utils
from jax import random, numpy as jnp
from optax import losses
from typing import (
  Any,
  Callable,
  Dict,
  List,
  Optional,
  Sequence,
  Tuple,
  Iterable  
 )
from flax.typing import (PRNGKey)
import optax
from flax.training import train_state, orbax_utils
Array = jnp.ndarray
TrainState = train_state.TrainState

In [None]:
seed_task = 20
n_ALIF=50
n_LIF=50
n_rec= n_ALIF + n_LIF

In [None]:
model_1 = models.LSSN(n_ALIF=n_ALIF, n_LIF=n_LIF, n_out=2,thr=0.03,beta=0.018, tau_m=20,tau_out=20,k=0.00, connectivity_rec_layer="local", learning_rule="e_prop_autodiff", sparse_input=True,sparse_readout=True, refractory_period=5,  gain=[2.0,2.0,2.0,2.0,2.0], gridshape=(10,10)) #,sigma=0.003)


In [None]:
task_batches = list(tasks.cue_accumulation_task(n_batches=32, batch_size=32, seed=seed_task))
# task_batches = list(tasks.pattern_generation(n_batches=1, batch_size=1, seed=seed_task, frequencies=[0.5, 1., 2., 3., 4.],
#                                      n_population=100, f_input=10, trial_dur=2000))

In [None]:
batch = task_batches[0]

In [None]:
def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):    
  """ Loss to be minimized by network, including task loss and any other, e.g. here also firing regularization
      Notes:
        1. logits is assumed to be non normalized logits
        2. labels are assumed to be one-hot encoded
  """
  # notice that optimization_loss is only called inside of learning_rules.compute_grads, and labels are already passed there as one-hot code and y is already softmax transformed
  task_loss = jnp.mean(jnp.mean(losses.softmax_cross_entropy(logits=logits, labels=labels), axis=0)) # mean over batches and sum over time
  
  av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
  f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
  regularization_loss = 0.5 * c_reg * jnp.sum(jnp.mean(jnp.square(av_f_rate - f_target), 0)) # average over batches
  return task_loss + regularization_loss

# def optimization_loss(logits, labels, z, c_reg, f_target, trial_length):
    
#   if labels.ndim==2: # calling labels what normally people call targets in regression tasks
#       labels = jnp.expand_dims(labels, axis=-1) # this is necessary because target labels might have only (n_batch, n_t) and predictions (n_batch, n_t, n_out=1)

#   task_loss = 0.5 * jnp.sum(jnp.mean(losses.squared_error(targets=labels, predictions=logits),axis=0))# sum over batches and time --> usually, take average, but biologically is unplausible that updates are averaged across batches, so sum
#   #task_loss = 0.5 * jnp.mean(losses.squared_error(targets=labels, predictions=logits))
#   av_f_rate = learning_utils.compute_firing_rate(z=z, trial_length=trial_length)
#   f_target = f_target / 1000 # f_target is given in Hz, bu av_f_rate is spikes/ms --> Bellec 2020 used the f_reg also in spikes/ms
#   regularization_loss = 0.5 * c_reg * jnp.sum(jnp.mean(jnp.square(av_f_rate - f_target),0)) # average over batches
#   return task_loss + regularization_loss
 

In [None]:
def get_initial_params(rng, model, input_shape):
  """Returns randomly initialized parameters, eligibility parameters and connectivity mask."""
  dummy_x = jnp.ones(input_shape)
  variables = model.init(rng, dummy_x)
  return variables['params'], variables['eligibility params'], variables['spatial params']
    

def get_init_eligibility_carries(rng, model, input_shape):
  """Returns randomly initialized carries. In the default mode, they are all initialized as zeros arrays"""
  return model.initialize_eligibility_carry(rng, input_shape)

def get_init_error_grid(rng, model, input_shape):
   """Return initial error grid initialized as zeros"""
   return model.initialize_grid(rng=rng, input_shape=input_shape)

# Create a custom TrainState to include both params and other variable collections
class TrainStateEProp(TrainState):
  """ Personalized TrainState for e-prop with local connectivity """
  eligibility_params: Dict[str, Array]
  spatial_params: Dict[str, Array]
  init_eligibility_carries: Dict[str, Array]
  init_error_grid: Array
  
def create_train_state(rng:PRNGKey, learning_rate:float, model, input_shape:Tuple[int,...])->train_state.TrainState:
  """Create initial training state."""
  key1, key2, key3 = random.split(rng, 3)
  params, eligibility_params, spatial_params = get_initial_params(key1, model, input_shape)
  init_eligibility_carries = get_init_eligibility_carries(key2, model, input_shape)
  init_error_grid = get_init_error_grid(key3, model, input_shape)

  tx = optax.adam(learning_rate=learning_rate)

  state = TrainStateEProp.create(apply_fn=model.apply, params=params, tx=tx, 
                                  eligibility_params=eligibility_params,
                                  spatial_params = spatial_params,
                                  init_eligibility_carries=init_eligibility_carries,                                  
                                  init_error_grid=init_error_grid
                                  )
  return state

In [None]:
state_1 = create_train_state(random.key(0), learning_rate=0.01, model=model_1, input_shape=(32,40))


In [None]:
jnp.where(state_1.params['ALIFCell_0']["recurrent_weights"]!=0,1,0).sum()/100

In [None]:
LS_avail = 50
c_reg =0.1
f_target = 10
optimization_loss_fn = optimization_loss
task = "classification"
learning_rule = "e_prop_hardcoded"

In [None]:
logits_1, grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, f_target=f_target, c_reg=c_reg, task=task, learning_rule="e_prop_autodiff",
                                                  shuffle=False, key=random.key(0))

In [None]:
grads_1

In [None]:
logits_hard_1,hard_grads_1 = learning_rules.compute_grads(batch=batch, state=state_1,optimization_loss_fn=optimization_loss_fn,
                                                  LS_avail=LS_avail, f_target=f_target, c_reg=c_reg,
                                                   learning_rule=learning_rule, task=task,
                                                  shuffle=True, key=random.key(0))
 

In [None]:
grads_1

In [None]:
hard_grads_1

In [None]:
read_out_1 = grads_1["ReadOut_0"]["readout_weights"]
read_out_1

In [None]:
read_out_hard_1 = hard_grads_1["ReadOut_0"]["readout_weights"]
read_out_hard_1

In [None]:
recurrent_1 = grads_1['ALIFCell_0']["input_weights"]
mask = jnp.where(recurrent_1!=0.)
recurrent_1[mask]
recurrent_1.shape

In [None]:
recurrent_hard_1 = hard_grads_1['ALIFCell_0']["input_weights"]
hardcoded_mask = jnp.where(recurrent_hard_1!=0.)
recurrent_hard_1[hardcoded_mask]
jnp.max(jnp.abs(recurrent_hard_1))

In [None]:
is_correct = jnp.absolute(recurrent_hard_1[mask]-recurrent_1[mask]) < 1e-5

In [None]:
is_correct.sum()

In [None]:
is_correct.size

In [None]:
jnp.max(jnp.abs(recurrent_hard_1-recurrent_1))

In [None]:
jnp.where(state_1.params["ALIFCell_0"]["input_weights"]!=0., 1,0).sum()

In [None]:
import jax
import jax.numpy as jnp
n_b=10
n_in=100
n_rec=400
# Assuming gradients is a JAX array of shape (n_b, n_in, n_rec)
gradients = jax.random.normal(jax.random.PRNGKey(0), (n_b, n_in, n_rec))  # Example data

# Step 1: Flatten the matrix along the last two dimensions
flattened_grads = gradients.reshape((n_b, -1))  # New shape: (n_b, n_in * n_rec)

# Step 2: Compute cosine similarity
# Normalize the flattened gradients
norms = jnp.linalg.norm(flattened_grads, axis=1, keepdims=True)
normalized_grads = flattened_grads / norms

# Compute the cosine similarity matrix
similarity = jnp.dot(normalized_grads, normalized_grads.T)

# similarity now has shape (n_b, n_b), representing cosine similarity between each pair of batches


In [None]:
similarity.shape

In [None]:
import jax
import jax.numpy as jnp

# Assuming gradients_a and gradients_b are both JAX arrays of shape (n_b, n_in, n_rec)
grad_a = jax.random.normal(jax.random.PRNGKey(0), (n_in, n_rec))  # Example data# Example data
grad_b = jax.random.normal(jax.random.PRNGKey(1), (n_in, n_rec))  # Example data# Example data

# Step 1: Flatten the gradients along the last two dimensions
flattened_a = grad_a.reshape(-1)  # Shape: (n_in * n_rec,)
flattened_b = grad_b.reshape(-1)  # Shape: (n_in * n_rec,)

# Step 2: Compute cosine similarity
# Normalize the gradients
norm_a = jnp.linalg.norm(flattened_a)  # Scalar value for the norm of grad_a
norm_b = jnp.linalg.norm(flattened_b)  # Scalar value for the norm of grad_b

# Avoid dividing by zero in case of zero norms
norm_a = jnp.where(norm_a == 0, 1e-10, norm_a)
norm_b = jnp.where(norm_b == 0, 1e-10, norm_b)

# Normalize the vectors
normalized_a = flattened_a / norm_a
normalized_b = flattened_b / norm_b

# Compute cosine similarity
cosine_similarity = jnp.sum(normalized_a * normalized_b)  # Scalar value

# Clip cosine similarity to be within the range [-1, 1] to avoid numerical issues with arccos
cosine_similarity = jnp.clip(cosine_similarity, -1.0, 1.0)

# Step 3: Compute the angle in radians
angle_in_radians = jnp.arccos(cosine_similarity)  # Scalar value

# If you want the angle in degrees (optional)
angle_in_degrees = jnp.degrees(angle_in_radians)



In [None]:
cosine_similarity

In [None]:
from optax.losses import cosine_similarity
cosine_similarity(flattened_a, flattened_b)

In [None]:
import plots

layer_names = ["Input layer", "Recurrent layer", "Readout layer"]
plots.plot_LSNN_weights(state_1,layer_names=layer_names,
                    save_path=r"C:\Users\j1559\Documents\Tuebingen\SS_24\MasterThesis\neuromodRNNs\neuromodRNN\src\modRNN\weights.png")