<a href="https://colab.research.google.com/github/8erberg/spatially-embedded-RNN/blob/main/seRNN_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# seRNN Demo: How to spatially-embed an RNN

In this notebook we provide a run through of the basic training routine behind spatially-embedded recurrent neural networks (seRNNs) and apply structural metrics to an example seRNN. We hope that giving researchers early access to these tools will make our implementation more understandable and allow researchers to adapt the model to their needs before full publication of the project. We will add the additional analysis script for generative models, decoding, structure-function clustering, mixed selectivity and energy-usage during the coming weeks.


For further details on methods, see our current preprint: https://www.biorxiv.org/content/10.1101/2022.11.17.516914v1

In [1]:
%pip install -q 'tensorflow==2.3.0'
%pip install -q 'numpy==1.18.5'

import numpy as np
import tensorflow as tf
import pandas as pd
from tensorflow.keras.utils import to_categorical

In [2]:
from tensorflow.keras.regularizers import Regularizer
from tensorflow.python.keras import backend
from mpl_toolkits.mplot3d import Axes3D
import scipy.spatial.distance
import matplotlib.pyplot as plt
import os
import tensorflow.keras as keras

In [3]:
np.random.seed(18229)
tf.random.set_seed(94892)

## Task / Dataset

### Dataset generator


The function below generates multiple datasets representing our maze-like task.

In [4]:
class mazeGeneratorI():
    '''
    Objects of the mazeGeneratorI class can create numpy and tf datasets of the first choice of the maze task.
    Task structure:
        Goal presentation, followed by delay period, followed by choice options.
    Response:
        One response required from agent at end of episode. Direction (Left, Up, Right, Down) of first step.
    Encoding:
        Both observations and labels are OneHot encoded.
    Usage:
        The two only function a user should need to access are "construct_numpy_data" and "construct_tf_data"
    Options:
        Both data construction methods have an option to shuffle the labels of data.
        The numpy data construction method allows to also return the maze identifiers.
    '''
    def __init__(self, goal_presentation_steps, delay_steps, choices_presentation_steps):
        self.version = 'v1.2.0'
        
        # Import variables defining episode
        self.goal_presentation_steps = goal_presentation_steps
        self.delay_steps = delay_steps
        self.choices_presentation_steps = choices_presentation_steps

        # Construct mazes dataframe
        ## Add encoded versions of the goal / choices presentations and the next step response
        self.mazesdf = self.import_maze_dic()
        self.mazesdf['Goal_Presentation'] = self.mazesdf['goal'].map({
            7:np.concatenate((np.array([1,0,0,0]),np.repeat(0,4))),
            9:np.concatenate((np.array([0,1,0,0]),np.repeat(0,4))),
            17:np.concatenate((np.array([0,0,1,0]),np.repeat(0,4))),
            19:np.concatenate((np.array([0,0,0,1]),np.repeat(0,4)))})
        self.mazesdf['Choices_Presentation']=self.mazesdf['ChoicesCategory'].map(lambda x: self.encode_choices(x=x))
        self.mazesdf['Step_Encoded']=self.mazesdf['NextFPmap'].map(lambda x: self.encode_next_step(x=x))

    def construct_numpy_data(self, number_of_problems, return_maze_identifiers = False, np_shuffle_data = False):
        # Create a new column which hold the vector for each problem
        self.mazesdf['Problem_Vec']=self.mazesdf.apply(lambda x: self.create_problem_observation(row= x,goal_presentation_steps= self.goal_presentation_steps,delay_steps= self.delay_steps,choices_presentation_steps= self.choices_presentation_steps), axis=1)
        # Set a random order of maze problems for the current session
        self.mazes_order = np.random.randint(0,8,number_of_problems)

        # Create vectors, holding observations and labels
        session_observation =np.array([])
        session_labels = np.array([])
        for i in self.mazes_order:
            session_observation = np.append(session_observation,self.mazesdf.iloc[i]['Problem_Vec'])
            session_labels = np.append(session_labels,self.mazesdf.iloc[i]['Step_Encoded'])

        # Reshape vectors to fit network observation and response space
        session_length = self.goal_presentation_steps + self.delay_steps + self.choices_presentation_steps
        session_observation = np.reshape(session_observation, (-1,session_length,8)).astype('float32')
        session_labels = np.reshape(session_labels, (-1,4)).astype('float32')

        # If np_shuffle_data == 'Labels, the order of labels is shuffled to randomise correct answers
        if np_shuffle_data == 'Labels':
          shuffle_generator = np.random.default_rng(38446)
          shuffle_generator.shuffle(session_labels,axis=0)

        # If return_maze_identifiers == 'IDs', return the array with maze IDs alongside the regular returns (observations, labels)
        if return_maze_identifiers == 'IDs':
          return session_observation, session_labels, self.mazes_order

        return session_observation, session_labels

    def construct_tf_data(self, number_of_problems, batch_size, tf_shuffle_data = False):
        # Create dataset as described by numpy dataset function and transform it into a TF dataset
        npds, np_labels = self.construct_numpy_data(number_of_problems=number_of_problems, np_shuffle_data = tf_shuffle_data)
        tfdf = tf.data.Dataset.from_tensor_slices((npds, np_labels))
        tfdf = tfdf.batch(batch_size)
        return tfdf

    def reset_construction_params(self, goal_presentation_steps, delay_steps, choices_presentation_steps):
        self.goal_presentation_steps = goal_presentation_steps
        self.delay_steps = delay_steps
        self.choices_presentation_steps = choices_presentation_steps

    def import_maze_dic(self, mazeDic=None):
        if mazeDic == None:
            # Set up dataframe with first choices of maze task
            ## The dictionary was generated using MazeMetadata.py (v1.0.0) and the following call:
            ### mazes.loc[(mazes['Nsteps']==2)&(mazes['ChoiceNo']=='ChoiceI')][['goal','ChoicesCategory','NextFPmap']].reset_index(drop=True).to_dict()
            self.mazesDic = {'goal': {0: 9, 1: 9, 2: 19, 3: 17, 4: 17, 5: 7, 6: 19, 7: 7},
            'ChoicesCategory': {0: 'ul',
            1: 'rd',
            2: 'ld',
            3: 'rd',
            4: 'ul',
            5: 'ur',
            6: 'lr',
            7: 'lr'},
            'NextFPmap': {0: 'u', 1: 'r', 2: 'd', 3: 'd', 4: 'l', 5: 'u', 6: 'r', 7: 'l'}}
        else:
            self.mazesDic = mazesDic
        
        # Create and return dataframe
        return pd.DataFrame(self.mazesDic)

    def encode_choices(self, x):
        # Helper function to create the observation vector for choice periods
        choices_sec = np.repeat(0,4)
        choicesEncoding = pd.Series(list(x))
        choicesEncoding = choicesEncoding.map({'l':1,'u':2,'r':3,'d':4})
        for encodedChoice in choicesEncoding:
            choices_sec[encodedChoice-1]=1
        return np.concatenate((np.repeat(0,4),choices_sec))

    def encode_next_step(self, x):
        # Helper function to change the response / action to a OneHot encoded vector
        step_sec = np.repeat(0,4)
        stepEncoding = pd.Series(list(x))
        stepEncoding = stepEncoding.map({'l':1,'u':2,'r':3,'d':4})
        for encodedStep in stepEncoding:
            step_sec[encodedStep-1]=1
        return step_sec

    def create_problem_observation(self, row, goal_presentation_steps, delay_steps, choices_presentation_steps):
        # Helper function to create one vector describing the entire outline of one maze problem (Goal presentation, Delay Period, and Choices Presentation)
        goal_vec = np.tile(row['Goal_Presentation'], goal_presentation_steps)
        delay_vec = np.tile(np.repeat(0,8), delay_steps)
        choices_vec = np.tile(row['Choices_Presentation'], choices_presentation_steps)
        problem_vec = np.concatenate((goal_vec,delay_vec,choices_vec))
        return problem_vec

    def __repr__(self):
        return '\n'.join([
            f'Maze DataSet Generator',
            f'Goal Presentation Steps: {self.goal_presentation_steps}',
            f'Delay Steps: {self.delay_steps}',
            f'Choices Presentation Steps: {self.choices_presentation_steps}'])


### Generate datasets for training

In [5]:
# This constructor might run for around a minute to generate the dataset
gen = mazeGeneratorI(goal_presentation_steps=20,delay_steps=10,choices_presentation_steps=20)
tfdf = gen.construct_tf_data(number_of_problems=5120, batch_size=128)
tfdf_test = gen.construct_tf_data(number_of_problems=2560, batch_size=128)
tfdf_val = gen.construct_tf_data(number_of_problems=2560, batch_size=128)
print(tfdf)

<BatchDataset shapes: ((None, 50, 8), (None, 4)), types: (tf.float32, tf.float32)>


In [6]:
# Show example of dataset
example_data = next(iter(tfdf))
#print(example_data)

## seRNN

### Regulariser

In this section we define two regularisation functions:
1. Regulariser for Euclidean embedding
2. Subfunction which adds the communicability value (this is the one we use in seRNNs)

In [7]:
class SE1(Regularizer):
  """A regulariser for sptially embedded RNNs.
  Applies L1 regularisation to recurrent kernel of
  RNN which is weighted by the distance of units
  in predefined 3D space.
  Calculation:
      se1 * sum[distance_matrix o recurrent_kernel]
  Attributes:
      se1: Float; Weighting of SE1 regularisation term.
      distance_tensor: TF tensor / matrix with cost per
      connection in weight matrix of network.
  """

  def __init__(self, se1=0.01, neuron_num = 100, network_structure = (5,5,4), coordinates_list = None, distance_power = 1, distance_metric = 'euclidean'):  
    self.version = 'v1.1.0'
    self.distance_power = distance_power
    
    # Set SE1 regularisation strength to default of 0.01 if no value given
    se1 = 0.01 if se1 is None else se1
    self._check_penalty_number(se1)

    # Transform regularisation strength to TF's standard float format 
    self.se1 = backend.cast_to_floatx(se1)

    # Set up tensor with distance matrix
    ## Set up neurons per dimension
    nx = np.arange(network_structure[0])
    ny = np.arange(network_structure[1])
    nz = np.arange(network_structure[2])

    ## Set up coordinate grid
    [x,y,z] = np.meshgrid(nx,ny,nz)
    self.coordinates = [x.ravel(),y.ravel(),z.ravel()]

    ## Override coordinate grid if one if provided in init
    if coordinates_list!=None:
      self.coordinates = coordinates_list

    ## Check neuron number / number of coordinates
    if (len(self.coordinates[0])==neuron_num)&(len(self.coordinates[1])==neuron_num)&(len(self.coordinates[2])==neuron_num):
      pass
    else:
      raise ValueError('Network / coordinate structure does not match the number of neurons.')

    ## Calculate the euclidean distance matrix
    euclidean_vector = scipy.spatial.distance.pdist(np.transpose(self.coordinates), metric=distance_metric)
    euclidean = scipy.spatial.distance.squareform(euclidean_vector**self.distance_power)
    self.distance_matrix = euclidean.astype('float32')
    self.spatial_cost_matrix = self.distance_matrix

    ## Add minimal cost for recurrent self connection (on diagonal)
    #diag_dist = np.diag(np.repeat(0.1,100)).astype('float32')
    #self.distance_matrix = self.distance_matrix + diag_dist

    ## Create tensor from distance matrix
    self.distance_tensor =  tf.convert_to_tensor(self.distance_matrix)


  def __call__(self, x):
    # Add calculation of loss here.
    # L1 for reference: self.l1 * math_ops.reduce_sum(math_ops.abs(x))
    abs_weight_matrix = tf.math.abs(x)

    #se1_loss = self.se1 * tf.math.multiply(abs_weight_matrix, self.distance_tensor)
    #se1_loss = tf.math.reduce_sum(abs_weight_matrix)
    se1_loss = self.se1 * tf.math.reduce_sum(tf.math.multiply(abs_weight_matrix, self.distance_tensor))
    
    return se1_loss

  def _check_penalty_number(self, x):
    """check penalty number availability, raise ValueError if failed."""
    if not isinstance(x, (float, int)):
      raise ValueError(('Value: {} is not a valid regularization penalty number, '
                        'expected an int or float value').format(x))

  def visualise_distance_matrix(self):
    plt.imshow(self.distance_matrix)
    plt.colorbar()
    plt.show()

  def visualise_neuron_structure(self):
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(self.coordinates[0],self.coordinates[1],self.coordinates[2],c='b',marker='.')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    plt.show()

  def get_config(self):
    return {'se1': float(self.se1)}

In [8]:
class SE1_sWc(SE1):
    '''
    Version of SE1 regulariser which combines the spatial and communicability parts in loss function.
    Additional comms_factor scales the communicability matrix.
    The communicability term used here is unbiased weighted communicability:
    Crofts, J. J., & Higham, D. J. (2009). A weighted communicability measure applied to complex brain networks. Journal of the Royal Society Interface, 6(33), 411-414.
    '''
    def __init__(self, se1=0.01, comms_factor = 1, neuron_num = 100, network_structure = (5,5,4), coordinates_list = None, distance_power = 1, distance_metric = 'euclidean'):
      SE1.__init__(self, se1, neuron_num , network_structure , coordinates_list, distance_power , distance_metric)
      self.comms_factor = comms_factor

    def __call__(self, x):
      # Take absolute of weights
      abs_weight_matrix = tf.math.abs(x)

      # Calulcate weighted communicability (see reference in docstring)
      stepI = tf.math.reduce_sum(abs_weight_matrix, axis=1)
      stepII = tf.math.pow(stepI, -0.5)
      stepIII = tf.linalg.diag(stepII)
      stepIV = tf.linalg.expm(stepIII@abs_weight_matrix@stepIII)
      comms_matrix = tf.linalg.set_diag(stepIV, tf.zeros(stepIV.shape[0:-1]))

      # Multiply absolute weights with communicability weights
      comms_matrix = comms_matrix**self.comms_factor
      comms_weight_matrix = tf.math.multiply(abs_weight_matrix, comms_matrix)

      # Multiply comms weights matrix with distance tensor, scale the mean, and return as loss
      se1_loss = self.se1 * tf.math.reduce_sum(tf.math.multiply(comms_weight_matrix , self.distance_tensor))
      
      return se1_loss

### Model training helper functions

Here we define a callback to give us easy access to weight matrices after training.

In [9]:
class RNNWeightMatrixHistoryI(keras.callbacks.Callback):
    '''
    Saves the RNN_Weight_Matrix to the training history before
    the start of training and after finishing each epoch.

    The network model needs to be build manually before calling fit() method
    for this callback to work.
    '''
    def __init__(self, RNN_layer_number = 0):
        super(RNNWeightMatrixHistoryI, self).__init__()
        self.RNN_layer_number = RNN_layer_number

    def on_train_begin(self, logs=None):
        # Create key for RNN_Weight_Matrix in history
        self.model.history.history['RNN_Weight_Matrix'] = []
        #print("Created key for RNN_Weight_Matrix in history.")

        # Save matrix before start of training
        self.model.history.history['RNN_Weight_Matrix'].append(self.model.layers[self.RNN_layer_number].get_weights()[1])
        #print("Saved RNN_Weight_Matrix to history.")

    def on_epoch_end(self, epoch, logs=None):
        # Save RNN_Weight_Matrix to history
        self.model.history.history['RNN_Weight_Matrix'].append(self.model.layers[self.RNN_layer_number].get_weights()[1])
        #print("Saved RNN_Weight_Matrix to history.")

### Model defintion

In [10]:
# Example regularisation strength from set of networks used in preprint
regu_strength = 0.3
print(regu_strength)

0.3


In [11]:
tf.keras.backend.clear_session()
regu = SE1_sWc(se1=regu_strength)
coord = regu.coordinates
cost = regu.spatial_cost_matrix

## Assemble network
tf_model = tf.keras.models.Sequential([
    tf.keras.layers.GaussianNoise(stddev=0.05),
    tf.keras.layers.SimpleRNN(100, activation='relu',recurrent_initializer='orthogonal', return_sequences=False, recurrent_regularizer= regu),
    tf.keras.layers.Dense(4, activation='softmax')
])
tf_model.build(input_shape=example_data[0].shape)

### Model training

In [12]:
tf_model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [13]:
history = tf_model.fit(tfdf, epochs=10,validation_data=tfdf_test,
                       callbacks=RNNWeightMatrixHistoryI(RNN_layer_number=1)
                       )

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


### Structural analysis

In [14]:
%pip install bctpy
import bct 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [15]:
# Extract example absolute weight matrix from trained network
history_dic = history.history
example_weight_matrix = np.abs(history_dic['RNN_Weight_Matrix'][10])

In [16]:
# Binarise network before structural analysis
binary_weight_matrix = example_weight_matrix.copy()
thresh = np.quantile(example_weight_matrix, q=0.9)
matrix_mask = example_weight_matrix > thresh
binary_weight_matrix[matrix_mask] = 1
binary_weight_matrix[~matrix_mask] = 0
#binary_weight_matrix

#### Modularity

In [17]:
## Note that the example network chose here shows relatively high modularity value
_, q_stat = bct.modularity_und(binary_weight_matrix, gamma=1)
print(q_stat)

0.20619601


#### Small-worldness

In [18]:
# Empirical clustering and path length
A = binary_weight_matrix
clu = np.mean(bct.clustering_coef_bu(A))
pth = bct.efficiency_bin(A)
# Run nperm null models
nperm = 1000
cluperm = np.zeros((nperm,1))
pthperm = np.zeros((nperm,1))
for perm in range(nperm):
    Wperm = np.random.rand(100,100)
    # Make it into a matrix
    Wperm = np.matrix(Wperm)
    # Make symmetrical
    Wperm = Wperm+Wperm.T
    Wperm = np.divide(Wperm,2)
    # Binarise
    threshold, upper, lower = .7,1,0
    Aperm = np.where(Wperm>threshold,upper,lower)
    # Take null model
    cluperm[perm] = np.mean(bct.clustering_coef_bu(Aperm))
    pthperm[perm] = bct.efficiency_bin(Aperm)
# Take the average of the nulls
clunull = np.mean(cluperm)
pthnull = np.mean(pthperm)
# Compute the small worldness
smw = np.divide(np.divide(clu,clunull),np.divide(pth,pthnull))

print(smw)

4.3051189602119315
