<a href="https://colab.research.google.com/github/asantoangles/dcm_working_memory/blob/main/CCN_2025_Working_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Digital brain models for working memory

## Setup

In [None]:
# @title Run this cell to import required modules

!pip install networkx --quiet

import numpy as np
import random
import scipy.signal as sig
import scipy.stats as stat
import matplotlib.pyplot as plt
import networkx as nx
import gdown
import ipywidgets as widgets
from IPython.display import Image
import pandas as pd
from google.colab import files, data_table
import math
from scipy.interpolate import interp1d
from scipy.integrate import solve_ivp
from tqdm import tqdm
from difflib import get_close_matches

data_table.enable_dataframe_formatter()

In [None]:
# @title Functions






def load_structural_data():

  #########
  ## FLN ##
  #########

  # Define the remote file to retrieve
  remote_url = 'https://drive.google.com/uc?id=1J0V1Hig_JE2XhsPCpAV3UlYzRq8NkEl3'
  # Define the local filename to save data
  local_file = 'FLN.npy'
  # Make http request for remote file data
  gdown.download(remote_url, local_file, quiet = True)

  FLN = np.load(local_file)



  #########
  ## SLN ##
  #########

  # Define the remote file to retrieve
  remote_url = 'https://drive.google.com/uc?id=19mBa4Yv2XsfwQw3n_0EeZX8as3CFO0Co'
  # Define the local filename to save data
  local_file = 'SLN.npy'
  # Make http request for remote file data
  gdown.download(remote_url, local_file, quiet = True)

  SLN = np.load(local_file)



  ##################
  ## Area centers ##
  ##################

  # Define the remote file to retrieve
  remote_url = 'https://drive.google.com/uc?id=1amlbDasESV-IcunPRD3uWh8hjMYS-X7K'
  # Define the local filename to save data
  local_file = 'centers.npy'
  # Make http request for remote file data
  gdown.download(remote_url, local_file, quiet = True)

  centers = np.load(local_file)



  ##########################
  ## Area labels and info ##
  ##########################

  # Define the remote file to retrieve
  remote_url = 'https://drive.google.com/uc?id=13WjXAnOYNFnVcFQh8aEiyHdsxp4f2PV-'
  # Define the local filename to save data
  local_file = 'area_info.csv'
  # Make http request for remote file data
  gdown.download(remote_url, local_file, quiet = True)

  # Gets list of used areas
  area_description = pd.read_csv(local_file, delimiter=';')

  ###############
  ## Hierarchy ##
  ###############

  # Define the remote file to retrieve
  remote_url = 'https://drive.google.com/uc?id=1SOaxTOrWxQeou8fUFFsHuq7RYaiuP1Ci'
  # Define the local filename to save data
  local_file = 'hierarchy.npy'
  # Make http request for remote file data
  gdown.download(remote_url, local_file, quiet = True)

  # Gets list of used areas
  h = np.load(local_file)

  return FLN, SLN, centers, area_description, h










# Plotting functions for structural matrices
def plot_matrix(matrix, coordinates, node_sizes = 1, node_labels = None, node_label_color = 'w', node_colors = 0, node_cmap = 'gray', node_alpha = 0.5, threshold = 0.1, title = '', view = 'yz', vmin = None, vmax = None, dpi = 150):

    view_dict = {'xy': (0, 1), 'xz': (0, 2),'yz': (1, 2)}
    idx1, idx2 = view_dict[view]

    node_sizes *= np.ones(matrix.shape[0])
    node_colors *= np.ones(matrix.shape[0])


    pos = {idx: coordinates[idx, np.array([idx1, idx2])] for idx in range(matrix.shape[0])}

    if view == 'yz':
        _, ax = plt.subplots(1, 1, figsize = (10, 6), dpi = dpi)
    elif view == 'xy':
        _, ax = plt.subplots(1, 1, figsize = (5, 10), dpi = dpi)
    elif view == 'xz':
        _, ax = plt.subplots(1, 1, figsize = (5, 6), dpi = dpi)

    ax.set_title(title)

    aux = matrix.copy()
    th = np.sort(matrix.flatten())[int((1-threshold) * len(matrix.flatten()))] # Removes X% weakest connections
    aux[matrix < th] = 0
    G = nx.from_numpy_array(aux.T, create_using=nx.DiGraph)
    widths = [5 * G.get_edge_data(e[0], e[1])["weight"] for e in G.edges]


    if vmax is None or vmin is None:
        if np.sum(node_colors < 0) > 0:
            vmin = -np.nanmax(np.abs(node_colors))
            vmax = np.nanmax(np.abs(node_colors))
        else:
            vmin = np.nanmin(node_colors)
            vmax = np.nanmax(node_colors)

    nodes = nx.draw_networkx_nodes(G, pos, node_size = 300 * node_sizes, node_color=node_colors, cmap = node_cmap,
                                   vmin = vmin, vmax = vmax, ax = ax, alpha = node_alpha)
    edges = nx.draw_networkx_edges(G, pos, node_size = 300 * node_sizes, ax = ax, arrows = True,
                                arrowstyle="-|>", arrowsize=10, edge_color='k', width=widths,
                                connectionstyle="arc3,rad=0.15")

    if node_labels is not None:
        for i in range(len(node_labels)):
            ax.text(coordinates[i, idx1], coordinates[i, idx2], node_labels[i], ha = 'center', va = 'center', fontsize = 5, color = node_label_color)


    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    if view == 'yz':
        ax.set_xlabel('Posterior                                                                                                                                                                                          Anterior')
        ax.set_ylabel('Ventral                                                                                                      Dorsal')

    if view == 'xy':
        ax.set_xlabel('Lateral                                                                              Medial')
        ax.set_ylabel('Posterior                                                                                                                                                                                          Anterior')

    if view == 'xz':
        ax.set_xlabel('Lateral                                                                              Medial')
        ax.set_ylabel('Ventral                                                                                                      Dorsal')

    plt.tight_layout()
    plt.show()












def plot_matrix_EI(matrix, coordinates, node_sizes = 1, node_labels = None, node_label_color = 'w', node_colors = 0, node_cmap = 'gray', node_alpha = 0.5, threshold = 0.1, title = '', view = 'yz', vmin = None, vmax = None, dpi = 150):

    view_dict = {'xy': (0, 1), 'xz': (0, 2),'yz': (1, 2)}
    idx1, idx2 = view_dict[view]

    node_sizes *= np.ones(matrix.shape[0])
    node_colors *= np.ones(matrix.shape[0])


    pos = {idx: coordinates[idx, np.array([idx1, idx2])] for idx in range(matrix.shape[0])}

    if view == 'yz':
        _, ax = plt.subplots(1, 1, figsize = (10, 6), dpi = dpi)
    elif view == 'xy':
        _, ax = plt.subplots(1, 1, figsize = (5, 10), dpi = dpi)
    elif view == 'xz':
        _, ax = plt.subplots(1, 1, figsize = (5, 6), dpi = dpi)

    ax.set_title(title)

    aux = matrix.copy()
    th = np.sort(np.abs(matrix).flatten())[int((1-threshold) * len(matrix.flatten()))] # Removes X% weakest connections
    aux[np.abs(matrix) < th] = 0
    G = nx.from_numpy_array(aux.T, create_using=nx.DiGraph)
    widths = [5 * G.get_edge_data(e[0], e[1])["weight"] for e in G.edges]

    if vmax is None or vmin is None:
        if np.sum(node_colors < 0) > 0:
            vmin = -np.nanmax(np.abs(node_colors))
            vmax = np.nanmax(np.abs(node_colors))
        else:
            vmin = np.nanmin(node_colors)
            vmax = np.nanmax(node_colors)

    nodes = nx.draw_networkx_nodes(G, pos, node_size = 300 * node_sizes, node_color=node_colors, cmap = node_cmap,
                                   vmin = vmin, vmax = vmax, ax = ax, alpha = node_alpha)


    # Plots excitation dominated edges
    aux_E = aux.copy()
    aux_E[aux_E < 0] = 0
    G = nx.from_numpy_array(aux_E.T, create_using=nx.DiGraph)
    widths = [15 * G.get_edge_data(e[0], e[1])["weight"] for e in G.edges]

    edges = nx.draw_networkx_edges(G, pos, node_size=300 * node_sizes, ax = ax, arrows = True,
                                arrowstyle="-|>", arrowsize=10, edge_color='darkred', width=widths,
                                connectionstyle="arc3,rad=0.15")

    # Plots inhibition dominated edges
    aux_I = aux.copy()
    aux_I[aux_I > 0] = 0
    G = nx.from_numpy_array(aux_I.T, create_using=nx.DiGraph)
    widths = [15 * G.get_edge_data(e[0], e[1])["weight"] for e in G.edges]

    edges = nx.draw_networkx_edges(G, pos, node_size=300 * node_sizes, ax = ax, arrows = True,
                            arrowstyle="-|>", arrowsize=10, edge_color='teal', width=widths,
                            connectionstyle="arc3,rad=0.15")

    if node_labels is not None:
        for i in range(len(node_labels)):
            ax.text(coordinates[i, idx1], coordinates[i, idx2], node_labels[i], ha = 'center', va = 'center', fontsize = 5, color = node_label_color)


    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    if view == 'yz':
        ax.set_xlabel('Posterior                                                                                                                                                                                          Anterior')
        ax.set_ylabel('Ventral                                                                                                      Dorsal')

    if view == 'xy':
        ax.set_xlabel('Lateral                                                                              Medial')
        ax.set_ylabel('Posterior                                                                                                                                                                                          Anterior')

    if view == 'xz':
        ax.set_xlabel('Lateral                                                                              Medial')
        ax.set_ylabel('Ventral                                                                                                      Dorsal')

    plt.tight_layout()
    plt.show()


def dxdt(t, x, tau_noise, sigma):
    """
    Compute the derivative of the OU process.

    Parameters
    ----------
    t : float
        Time
    x : np.ndarray
        1D array of length N
    tau_noise : float
        Time constant of the noise
    sigma : float
        Standard deviation of the noise term

    Returns
    -------
    dxdt : np.ndarray
        1D array of length N
    """
    return -x / tau_noise + sigma * np.random.normal(0, 1, len(x)) / np.sqrt(tau_noise)


def generate_noise_function(N, t_eval, tau_noise, sigma):
    """
    Generate a noise function for a fixed sigma and return an interpolated function.

    Parameters
    ----------
    N : int
        Number of brain areas
    t_eval : np.ndarray
        Time points at which to save the noise (1D array)
    tau_noise : float
        Time constant of the noise
    sigma : float
        Standard deviation of the noise term

    Returns
    -------
    interp_func : callable
        A function that takes time t and returns noise values
    """
    # Generate OU noise using solve_ivp
    sol = solve_ivp(
        dxdt,
        [t_eval[0], t_eval[-1]],
        y0=np.zeros(N),
        t_eval=t_eval,
        args=(tau_noise, sigma),
    )
    noise = sol.y.T # Shape: (N_t, N)

    # Create an interpolation function for the noise
    interp_func = interp1d(
        t_eval, noise, axis=0, kind="linear", fill_value="extrapolate"
    )

    return interp_func


def correct_spine_counts(spine_counts):
    """
    Correct the spine counts based on age factors. For ages 5 and 10, the spine counts are increased by 15% and 30%, respectively.

    Parameters
    ----------
    spine_counts : np.ndarray
        Spine counts

    Returns
    -------
    corrected_spine_counts : np.ndarray
    """
    # age factor
    AF = np.ones(30)
    AF[[13, 14, 15, 16, 18, 29]] = np.multiply(AF[[13, 14, 15, 16, 18, 29]], 1.15) # 5yo, 15% increase
    AF[[5, 7, 20]] = np.multiply(AF[[5, 7, 20]], 1.30) # 10yo, 30% increase

    return spine_counts * AF

def load_and_preprocess_data(data_folder_path=None):
    """
    Load and preprocess the data.

    Parameters
    ----------
    data_folder_path : str, optional
        Path to the data folder containing the data files. If None, the relative path is used.

    Returns
    -------
    area_names : np.ndarray
        Array of area names
    area_lobes : np.ndarray
        Array of area lobes
    area_descriptions : np.ndarray
        Array of area descriptions
    SLN : np.ndarray
        Structural connectivity matrix
    W : np.ndarray
        Weight matrix. W_ij = 1.2 * FLN_ij^0.3 / sum_j(1.2 * FLN_ij^0.3)
    hier_vals : np.ndarray
        Hierarchy values. If spine counts are negative (missing), the original hierarchy values are used.
    """
    if data_folder_path is None:
        data_folder_path = "data"

    ## Area names ##
    export_url = 'https://drive.google.com/uc?export=download&id=1vVc1GXFb-a-J5uiu8__5enbiQUZCMNhU'
    # Define the local filename to save data
    local_file = 'areaList.xlsx'
    # Make http request for remote file data
    gdown.download(export_url, local_file, quiet = True)
    df = pd.read_excel(local_file, header = 0)
    area_names = df['Name'].to_numpy()
    area_lobes = df['Lobe'].to_numpy()
    area_descriptions = df['Description'].to_numpy()

    ## SLN ##
    export_url = 'https://drive.google.com/uc?export=download&id=1I5iebwYaYi2F7Zo3oZd_9mtQmctGgmEc'
    local_file = 'sln.xlsx'
    gdown.download(export_url, local_file, quiet = True)
    df = pd.read_excel(local_file, header=None)
    SLN = df.to_numpy()

    ## FLN ##
    export_url = 'https://drive.google.com/uc?export=download&id=1bBsSd-nWWb80fcGEB4L-yeDjgTx2OT3X'
    local_file = 'fln.xlsx'
    gdown.download(export_url, local_file, quiet = True)
    df = pd.read_excel(local_file, header=None)
    FLN = df.to_numpy()
    W = 1.2 * FLN ** 0.3
    W /= np.sum(W, axis=1, keepdims=True)

    ## Hierarchy values ##
    export_url = 'https://drive.google.com/uc?export=download&id=1zFV1BH8erMqVR9zQBHSP8hWslwwvcl5p'
    local_file = 'hierVals.xlsx'
    gdown.download(export_url, local_file, quiet = True)
    df = pd.read_excel(local_file, header=None)
    hier_vals1 = df.to_numpy().flatten() # hier_vals1 is the original hierarchy values

    ## Spine counts ##
    export_url = 'https://drive.google.com/uc?export=download&id=1QhW4Y9yoOzLsjJnFcKm-KH_N9unbJnMS'
    local_file = 'spineCounts_ageCorrected.xlsx'
    gdown.download(export_url, local_file, quiet = True)
    df = pd.read_excel(local_file, header=None)
    spinec = df.to_numpy().flatten()

    hier_vals = spinec / np.max(spinec)
    indices = np.where(spinec < 0)
    hier_vals[indices] = hier_vals1[indices] # replace empty spine counts with original hierarchy values

    return {'area_names': area_names, 'area_lobes': area_lobes, 'area_descriptions': area_descriptions, 'SLN': SLN, 'W': W, 'hier_vals': hier_vals}


def get_success_rate(desc, WM_network, t_end, dt, I_ext_strengths, ts_ext_start, ts_ext_end, runs=50, threshold=10, monitor_areas_idx=range(12, 30)):
    """
    Get the success rates of the WM model for a given year.

    Parameters
    ----------
    desc : str
        Description of the simulation
    WM_network : WorkingMemoryNetwork
        Working Memory network
    t_end : float
        End time of the simulation
    dt : float
        Time step
    I_ext_strengths : np.ndarray
        External input strengths
    ts_ext_start : np.ndarray
        Start times of the external inputs
    ts_ext_end : np.ndarray
        End times of the external inputs
    runs : int, optional
        Number of runs
    threshold : float, optional
        Threshold for distinguishing successful trials
    monitor_areas_idx : np.ndarray, optional
        Indices of areas to monitor and check for success

    Returns
    -------
    success_rate : float
        Success rate
    """
    success_times = 0
    for seed in tqdm(range(runs), desc=desc):
        WM_network.reset(random_seed=seed)
        state_history = WM_network.run(t_end=t_end, dt=dt, I_ext_strengths=I_ext_strengths, ts_ext_start=ts_ext_start, ts_ext_end=ts_ext_end)
        final_r_A = np.mean(state_history[int(-1/dt), monitor_areas_idx, 0], axis=0)    # mean r_A in the last one second for each area
        if np.sum(final_r_A > threshold) > len(monitor_areas_idx) / 2:
            success_times += 1
    return success_times / runs

In [None]:
#@title Parameters
class WorkingMemoryParameters:
    """
    This class includes various parameters involved in the dynamics of a working memory model.
    Parameters are stored in a dictionary for flexibility and can be updated using the `update` method.
    """

    def __init__(self):
        # Directly initializing self._params
        self.__dict__["_params"] = {
            # I -> r
            "A": 135,           # Hz/nA
            "B": 54,            # Hz
            "D": 0.308,         # s
            "G_I": 4,
            "C_0": 177,         # Hz
            "C_1": 615,         # Hz/nA
            "R_0": 5.5,         # Hz
            "TAU_R": 2e-3,      # s

            # r -> S
            "TAU_N": 60e-3,     # s
            "TAU_G": 5e-3,      # s
            "GAMMA": 1.282,
            "GAMMA_I": 2.0,

            # S -> I
            "J_S": 0.3213,      # nA
            "J_C": 0.0107,      # nA
            "J_IE": 0.15,       # nA
            "J_EI": -0.31,      # nA
            "J_II": -0.12,      # nA
            "I_0A": 0.3294,     # nA
            "I_0B": 0.3294,     # nA
            "I_0C": 0.26,       # nA

            # Noise
            "TAU_NOISE": 2e-3,  # s
            "SIGMA_A": 0.01,    # nA
            "SIGMA_B": 0.01,    # nA

            # Gradient of J_s
            "J_MIN": 0.21,      # nA
            "J_MAX": 0.42,      # nA (< 0.4655)

            # Gradient of J_IE
            "J_0": 0.2112,      # nA
            "ZETA": None,

            # Inter-areal projections
            "G": 0.48,          # global coupling strength
            "Z": None,          # E-I balancing factor
            "ALPHA": 1          # relative strength of (feedback) inhibitory projections
        }

        # Initialize dependent parameters
        self._recalculate_dependent_parameters()

    def __getattr__(self, name):
        """
        Dynamically retrieve the value of a parameter.
        """
        # Safely access self._params using __dict__ to avoid recursion
        params = self.__dict__.get("_params", {})
        if name in params:
            return params[name]

        # Suggest similar parameter names
        similar_keys = get_close_matches(name, params.keys(), n=3, cutoff=0.5)
        if similar_keys:
            suggestion = f" Did you mean one of these? {similar_keys}"
        else:
            suggestion = f" Available parameters are: {list(params.keys())}"

        raise AttributeError(
            f"'WorkingMemoryParameters' object has no attribute '{name}'."
            f"{suggestion}"
        )

    def update(self, **kwargs):
        """
        Update parameters of the model.
        """
        dependent_params = ["ZETA", "Z"]
        for key, value in kwargs.items():
            if key in dependent_params:
                raise AttributeError(f"Cannot directly update dependent parameter '{key}'.")
            if key.upper() in map(str.upper, self._params.keys()):  # Ignore case
                matching_key = [k for k in self._params.keys() if k.upper() == key.upper()][0]
                self._params[matching_key] = value
            else:
                raise AttributeError(f"Invalid parameter '{key}'. Available parameters are: {list(self._params.keys())}")
        self._recalculate_dependent_parameters()

    def _recalculate_dependent_parameters(self):
        """
        Recalculate dependent parameters after updating.
        """
        self._params["ZETA"] = self.calc_zeta()
        self._params["Z"] = self.calc_Z()

    def calc_zeta(self):
        """
        Compute the inhibitory-to-excitatory balancing factor (ZETA).
        """
        tau_g = self._params["TAU_G"]
        gamma_i = self._params["GAMMA_I"]
        c_1 = self._params["C_1"]
        j_ii = self._params["J_II"]
        g_i = self._params["G_I"]

        denominator = g_i - j_ii * tau_g * gamma_i * c_1
        if denominator <= 0:
            raise ValueError("Denominator in ZETA calculation is non-positive.")
        return (tau_g * gamma_i * c_1) / denominator

    def calc_Z(self):
        """
        Compute the E-I balancing factor (Z).
        """
        tau_g = self._params["TAU_G"]
        gamma_i = self._params["GAMMA_I"]
        c_1 = self._params["C_1"]
        j_ei = self._params["J_EI"]
        j_ii = self._params["J_II"]
        g_i = self._params["G_I"]

        denominator = c_1 * tau_g * gamma_i * j_ii - g_i
        if denominator == 0:
            raise ValueError("Denominator in Z calculation is zero.")
        return (2 * c_1 * tau_g * gamma_i * j_ei) / denominator
params = WorkingMemoryParameters()

In [None]:
#@title Working memory class
class WorkingMemoryNetwork:
    def __init__(self, area_names, area_lobes, Y0, W, F, h, params, remove_FB=False, random_seed=None):
        """
        Initialize the network object for the working memory model.

        Parameters
        ----------
        area_names: list
            List of brain area names.
        area_lobes: list
            List of brain area lobes.
        Y0: np.ndarray
            N x 8 matrix of initial conditions. Each row is a brain area, each column is a variable vector (r_A, r_B, r_C, S_A, S_B, S_C, x_A, x_B).
        W: np.ndarray
            N x N matrix of connectivity strengths.
        F: np.ndarray
            N x N matrix of feedforward (FF) relative strengths.
        h: np.ndarray
            1D array of the positions of the brain areas in the T1w:T2w ranking.
        params: object
            Parameters object containing the parameters of the model.
        remove_FB: bool, optional
            Whether to remove feedback projections. Default is False.
        random_seed: int, optional
            Random seed for reproducibility. Default is None.
        """
        self.area_names = area_names.copy()
        self.area_lobes = area_lobes.copy()
        self.N = len(area_names) # number of brain areas

        # Check the dimensions of the input matrices
        assert len(area_lobes) == self.N, f"Length of area_lobes must be {self.N}!"
        assert Y0.shape == (self.N, 8), f"Y0 must be a {self.N} x 8 matrix!"
        assert W.shape == (self.N, self.N), f"W must be a {self.N} x {self.N} matrix!"
        assert F.shape == (self.N, self.N), f"F must be a {self.N} x {self.N} matrix!"
        assert self.N == h.shape[0], f"The length of h must be {self.N}!"
        assert np.min(W) >= 0, "Ensure all connections are non-negative!"
        assert np.all(np.diag(W) == 0), "Ensure no self-connections!"
        assert np.min(F) >= 0 and np.max(F) <= 1, "Ensure all feedforward relative strengths are between 0 and 1!"
        assert np.min(h) >= 0 and np.max(h) <= 1, "Ensure all hierarchical level values are between 0 and 1!"

        self.Y0 = Y0.copy() # store initial conditions
        self.Y = Y0.copy() # initialize current state
        self.W = W.copy() # connectivity matrix (FLN)
        self.F = F.copy() # FF relative strength matrix (SLN)
        self.h = h.copy() # hierarchical level
        self.params = params

        self.J_s = self.calc_J_s() # calculate gradient of J_s
        self.J_IE = self.calc_J_IE() # calculate gradient of J_IE

        # "SLN-driven modulation of FB projections between frontal areas is not too large, so that interactions between these areas are never strongly inhibitory.
        # In practice, such constraint is only necessary for projections from frontal areas to 8 l and 8 m.
        # We consider that the SLN-driven modulation of FB projections to 8 l and 8 m is never larger than 0.4."
        self.lambda_I = 1 - self.F # relative strength of (feedback) inhibitory projections
        self.frontal_indices = [i for i in range(self.N) if self.area_lobes[i] in ('Frontal', 'Prefrontal')]
        inhib_max = 0.4 # maximum inhibitory strength
        for i in range(2): # 8l and 8m are the first two frontal areas
            for source_area_idx in self.frontal_indices:
                x_idx = self.frontal_indices[i]
                y_idx = source_area_idx
                self.lambda_I[x_idx, y_idx] = min(self.lambda_I[x_idx, y_idx], inhib_max)

        self.W_E, self.W_I = self.calc_WE_WI()

        if remove_FB: # remove feedback projections
            self.W_E -= np.triu(self.W_E, k=1)
            self.W_I -= np.triu(self.W_I, k=1)

        self.state_histories = [] # list to store the state history
        self.time_histories = [] # list to store the time history
        self.monitor = [] # list to store the monitored variables

        if random_seed:
            self.random_seed = random_seed
            self.rng = np.random.default_rng(random_seed)
        else:
            self.random_seed = None
            self.rng = np.random.default_rng()

    def calc_WE_WI(self):
        """
        Returns
        -------
        W_E, W_I - strenghts of excitatory - excitatory and excitatory - inhibitory long range connection strength, respectively.
        """
        # Below, we denote SLN as self.F and (1 - SLN) as self.lambda_I
        # np.outer(self.J_s / np.max(self.J_s) and np.outer(self.J_IE / np.max(self.J_IE), np.ones(self.N))
        # are normalizing factors
        W_E = self.params.G * self.W * np.outer(self.J_s / np.max(self.J_s), np.ones(self.N)) * self.F
        W_I = self.params.ALPHA * self.params.G / self.params.Z * self.W * np.outer(self.J_IE / np.max(self.J_IE), np.ones(self.N)) * self.lambda_I

        return W_E, W_I

    def calc_J_s(self):
        """
        Calculate the synaptic strength J_s for each brain area.
        """
        return self.params.J_MIN + (self.params.J_MAX - self.params.J_MIN) * self.h

    def calc_J_IE(self):
        """
        Calculate the synaptic strength J_IE for each brain area.
        """
        return (self.params.J_0 - self.J_s - self.params.J_C) / (2 * self.params.J_EI * self.params.ZETA)

    def phi_E(self, I):
        """
        Transfer function for excitatory populations.
        """
        return (self.params.A * I - self.params.B) / (1 - np.exp(-self.params.D * (self.params.A * I - self.params.B)))

    def phi_I(self, I):
        """
        Transfer function for inhibitory populations.
        """
        return np.maximum(self.params.R_0 + (self.params.C_1 * I - self.params.C_0) / self.params.G_I, 0)

    def calc_dY(self, Y, dt, dt_over_taus, I_ext):
        """
        Calculate the differential of Y for this N x 8 system.
        Function will become a method to the working memory class.

        Parameters
        ----------
        Y: np.ndarray
            N x 8 matrix of variables (r_A, r_B, r_C, S_A, S_B, S_C, x_A, x_B).
        dt: float
            Time step size.
        dt_over_taus: np.ndarray
            Precomputed array of time step ratios (dt/tau_r, dt/tau_N, dt/tau_G, dt/tau_noise, sqrt(dt/tau_noise)).
        I_ext: np.ndarray
            N x 3 matrix of external inputs. Each row is a brain area, each column is an external input (I_ext_A, I_ext_B, I_ext_C).

        Returns
        -------
        dY: np.ndarray
            N x 8 matrix of differentials.
        """
        # Extract variables and precomputed time step ratios
        r_A, r_B, r_C = Y[:, 0], Y[:, 1], Y[:, 2]
        S_A, S_B, S_C = Y[:, 3], Y[:, 4], Y[:, 5]
        x_A, x_B = Y[:, 6], Y[:, 7]
        dt_over_tau_r, dt_over_tau_N, dt_over_tau_G, dt_over_tau_noise, sqrt_dt_over_tau_noise = dt_over_taus

        # Compute the inputs
        I_A = self.J_s * S_A + self.params.J_C * S_B + self.params.J_EI * S_C + self.params.I_0A + self.W_E @ S_A + x_A + I_ext[:, 0]
        I_B = self.params.J_C * S_A + self.J_s * S_B + self.params.J_EI * S_C + self.params.I_0B + self.W_E @ S_B + x_B + I_ext[:, 1]
        I_C = self.J_IE * S_A + self.J_IE * S_B + self.params.J_II * S_C + self.params.I_0C + self.W_I @ (S_A + S_B) + I_ext[:, 2]

        # Compute the differentials of firing rates
        dr_A = (-r_A + self.phi_E(I_A)) * dt_over_tau_r
        dr_B = (-r_B + self.phi_E(I_B)) * dt_over_tau_r
        dr_C = (-r_C + self.phi_I(I_C)) * dt_over_tau_r

        # Compute the differentials of synaptic conductances
        dS_A = -S_A * dt_over_tau_N + self.params.GAMMA * (1 - S_A) * r_A * dt
        dS_B = -S_B * dt_over_tau_N + self.params.GAMMA * (1 - S_B) * r_B * dt
        dS_C = -S_C * dt_over_tau_G + self.params.GAMMA_I * r_C * dt

        # Compute the differentials of noise terms (without instantaneous part)
        dx_A = -x_A * dt_over_tau_noise + self.params.SIGMA_A * sqrt_dt_over_tau_noise * self.rng.normal(loc=0, scale=1, size=(self.N,))
        dx_B = -x_B * dt_over_tau_noise + self.params.SIGMA_B * sqrt_dt_over_tau_noise * self.rng.normal(loc=0, scale=1, size=(self.N,))

        # Combine the differentials
        dY = np.zeros_like(Y)
        dY[:, 0] = dr_A
        dY[:, 1] = dr_B
        dY[:, 2] = dr_C
        dY[:, 3] = dS_A
        dY[:, 4] = dS_B
        dY[:, 5] = dS_C
        dY[:, 6] = dx_A
        dY[:, 7] = dx_B

        return dY

    def precompute_dt_over_taus(self, dt):
        """
        Precompute time step ratios to avoid recalculating during iterations.

        Parameters
        ----------
        dt: float
            Time step size.

        Returns
        -------
        dt_over_taus: np.ndarray
            Array of time step ratios and their square root for noise calculations.
        """
        taus = np.array([self.params.TAU_R, self.params.TAU_N, self.params.TAU_G, self.params.TAU_NOISE])
        remainders = taus % dt
        assert remainders.all() == 0, f"dt is suggested to be divisible by all time constants: {taus}"
        dt_over_taus = dt / taus
        return np.append(dt_over_taus, np.sqrt(dt_over_taus[-1]))  # Add sqrt(dt/tau_noise) for noise

    def run(self, t_end, dt, t_start=0, I_ext_strengths=None, ts_ext_start=None, ts_ext_end=None, lesion_areas=None):
        """
        Run the simulation.

        Parameters
        ----------
        t_end: float
            End time of the simulation.
        dt: float
            Time step.
        t_start: float, optional
            Start time of the simulation. Default is 0.
        I_ext_strengths: np.ndarray, optional
            N_period x N x 3 matrix of external inputs strengths. N_period is the number of input periods. In each input period, the external input strength is a N x 3 matrix. If None, no external input.
        ts_ext_start: np.ndarray, optional
            Start times of the external input. If None, no external input.
        ts_ext_end: np.ndarray, optional
            End times of the external input. If None, no external input.
        lesion_areas: list, optional
            List of brain areas to lesion. If None, no lesion.

        Returns
        -------
        state_history: np.ndarray
            State history in this simulation (len(t_eval) x N x 8 array).
        """
        # Generate time points
        t_eval = np.arange(t_start, t_end+dt, dt)
        self.time_histories.append(t_eval)

        # Check if external input is provided
        if I_ext_strengths is None:
            ts_ext_start, ts_ext_end = None, None
        else:
            assert I_ext_strengths.shape[0] == len(ts_ext_start) == len(ts_ext_end), "Number of external input periods should be consistent!"
            assert I_ext_strengths.shape[1:] == (self.N, 3), "In each external input period, I_ext_strength should be an N x 3 matrix!"
            for t_ext_start, t_ext_end in zip(ts_ext_start, ts_ext_end):
                assert t_ext_start < t_ext_end, "ts_ext_start must be less than ts_ext_end!"
                assert t_ext_start >= t_start and t_ext_end <= t_end, "External input time must be within the simulation time!"

        # Precompute time step ratios
        dt_over_taus = self.precompute_dt_over_taus(dt)

        # Euler-Maruyama integration
        Y = self.Y.copy()
        state_history = [Y.copy()]
        for i, t in enumerate(t_eval[1:], start=1):
            # Generate I_ext
            I_ext = np.zeros((self.N, 3))
            if ts_ext_start is not None: # if external input is provided
                for j, (t_ext_start, t_ext_end) in enumerate(zip(ts_ext_start, ts_ext_end)): # loop over external input periods
                    if t_ext_start <= t < t_ext_end:
                        I_ext += I_ext_strengths[j]

            # Calculate the differentials
            dY = self.calc_dY(Y, dt, dt_over_taus, I_ext)

            # Update the state
            Y += dY

            # Apply lesion if necessary
            if lesion_areas:
                if t > 0:
                    for lesion_area in lesion_areas:
                        area_idx = np.where(np.array(self.area_names) == lesion_area)[0][0]
                        Y[area_idx, :6] = 0

            # Store the current state
            state_history.append(Y.copy())

        # Store the state history
        self.state_histories.append(np.array(state_history))

        # Update the current state
        self.Y = Y

        return np.array(state_history)

    def get_state_histories(self, variables=['all']):
        """
        Get the state histories of the network.

        Parameters
        ----------
        variables: list
            list of variables to return. If ['all'], return all variables.
        Returns
        -------
        state_histories: list
            List of state matrices for all simulations (length N_sim). Each element is a N_t x N x N_var matrix.
        """
        if 'all' in variables:
            return self.state_histories

        hash_map = {'r_A': 0, 'r_B': 1, 'r_C': 2, 'S_A': 3, 'S_B': 4, 'S_C': 5, 'x_A': 6, 'x_B': 7}
        indices = [hash_map[variable] for variable in variables]

        extracted_histories = []
        for state_history in self.state_histories:
            # state_matrix is of shape (time_steps, N, 8)
            # Extract the desired variables along the last axis
            extracted_history = state_history[:, :, indices]
            extracted_histories.append(extracted_history)

        return extracted_histories

    def merge_histories(self):
        """
        Merge the time histories and state histories of all simulations respectively.

        Returns
        -------
        merged_time: np.ndarray
            Merged time histories (length N_t_total).
        merged_state: np.ndarray
            Merged state histories (N_t_total x N x 8).
        """
        # Check if history is available
        if not self.time_histories or not self.state_histories:
            raise ValueError("No history to merge. Run a simulation first.")

        # Initialize merged arrays
        merged_time = self.time_histories[0]
        merged_state = self.state_histories[0]

        # Iterate over the remaining histories
        for i in range(1, len(self.time_histories)):
            time = self.time_histories[i]
            state = self.state_histories[i]

            # Concatenate time and state
            merged_time = np.concatenate((merged_time, time[1:] + merged_time[-1] - time[0]))
            merged_state = np.concatenate((merged_state, state[1:]), axis=0)

        return merged_time, merged_state

    def clear_histories(self):
        """
        Clear the time histories and state histories.
        """
        self.state_histories = []
        self.time_histories = []
        self.monitor = []

    def reset(self, Y0=None, random_seed=None):
        """
        Reset the network to a given initial state or the original initial state.

        Parameters
        ----------
        Y0: np.ndarray, optional
            N x 8 matrix of initial conditions. Each row is a brain area, each column is a variable vector (r_A, r_B, r_C, S_A, S_B, S_C, x_A, x_B). If None, use the original initial state.
        random_seed: int, optional
            Random seed for reproducibility. If None, use the original random seed.
        """
        self.Y0 = Y0 or self.Y0 # if Y0 is not provided, use the original initial state
        self.Y = self.Y0 # reset the state
        self.clear_histories()
        if random_seed:
            self.random_seed = random_seed
        self.rng = np.random.default_rng(self.random_seed) # reset the random number generator

    def plot_all_areas(self, xlim, ylim, axes_flat=None, variables=['r_A', 'r_B'], colors=[[.1, .6, .8], [.6, 0, .5]], title=None, save_path=None, downsample=10, legend=True):
        """
        Plot the variables of all brain areas.

        Parameters
        ----------
        xlim: tuple
            Tuple of x-axis limits (start, end).
        axes_flat: matplotlib.axes.Axes, optional
            Flattened axes object to plot on. If None, create a new figure.
        variables: list of str
            List of variables to plot. Must match variable names in state history.
        colors: list
            List of colors for each variable, must match the length of `variables`.
        title: str, optional
            Title of the plot.
        save_path: str, optional
            Path to save the plot. If None, do not save the plot.
        downsample: int, optional
            Downsampling factor for time series data to speed up plotting. Default is 10.
        legend: bool, optional
            Whether to show the legend. Default is True.

        Returns
        -------
        lines: list
            List of Line2D objects representing the plotted data.
        """
        # Check variables and colors
        if not isinstance(variables, list) or not all(isinstance(v, str) for v in variables):
            raise ValueError("`variables` must be a list of strings representing variable names.")
        assert len(variables) == len(colors), "Length of variables and colors must be the same."

        # Prepare data and area names
        state_history = self.get_state_histories(variables)[-1]
        t = self.time_histories[-1] # Time array
        t_downsampled = t[::downsample] # Downsample time
        t_end = t[-1] # Use the last time point as t_end

        # Prepare plotting
        rows = math.ceil(self.N / 6) # Determine grid rows
        existing_plots = False # Track if there are existing plots in the axes
        if axes_flat is None:
            fig, axes = plt.subplots(rows, 6, figsize=(15, 2 * rows), dpi=100)
            fig.subplots_adjust(hspace=0.5, wspace=0.3)
            axes = axes.flatten() # Flatten for easy iteration
        else:
            axes = axes_flat
            if axes[0].lines:
                existing_plots = True

        # Plot firing rates for each brain area
        lines = []
        for i, ax in enumerate(axes[:self.N]): # Only iterate over valid brain areas
            ymax = 0 # Track max y value for this subplot
            for var_idx, var in enumerate(variables):
                data = state_history[:, i, var_idx][::downsample] # Downsample data
                if i == 0:
                    line = ax.plot(t_downsampled - 4, data, label='$'+var+'$', color=colors[var_idx], linewidth=2)
                    lines.append(line[0])
                    if legend:
                        ax.legend(loc="upper right")
                else:
                    ax.plot(t_downsampled - 4, data, color=colors[var_idx], linewidth=2)
                ymax = max(ymax, data.max()) # Update y max value

            # Add titles, labels, and ylimits
            ax.set_title(self.area_names[i], fontweight='bold')
            if existing_plots: # If there are existing plots before calling this function
                x_min1, x_max1 = ax.get_xlim()
                x_min2, x_max2 = xlim
                ax.set_xlim(min(x_min1, x_min2), max(x_max1, x_max2))
                y_min1, y_max1 = ax.get_ylim()
                y_min2, y_max2 = -max(ymax * 0.2, 0.5), max(ymax * 1.3, 2)
                ax.set_ylim(min(y_min1, y_min2), max(y_max1, y_max2))
            else: # If there are no existing plots before calling this function
                ax.set_xlim(*xlim)
                ax.set_ylim(-max(ymax * 0.2, 0.5), max(ymax * 1.3, 2))

            if i % 6 == 0: # First column: add ylabel
                ax.set_ylabel("Rate (sp/s)")
            if i >= (rows - 1) * 6: # Last row: add xlabel
                ax.set_xlabel("Time (s)")

            if i < 1:
                ax.set_ylim(0, 65)
            else:
                ax.set_ylim(*ylim)

            ax.grid(linestyle=':', alpha=0.5)
            for spine in ['top','right']:
                ax.spines[spine].set_visible(False)
            ax.xaxis.set_ticks_position('bottom')
            ax.yaxis.set_ticks_position('left')

        # Remove empty subplots if self.N is not a multiple of 6
        for j in range(self.N, len(axes)):
            fig.delaxes(axes[j])

        # Adjust layout and add optional title
        if title:
            plt.suptitle(title, fontweight='bold')
        plt.tight_layout()

        # Save the figure if a path is provided
        if save_path:
            plt.savefig(save_path, dpi=1200)

        if axes_flat is None:
            plt.show()

        return lines

    def plot_n_areas(self, xlim, ylim, target_area_names=['V1', 'MT', 'LIP', '24c', 'STPi', '9/46d'], axes_flat=None, variables=['r_A', 'r_B'], colors=[[.1, .6, .8], [.6, 0, .5]], title=None, save_path=None, downsample=10, legend=True):
        """
        Plot the firing rates of selected brain areas.

        Parameters
        ----------
        xlim: tuple
            Tuple of x-axis limits (start, end).
        target_area_names: list, optional
            List of brain area names to plot.
        axes_flat: matplotlib.axes.Axes, optional
            Flattened axes object to plot on. If None, create a new figure.
        variables: list of str
            List of variables to plot. Must match variable names in state history.
        colors: list
            List of colors for each variable, must match the length of `variables`.
        title: str, optional
            Title of the plot.
        save_path: str, optional
            Path to save the plot. If None, do not save the plot.
        downsample: int, optional
            Downsampling factor for time series data to speed up plotting. Default is 10.
        legend: bool, optional
            Whether to show the legend. Default is True.

        Returns
        -------
        lines: list
            List of Line2D objects representing the plotted data.
        """
        # Check variables and colors
        if not isinstance(variables, list) or not all(isinstance(v, str) for v in variables):
            raise ValueError("`variables` must be a list of strings representing variable names.")
        assert len(variables) == len(colors), "Length of variables and colors must be the same."

        # Prepare data and area names
        N_area = len(target_area_names)
        state_history = self.get_state_histories(variables)[-1]
        t = self.time_histories[-1] # Time array
        t_downsampled = t[::downsample] # Downsample time
        t_end = t[-1] # Use the last time point as t_end

        # Prepare plotting
        if N_area > 6:
            columns = 6
            rows = math.ceil(N_area / 6)
            figsize = (15, 2 * rows)
        else:
            columns = 3
            rows = math.ceil(N_area / 3)
            figsize = (8, 2.5 * rows)

        existing_plots = False # Track if there are existing plots in the axes
        if axes_flat is None:
            fig, axes = plt.subplots(rows, columns, figsize=figsize, dpi=100)
            fig.subplots_adjust(hspace=0.5, wspace=0.3)
            axes = axes.flatten() # Flatten for easy iteration
        else:
            axes = axes_flat
            if axes[0].lines:
                existing_plots = True

        # Plot firing rates for each brain area
        lines = []
        for i, ax in enumerate(axes[:N_area]): # Only iterate over valid brain areas
            area_idx = np.where(np.array(self.area_names) == target_area_names[i])[0][0]
            ymax = 0 # Track max y value for this subplot
            for var_idx, var in enumerate(variables):
                data = state_history[:, area_idx, var_idx][::downsample] # Downsample data
                if i == 0:
                    line = ax.plot(t_downsampled - 4, data, label='$'+var+'$', color=colors[var_idx], linewidth=2)
                    lines.append(line[0])
                    if legend:
                        ax.legend(loc="upper right", fontsize=10)
                else:
                    ax.plot(t_downsampled - 4, data, color=colors[var_idx], linewidth=2)
                ymax = max(ymax, data.max()) # Update y max value

            # Add titles, labels, and ylimits
            ax.set_title(target_area_names[i], fontsize=13, fontweight='bold')
            if existing_plots: # If there are existing plots before calling this function
                x_min1, x_max1 = ax.get_xlim()
                x_min2, x_max2 = xlim
                ax.set_xlim(min(x_min1, x_min2), max(x_max1, x_max2))
                y_min1, y_max1 = ax.get_ylim()
                y_min2, y_max2 = -max(ymax * 0.2, 0.5), max(ymax * 1.3, 2)
                ax.set_ylim(min(y_min1, y_min2), max(y_max1, y_max2))
            else: # If there are no existing plots before calling this function
                ax.set_xlim(*xlim)
                ax.set_ylim(-max(ymax * 0.2, 0.5), max(ymax * 1.3, 2))

            if i < 1:
                ax.set_ylim(0, 65)
            else:
                ax.set_ylim(*ylim)

            if i % columns == 0: # First column: add ylabel
                ax.set_ylabel("Rate (sp/s)", fontsize=12)
            if i >= (rows - 1) * columns: # Last row: add xlabel
                ax.set_xlabel("Time (s)", fontsize=12)

        # Remove empty subplots if self.N is not a multiple of 6
        for j in range(self.N, len(axes)):
            fig.delaxes(axes[j])

        # Adjust layout and add optional title
        if title:
            plt.suptitle(title, fontsize=18, fontweight='bold')
        plt.tight_layout()

        # Save the figure if a path is provided
        if save_path:
            plt.savefig(save_path, dpi=300)  # Set proper DPI for high-quality saving

        if axes_flat is None:
            plt.show()

        return lines

    def avg_firing_rates_last_two_seconds(self, variables=['r_A']):
        """
        Return the average firing rates (of pop A) over the last two seconds of the simulation for all brain areas.

        Parameters
        ----------
        variables: list of str
            List of variables to plot. Must match variable names in state history.

        Returns
        -------
        None
        """
        # Check variables
        if not isinstance(variables, list) or not all(isinstance(v, str) for v in variables):
            raise ValueError("`variables` must be a list of strings representing variable names.")

        # Extract time and state histories
        t = self.time_histories[-1] # Time array
        state_history = self.get_state_histories(variables)[-1] # Firing rate data

        # Identify the indices for the last two seconds
        end_time = t[-1]
        start_time = end_time - 2 # Last two seconds
        indices = np.where((t >= start_time) & (t <= end_time))[0]

        if len(indices) == 0:
            raise ValueError("No data points found in the last two seconds of the simulation.")

        if variables == ['r_A']:
            # Compute the average firing rates over the last two seconds of population A
            avg_firing_rates = np.mean(state_history[indices, :, 0], axis=0)

            return avg_firing_rates

        elif variables == ['r_A', 'r_B']:
            # Compute the average firing rates over the last two seconds of population A and B
            avg_firing_rates_A, avg_firing_rates_B = np.mean(state_history[indices, :, 0], axis=0), np.mean(state_history[indices, :, 1], axis=0)

            return avg_firing_rates_A, avg_firing_rates_B

## 1. Building a cortical circuit

### 1.1. Modelling single populations of neurons

#### Excitatory population
We will start by simulating a single population of excitatory neurons. The dynamics of a population of neurons can be described by a differential equation such as the following:
$$ \tau_R \frac{dr(t)}{dt} = -r(t) + Φ_E(I)  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1)$$
Where $r(t)$ is the firing rate of the population at time t, $τ_R$ is the time constant of the neural population and $I$ is the total synaptic input to the population, which can consist of background input, external stimuli, input from other populations and/or noise.


The **transfer function** $Φ(I)$, determines how a neural population reacts to its input. For an excitatory population we use the following transfer function from (Abbott and Chance, 2005):

$$ Φ_E(I) = \frac{aI - b}{1-e^{-d(aI-b)}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)$$

a, b and d are the constant parameters that shape the transfer function.

In [None]:
def phi_E(I):
    return (params.A * I - params.B) / (1 - np.exp(-params.D * (params.A * I - params.B)))

def simulate_excitatory_population(I_ext=None,I_0 = 0.33, TAU_R = 0.1, dt = 0.5e-3,T = 20):
       time_steps = int(T / dt)
       r = np.zeros(time_steps)
       for t in range(1, time_steps):
              I =  I_0 + I_ext[t]
              dr = (-r[t-1] + phi_E(I)) * (dt / TAU_R)
              r[t] = r[t-1] + dr
       return r

#### Exercise 1.1

In [None]:
# @markdown Now, we simulate the dynamics of an excitatory population for 10 seconds. This population recieves a time varying external input (This could for example be a visual stimulus). The total input to the population is shown on the top plot.

# @markdown Play with the time constant parameter, what do you observe?

# @markdown Execute this cell to enable the widget!
params = WorkingMemoryParameters()
def exc_exploration( tau_R = 0.02):
  I_background = 0.8
  dt= 0.5e-3
  T = 10
  time_steps = int(T / dt)
  I_ext = np.zeros(time_steps)
  I_ext[int(2/dt):int(4/dt)] = 0.2
  I_ext[int(2.5/dt):int(3.5/dt)] = 0.4
  I_ext[int(6.5/dt):int(7.5/dt)] = -0.5
  I_ext[int(7.5/dt):int(8/dt)] = -0.8



  r= simulate_excitatory_population(I_ext=I_ext,I_0= I_background, TAU_R= tau_R, dt=dt, T=T)

  I0 = np.ones(time_steps)*I_background
  plt.figure(figsize=(11, 2))
  plt.plot(np.linspace(0, T, time_steps), I0+ I_ext, label="Total Input", color="green")
  plt.ylim((-0.3,1.4))
  plt.xlabel("Time (s)")
  plt.ylabel("Input current (nA)")
  plt.legend()
  plt.show()
  plt.figure(figsize=(11, 2))
  plt.plot(np.linspace(0, T, time_steps), r, label="r (Excitatory)")
  plt.xlabel("Time (s)")
  plt.ylabel("Firing rate (Hz)")
  plt.legend()
  plt.ylim((0,155))
  plt.show()

tau_label = widgets.Label(value='τ_R (s):')
tau_slider = widgets.FloatSlider(value=0.02, min=0.01, max=1.0, step=0.01)
tau_widget = widgets.HBox([tau_label, tau_slider])
display(tau_widget)
widgets.interactive_output(exc_exploration, {'tau_R': tau_slider})

#### Inhibitory population
Not all the neurons in the brain are excitatory. To be able to build a realistic model of the brain, we need to be able to model inhibitory populations of neurons as well as excitatory ones. Dynamics of inhibitory neurons follows the same ODE as excitatory ones. Except that we use a different activation function:
$$ \tau_R \frac{dr(t)}{dt} = -r(t) + Φ_I(I)  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1)$$
 $$ Φ_I(I) = [\frac{1}{g_I}(c_1I-c_0)+r_0]_+ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)$$
 where the $[x]_+$ notation denotes rectification at zero.  

In [None]:
def phi_I(I):
    return np.maximum(params.R_0 + (params.C_1 * I - params.C_0) / params.G_I, 0)

def simulate_inhibitory_population(I_ext=None,I_0 = 0.33, TAU_R = 0.1, dt = 0.5e-3,T = 20):
       time_steps = int(T / dt)
       r = np.zeros(time_steps)
       for t in range(1, time_steps):
              I =  I_0 + I_ext[t]
              dr = (-r[t-1] + phi_I(I)) * (dt / TAU_R)
              r[t] = r[t-1] + dr
       return r

#### Exercise 1.2

In [None]:
# @markdown In this exercise, we focus on the differences between excitatory and inhibitory transfer functions.

# @markdown Now, we simulate the dynamics of one excitatory and one inhibitory population for a 10 seconds trial. Both populations are isolated (They are not connected to each other) and they recieve the same total input, as shown in the top panel. The toal input consists of a constant background input and an external stimulus that is presented between t=4 to t=6 seconds.

# @markdown Change the background input and the stimulus intensity. Observe how the firing rate of each population reacts to the increased input.
# @markdown 1. Slowly increase the background input from 0. Which population starts to fire first?

# @markdown 2. Which population has a sharper increase in the firing rate? Which one is more active?

# @markdown 3. What happens after the stimulus is removed?

# @markdown Execute this cell to enable the widget!


params = WorkingMemoryParameters()
def inh_exploration( I_background = 0.3, I_stimulus = 0.2):
  tau_R = 0.05
  dt= 0.5e-3
  T = 10
  time_steps = int(T / dt)
  stim_start = 4
  stim_end = 6
  I_ext = np.zeros(time_steps)
  I_ext[int(stim_start/dt):int(stim_end/dt)] = I_stimulus


  r_exc= simulate_excitatory_population(I_ext=I_ext,I_0= I_background, TAU_R= tau_R, dt=dt, T=T)
  r_inh= simulate_inhibitory_population(I_ext=I_ext,I_0= I_background, TAU_R= tau_R, dt=dt, T=T)

  I0 = np.ones(time_steps)*I_background
  plt.figure(figsize=(11, 2))
  plt.plot(np.linspace(0, T, time_steps), I0+ I_ext, label="Total Input", color="green")
  plt.ylim((-0.6,1.6))
  plt.xlabel("Time (s)")
  plt.ylabel("Input current (nA)")
  plt.vlines(x=stim_start,ymin=-0.55, ymax=1.55,linestyles='dashed',color = "black")
  plt.vlines(x=stim_end,ymin=-0.55, ymax=1.55,linestyles='dashed',color = "black")
  plt.legend()
  plt.show()
  plt.figure(figsize=(11, 3))
  plt.plot(np.linspace(0, T, time_steps), r_exc, label="r (Excitatory)")
  plt.plot(np.linspace(0, T, time_steps), r_inh, label="r (Inhibitory)", color= "grey")
  plt.vlines(x=stim_start,ymin=0.05, ymax=149,linestyles='dashed',color = "black")
  plt.vlines(x=stim_end,ymin=0.05, ymax=149,linestyles='dashed',color = "black")
  plt.xlabel("Time (s)")
  plt.ylabel("Firing rate (Hz)")
  plt.legend()
  plt.ylim((0,155))
  plt.show()
  #print(r_exc[int(2/dt)],r_exc[int(5/dt)])
  #print(r_inh[int(2/dt)],r_inh[int(5/dt)])
_ = widgets.interact(inh_exploration,I_background= (0.0,1,0.01), I_stimulus=(-0.6, 0.6, 0.1))

### 1.2. Connecting populations with synaptic dynamics

After learning to model individual populations of neurons, we are now ready to start connecting them together to form simple circuits.

When populations are connected, we must consider how the presynaptic population’s firing rate influences the postsynaptic population through synaptic dynamics. Synaptic transmission is not instantaneous — it involves the release of neurotransmitters, their binding to receptors, and the opening of ion channels. This results in a change in conductance in the postsynaptic neuron, which evolves over time. (Although the conductance occurs on the postsynaptic side, in population models we associate the variable
s with the presynaptic population, since it is directly driven by its firing rate r(t).)

To capture these interactions, we introduce a new dynamic variable to represent synaptic conductance:

* Excitatory populations activate NMDA receptors on the target population, producing an NMDA conductance variable $S^{N}$

* Inhibitory populations activate GABA receptors, giving rise to a GABA conductance variable $S^{G}$.

These conductances evolve over time according to their own differential equations, and they shape the synaptic current received by the postsynaptic population.

If population A(excitatory) projects to population B, then:
* $S_A^N$ represents the NMDA conductancee generated by A's firing.
* The input current to B includes a term like: $$I_B \sim J_{BA} . S_A$$
Where $J_{BA}$ is the synaptic weight from A to B.

#### NMDA dynamics

The following equation models the dynamics of NMDA conducatance variable activated by excitatory cells:

$$ \frac{dS^N(t)}{dt} = -\frac{S^N}{τ_N} + γ(1-S^N)r  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
NMDA receptor-mediated currents are slow and long-lasting. This is why they are important for temporal integration and working memory models.


#### GABA dynamics

The following equation shows the dynamics of GABA conducatance variable:
$$ \frac{dS^G(t)}{dt} = -\frac{S^G}{τ_G} + γ_Ir  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
GABA is fast-acting and responsible for rapid inhibition.


#### Exercise 1.3

In [None]:
# @markdown Execute this cell to enable the widget!

# @markdown 1) Vary the time constants for NMDA and GABA within their valid range. Look at how each of the conductance variables change as a function of the firing rate of the population they belong to.

# @markdown 2) What are the differences between NMDA and GABA conductance dynamics?
params = WorkingMemoryParameters()
def NMDA_conductance(params,I_ext=None,dt = 0.5e-3,T = 20,TAU_N= 0.06 ):
       time_steps = int(T / dt)
       # Initialize
       S_A = np.zeros(time_steps)
       r_A = np.zeros(time_steps)
       S_A[0] = 0
       for t in range(1, time_steps):
              I_A =  params.I_0A + I_ext[t]
              dr_A = (-r_A[t-1] + phi_E(I_A)) * (dt / params.TAU_R)
              r_A[t] = r_A[t-1] + dr_A
              S_A[t] = S_A[t-1] + dt * (-S_A[t-1] / TAU_N +
                                          params.GAMMA * (1 - S_A[t-1]) * r_A[t])
       return r_A, S_A
def GABA_conductance(params,I_ext=None,dt = 0.5e-3,T = 20,TAU_G= 0.005 ):
       time_steps = int(T / dt)
       S = np.zeros(time_steps)
       r = np.zeros(time_steps)
       S[0] = 0
       for t in range(1, time_steps):
              I =  params.I_0A + I_ext[t]
              dr = (-r[t-1] + phi_I(I)) * (dt / params.TAU_R)
              r[t] = r[t-1] + dr
              S[t] = S[t-1] + dt * (-S[t-1] / TAU_G +
                                          params.GAMMA_I * r[t])
       return r,S
def F_exploration( TAU_N = 0.06, TAU_G = 0.005):

  dt= 0.5e-3
  T = 8
  time_steps = int(T / dt)
  I_ext_A = np.zeros(time_steps)
  I_ext_A[int(1/dt):int(2/dt)] = 0.1
  I_ext_A[int(2.5/dt):int(3/dt)] = 0.2
  I_ext_A[int(4/dt):int(5/dt)] = 0.5
  I_ext_A[int(6/dt):int(7/dt)] = 1

  r,s= NMDA_conductance(params,I_ext=I_ext_A, dt=dt, T=T, TAU_N= TAU_N)
  r2,s2= GABA_conductance(params,I_ext=I_ext_A, dt=dt, T=T, TAU_G= TAU_G)

  plt.figure(figsize=(11, 1.5))
  plt.plot(np.linspace(0, T, time_steps), I_ext_A, label="Input",color = "green")
  plt.legend()
  plt.show()
  plt.figure(figsize=(11, 1.5))
  plt.plot(np.linspace(0, T, time_steps), r, label="r (Excitatory)")
  plt.plot(np.linspace(0, T, time_steps), r2, label="r (Inhibitory)",color = "grey")
  plt.legend()
  plt.show()
  plt.figure(figsize=(11, 1.5))
  plt.plot(np.linspace(0, T, time_steps), s, label="s (NMDA)")
  plt.plot(np.linspace(0, T, time_steps), s2, label="s (GABA)",color = "grey")
  plt.legend()
  plt.show()



  # plt.legend()
  # plt.show()

_ = widgets.interact(F_exploration, TAU_N=(0.05, 0.25, 0.01), TAU_G=(0.005, 0.020, 0.001))

#### Optional: AMPA dynamics
Once you learn to model conductances, you can always extend your model to include more neurotransmitter dynamics. You can for example introduce $S^{AMPA}$ for excitatory populations to build a model that captures both the slow dynamics of NMDA as well as the fast dynamics of AMPA.

#### Bonus Exercise
AMPA has similar dynamics to GABA but its effect on the postsynaptic neuron is excitatory. Write an equation describing the dynamics of $S^G$.

### 1.3. Local circuit model
Now that we’ve built an understanding of how neural populations behave and how they can be connected via synaptic conductances, we can take the next step: constructing a simple circuit that captures key dynamics of a cortical brain area.

We consider a minimal yet powerful architecture composed of two excitatory populations (A and B) and one inhibitory population (C). The excitatory populations are selective — each is tuned to a different input or stimulus feature — while the inhibitory population provides non-selective inhibition to both excitatory groups.

This setup reflects a common organizational motif in the cortex, where pyramidal neurons form local subnetworks that are selectively responsive to particular inputs, and interneurons mediate competitive interactions between them. While real cortical networks contain many more such subpopulations, this reduced model retains the essential ingredients needed to study higher cognitive functions.

The total synaptic input to any of the populations can be calculated using these equations:
$$
I_A = J_S S_A +J_C S_B + J_{EI} S_C + I_{ext_A} + I_{0A}
$$
$$
I_B = J_C S_A +J_S S_B + J_{EI} S_C + I_{ext_B} + I_{0B}
$$
$$
I_c = J_{IE} S_A +J_{IE} S_B + J_{II} S_C + I_{ext_C} + I_{0C}
$$



Where the synaptic conducatances are calculated as below:
$$ \frac{dS_A(t)}{dt} = -\frac{S_A}{τ_N} + γ(1-S_A)r_A  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
$$ \frac{dS_B(t)}{dt} = -\frac{S_B}{τ_N} + γ(1-S_B)r_B  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
$$ \frac{dS_C(t)}{dt} = -\frac{S_C}{τ_G} + γ_Ir_c  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
And the firing rates of each population is calculated by solving:
$$ \tau_R \frac{dr_A(t)}{dt} = -r_A(t) + Φ_E(I_A)  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
$$ \tau_R \frac{dr_B(t)}{dt} = -r_B(t) + Φ_E(I_B)  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$
$$ \tau_R \frac{dr_C(t)}{dt} = -r_C(t) + Φ_I(I_C)  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $$


In [None]:
#@title Implementation of the local circuit
#@markdown Please run this cell before continueing to the exercises.
params = WorkingMemoryParameters()
def simulate_neural_mass_dynamics(params,I_ext_A=None,I_ext_B=None,I_ext_C=None,dt = 0.5e-3,T = 20):
       time_steps = int(T / dt)
       # Initialize
       S_A = np.zeros(time_steps)
       S_B = np.zeros(time_steps)
       S_C = np.zeros(time_steps)
       r_A = np.zeros(time_steps)
       r_B = np.zeros(time_steps)
       r_C = np.zeros(time_steps)
       S_A[0] ,S_B[0] ,S_C[0] = 0 ,0,0
       for t in range(1, time_steps):
              # Inputs to each population
              I_A = (params.J_S * S_A[t-1] + params.J_C * S_B[t-1] +
                     params.J_EI * S_C[t-1] + params.I_0A + I_ext_A[t])
              I_B = (params.J_C * S_A[t-1] + params.J_S * S_B[t-1] +
                     params.J_EI * S_C[t-1] + params.I_0B) + I_ext_B[t]
              I_C = (params.J_IE * S_A[t-1] + params.J_IE * S_B[t-1] +
                     params.J_II * S_C[t-1] + params.I_0C) +I_ext_C[t]

              # Compute the differentials of firing rates
              dr_A = (-r_A[t-1] + phi_E(I_A)) * (dt / params.TAU_R)
              dr_B = (-r_B[t-1] + phi_E(I_B)) * (dt / params.TAU_R)
              dr_C = (-r_C[t-1] + phi_I(I_C)) * (dt / params.TAU_R)
              # Update firing rates
              r_A[t] = r_A[t-1] + dr_A
              r_B[t] = r_B[t-1] + dr_B
              r_C[t] = r_C[t-1] + dr_C
              # Update state variables using Euler's method
              S_A[t] = S_A[t-1] + dt * (-S_A[t-1] / params.TAU_N +
                                          params.GAMMA * (1 - S_A[t-1]) * r_A[t])
              S_B[t] = S_B[t-1] + dt * (-S_B[t-1] / params.TAU_N +
                                          params.GAMMA * (1 - S_B[t-1]) * r_B[t])
              S_C[t] = S_C[t-1] + dt * (-S_C[t-1] / params.TAU_G +
                                          params.GAMMA_I * r_C[t])
       return r_A, r_B, r_C

#### Exercise 1.4

In this exercise, we simulate a simple 10-second trial. Between t=4 and t=6, a stimulus is presented that matches the selectivity of population A. To model this, we apply an external input exclusively to population A during this time window.


*   Use the slider to apply a brief external input to the excitatory population A.
*   Observe how **both populations** respond **during** and **after** the stimulus.
* Gradually increase $J_S$. At what point do you notice a qualitative change in the network? at what exact value of $J_S$ do you notice this change?




Note: When adjusting $J_s$, the inhibitory coupling $J_{IE}$ is updated accordingly to preserve the balance between excitation and inhibition in the network.

In [None]:
# @markdown Execute this cell to enable the widgets!



def F_exploration(stimulus_A = 0.2,J_s = 0.32):
  params = WorkingMemoryParameters()
  dt= 0.5e-3
  T = 10
  time_steps = int(T / dt)
  stim_start = 4
  stim_end = 6
  I_ext_A = np.zeros(time_steps)
  I_ext_A[int(stim_start/dt):int(stim_end/dt)] = stimulus_A
  params.update(J_S=J_s)
  def calc_J_IE():
        return (params.J_0 - params.J_S - params.J_C) / (2 * params.J_EI * params.ZETA)
  j_ie = calc_J_IE()
  params.update(J_IE=j_ie)
  r_A, r_B, r_C = simulate_neural_mass_dynamics(params,I_ext_A=I_ext_A,I_ext_B=np.zeros(time_steps),I_ext_C=np.zeros(time_steps), dt=dt, T=T)


  plt.figure(figsize=(15, 8))

  plt.plot(np.linspace(0, T, time_steps), r_A,linewidth = 2.0 ,label="r_A (Excitatory A)")
  plt.plot(np.linspace(0, T, time_steps), r_B,linewidth = 2.0 ,label="r_B (Excitatory B)",color = "purple")
  plt.plot(np.linspace(0, T, time_steps), r_C,linewidth = 1.2 ,label="r_C (Inhibitory C)",color = "grey")

  plt.legend()
  plt.show()

_ = widgets.interact(F_exploration, stimulus_A=(0.0, 0.2, 0.001), J_s = (0.1, 1, 0.01))

#### Attractory dynamics
As you may have noticed, this simple circuit creates competition between the two selective excitatory populations through shared inhibition. When one population becomes active, it suppresses the other via the inhibitory population. This mutual inhibition can give rise to winner-take-all dynamics, where only one population remains active at a time.

Such dynamics are a hallmark of attractor networks—systems that tend to settle into one of multiple stable activity patterns (or attractors), depending on initial conditions or inputs. These stable states can represent decisions, percepts, or memories.

#### Exercise 1.5
In the simulation below, a stimulus is presented at t = 3s that selectively excites population A. Later, at t = 7s, a distractor is presented to population B. You can adjust the strength of both the stimulus and the distractor to see how the network responds.

As you explore, consider:

* What happens if the initial stimulus is strong, but the distractor is weak?

* At what point does the distractor become strong enough to disrupt the ongoing activity?

* How does the network behave after the inputs are removed—does it return to baseline, or remain in an active state?

* Does the firing rate in the active state depend on the intensity of the stimulus or the distractor?

These questions will help you uncover how attractor dynamics enable the network to exhibit decision-making and working memory-like behavior.

In [None]:
# @markdown Execute this cell to enable the widget!
params = WorkingMemoryParameters()
def F_exploration(I_stimulus = 0.2,I_distractor = 0.1):
  params.update(J_S=0.47)
  def calc_J_IE():
        return (params.J_0 - params.J_S - params.J_C) / (2 * params.J_EI * params.ZETA)
  j_ie = calc_J_IE()
  params.update(J_IE=j_ie)
  dt= 0.5e-3
  T = 10
  time_steps = int(T / dt)
  stim_start = 3
  stim_end = 4
  _I_ext_A = np.zeros(time_steps)
  _I_ext_A[int(stim_start/dt):int(stim_end/dt)] = I_stimulus
  _I_ext_B = np.zeros(time_steps)
  _I_ext_B[int(7/dt):int(8/dt)] = I_distractor
  _I_ext_C = np.zeros(time_steps)


  r_A, r_B, r_C = simulate_neural_mass_dynamics(params,I_ext_A=_I_ext_A, I_ext_B= _I_ext_B, I_ext_C= _I_ext_C, dt=dt, T=T)


  plt.figure(figsize=(15, 8))

  plt.plot(np.linspace(0, T, time_steps), r_A, label="r_A (Excitatory A)")
  plt.plot(np.linspace(0, T, time_steps), r_B, label="r_B (Excitatory B)",color= "purple")
  #plt.plot(np.linspace(0, T, time_steps), r_C, label="r_C (Inhibitory C)",color = "grey")

  plt.legend()
  plt.show()

_ = widgets.interact(F_exploration, I_stimulus=(0.0, 1, 0.01), I_distractor=(0.0, 0.2, 0.005))


## 2. Structural Connectivity

### Background: the Connectome

Most of the connections made by cortical neurons are with their nearest neighbors (within millimeters of the soma). However, pyramidal neurons can project to distant areas of the cortex through myelinated axons, which bundle together in white-matter fibers (https://nyaspubs.onlinelibrary.wiley.com/doi/full/10.1111/nyas.12271). The network of white-matter tracts is what network neuroscientists call **the connectome**. The connectome can be used to connect neural-masses, such as the ones developed in the first part of the tutorial, and is, therefore, one of the key ingredients in building **large-scale brain models**.

In humans, white-matter tracts are normally measured noninvasively with diffusion-weighted magnetic resonance imaging, which estimates white-matter tracts from the direction of water diffusion. However, this estimation is susceptible to measurement noise, under-estimates white-matter tracts at a greater depth within the brain and does not provide information about the directionality of axons within white-matter tracts.

Conversely, in non-human primates, other invasive methods can be used, allowing for a more precise estimation of long-range cortical connections. One of such methods is **retrograde tracing**, whereby a tracer is injected in a cortical area and travels from the synapses in the area towards the soma of the presynaptic neurons. Therefore, by injecting in a given area and observing where the tracer ends up, this technique indicates which neurons (sources) project to the injected area (target). For more information, please consult https://nyaspubs.onlinelibrary.wiley.com/doi/full/10.1111/nyas.12271.

In this tutorial, to define cortical areas, we use the M132 brain parcellation, detailed in https://doi.org/10.1093/cercor/bhs270, with 91 areas in the left hemisphere defined by histological characteristics. A plot of this parcellation can be visualized below (adapted from https://academic.oup.com/cercor/article/24/1/17/272931).

&nbsp;
![Representation of the M132 parcellation of the macaque brain with specific colors for each area](https://drive.google.com/uc?id=10_73Ru-CYnrFGTHiIAQZktW8auG0QySl "M123 Parcellation of the Macaque Brain")

A good way to approximate the strength of connectivity from area $j$ to area $i$ with tract tracing data is to count the number of labeled neurons in area $j$ after injection in area $i$.

This quantity can then be divided by the total number of neurons projecting to area $i$ to obtain the Fraction of Labeled Neurons (FLN), which indicates the strength of the projections from area $j$ to $i$. Note that, since this metric is based on tract tracing data, which has a clear source and target, the resulting FLN matrix is directed (i.e. the connection from $i$ to $j$ is not necessarily the same as $j$ to $i$).

In [None]:
# @markdown Run the following cell to load FLN data from injections in 30 areas of the macaque connectome, together with a description of the 30 areas.

# @markdown **Pick one of the areas now and pay attention to its features when completing the next exercises!**


FLN, SLN, area_centers, area_info, h = load_structural_data()

area_labels = area_info["Name"].values

distances = np.zeros((area_centers.shape[0], area_centers.shape[0]))
for i in range(distances.shape[0]):
    for j in range(i+1, distances.shape[1]):
        distances[i, j] = np.linalg.norm(area_centers[j, :] - area_centers[i, :])

distances += distances.T

area_info

In [None]:
# @markdown Let's take a look at the FLN matrix. Run this cell to plot it in matrix form.

plt.figure(figsize = (15, 6))
plt.subplot(1, 2, 1)
plt.imshow(FLN, cmap = 'inferno')
c = plt.colorbar()
c.set_label('FLN')
plt.xticks(np.arange(FLN.shape[0]), area_labels, rotation = 45, ha = 'right', fontsize = 7)
plt.yticks(np.arange(FLN.shape[0]), area_labels, fontsize = 7)
plt.xlabel('From')
plt.ylabel('To')

plt.subplot(1, 2, 2)
plt.imshow(np.log10(FLN), cmap = 'inferno')
c = plt.colorbar()
c.set_label('log10(FLN)')
plt.xticks(np.arange(FLN.shape[0]), area_labels, rotation = 45, ha = 'right', fontsize = 7)
plt.yticks(np.arange(FLN.shape[0]), area_labels, fontsize = 7)
plt.xlabel('From')
plt.ylabel('To')

plt.tight_layout()
plt.show()

### Exercise 2.1. Visualizing the connectome

In [None]:
# @markdown Let's now plot the node's positions in the brain together with their connectivity. The thickness of the lines represent the weight of the connection.

# @markdown With the slider, you can regulate what percentage of strongest connections is shown. For example, if the threshold is 10, only the 10% strongest connections are shown. Try to move it around and see how it changes the visualization (it takes some time to update).

# @markdown 1. Are there any interesting features that you notice already from this plot and the matrix representation? Are there major changes in the overall structure of the connectome when applying a more stringent threshold?
def threshold_connectivity(Threshold = 100):

  plot_matrix(matrix = FLN, coordinates = area_centers, node_labels = area_labels, threshold=Threshold/100)

_ = widgets.interact(threshold_connectivity, Threshold=(2, 100, 2))

Even though the matrix is relatively dense (about 65% of possible connectione exist), strong connectivity is quite sparse, with a large majority of connections having a very low weight.

In [None]:
# @markdown Let's plot the distribution of non-zero FLN values to better visualize this "sparsity".

connections = FLN.flatten()
connections = connections[connections != 0]

plt.figure(figsize = (8, 3), dpi = 150)

plt.subplot(1, 2, 1)
y, x = np.histogram(connections, bins = np.arange(0, 1.01, 0.02))
x = 0.5 * (x[1:] + x[:-1])
plt.plot(x, y, color = 'k', linewidth = 1)
plt.ylabel('# Connections')
plt.xlabel('FLN')

plt.subplot(1, 2, 2)
y, x = np.histogram(connections, bins = np.arange(0, 1.01, 0.02))
x = 0.5 * (x[1:] + x[:-1])
plt.plot(x, y, color = 'k', linewidth = 1)
plt.ylabel('# Connections')
plt.xlabel('FLN')
plt.yscale('log')
plt.xscale('log')

plt.tight_layout()
plt.show()

You might notice that, when plotting the distribution of non-zero connection weights in a log-log plot, it seems to follow a linear relationship. This type of scaling is called a **power-law** and it can be found ubiquitously across different brains (https://doi.org/10.1103/PhysRevResearch.7.013134) and other complex systems (https://www.science.org/doi/full/10.1126/science.284.5420.1677).

### Exercise 2.2. Understanding node degrees


Let's start our exploration of the connectome by calculating one of the simples features of nodes in a network: their **degree**. In unweighted networks (where connections are binary - either present or absent), the degree of a node is the number of connections, or links, it establishes with the network. This concept can be adapted to weighted networks by simply computing the sum of connection weights linked to a node.

Furthermore, since we are dealing with a directed network (W), there are two ways of computing weighted node degrees:

- **in-degree:** number of incoming connections to a node

$$
\text{in-degree}_i = \sum_j W_{ij}
$$

- **out-degree:** number of outgoing connections from a node

$$
\text{out-degree}_j = \sum_i W_{ij}
$$

In [None]:
# @markdown Let's now compare weighted in- and out- degrees (i.e. sum of incoming and out-going connections) in our network. Execute this cell to get an interactive plot of in- and out- degrees.


# @markdown 1. Are nodes generally balanced in terms ou in- and out- degrees? Do you think this can influence the role of cortical areas in the network?

def compute_degree(W, degree_type = 'In'):
    """
    Computes the in-degree (sum of incoming connections) for all nodes in a network

    Args:
      W (array): connectivity matrix
      degree_type (str): type of degree to be computed. Options: ['In', 'Out']

    Returns:
      in_degree (array): node degree
    """
    ###################################################################################################################################################
    # Exercise 1: Fill out the first in_degree = np.sum(...) and comment the line below before running this cell
    # Hint: the np.sum function has an argument axis that allows for the values of a matrix to be summed along a given axis.
    ###################################################################################################################################################
    # Exercise 2: Fill out the second in_degree = np.sum(...) and comment the line below before running this cell
    # Hint: try to think of a simple operation that tells you if there is a connection or not. In python, booleans (False/True) will be interpreted as integers (0/1) if used in mathetmatical operations
    ###################################################################################################################################################
    # raise NotImplementedError("Please complete coding exercise")
    ###################################################################################################################################################

    if degree_type == 'In':
      degree = np.sum(W, axis = 1)
      #degree = np.sum(W!=0, axis = 1).astype(float)
    elif degree_type == 'Out':
      degree = np.sum(W, axis = 0)
      #degree = np.sum(W!=0, axis = 0).astype(float)

    return degree

def plot_degree(Degree_Type = ['In', 'Out']):

  degree = compute_degree(FLN, degree_type = Degree_Type)

  plot_matrix(matrix = FLN, coordinates = area_centers, node_sizes = 4*degree/np.nanmax(degree), node_labels = area_labels, threshold=0.1)

_ = widgets.interact(plot_degree)

As you can see, by comparing the plots with the in- and out-degrees, the strength of incoming and outgoing connections is not always balanced. Therefore, by measuring the difference between in- and out- degrees, it is possible to visualize if nodes are mostly information "senders" (out-degree > in-degree) or "receivers" (out-degree < in-degree). More information on the functional role of senders and receivers and how to infer them in the undirected human connectome, refer to https://www.nature.com/articles/s41467-019-12201-w.

In [None]:
# @markdown Run this cell to plot the difference between in- and out- degrees across our network. In the brain network plot, node colors represent the normalized difference between in- and out- degrees.

# @markdown 1. Is your chosen area a sender or a receiver? Do you think senders and receivers play a different role in brain function?


in_degree = compute_degree(FLN, degree_type = 'In')
out_degree = compute_degree(FLN, degree_type = 'Out')

in_out_ratio = (in_degree - out_degree)/(in_degree + out_degree)

plt.figure(figsize = (15.04, 3), dpi = 100)
plt.plot(np.sort(in_out_ratio), 'ok')
plt.xticks(np.arange(len(in_out_ratio)), [area_labels[n] for n in np.argsort(in_out_ratio)], rotation = 45, ha = 'right')
plt.axhline(y = 0, color = 'k', alpha = 0.2)

i = np.where(np.sort(in_out_ratio) > 0)[0][0]-0.5
plt.fill_between([-10, i], [-10, -10], [10, 10], color = 'teal', alpha = 0.2)
plt.fill_between([i, 1000], [-10, -10], [10, 10], color = 'darkred', alpha = 0.2)
plt.text(0.5 * i, 0.5, 'Senders', ha = 'center', fontsize = 15)
plt.text(0.5 * (i + len(in_out_ratio)), -0.5, 'Receivers', ha = 'center', fontsize = 15)

plt.xlim([-0.5, len(in_out_ratio)-0.5])
plt.ylim([-1, 1])
plt.ylabel(f'In Degree - Out Degree\n(Normalized)')
plt.tight_layout()
plt.show()

plot_matrix(matrix = FLN, coordinates = area_centers, node_colors = in_out_ratio, node_cmap = 'RdBu_r', node_label_color = 'k', node_labels = area_labels, threshold=0.1)

#### **Bonus Exercise - Difference between weighted and unweighted degree**

In [None]:
# @markdown Run the following cell to visualize the in-degree of nodes. In this case, the size of nodes will represent their in-degree. The widget allows you to choose between the weighted and unweighted version.

# @markdown 1. Visualize the weighted and unweighted in-degrees of our areas. Do you notice any differences?

def compute_in_degree(W, weighted = True):
    """
    Computes the in-degree (sum of incoming connections) for all nodes in a network

    Args:
      W (array): connectivity matrix
      weighted (bool): if True, accounts for the weight of links when computing in-degrees. Otherwise, it just counts the number of incoming connections

    Returns:
      in_degree (array): in-degree
    """
    ###################################################################################################################################################
    # Exercise 1: Fill out the first in_degree = np.sum(...) and comment the line below before running this cell
    # Hint: the np.sum function has an argument axis that allows for the values of a matrix to be summed along a given axis.
    ###################################################################################################################################################
    # Exercise 2: Fill out the second in_degree = np.sum(...) and comment the line below before running this cell
    # Hint: try to think of a simple operation that tells you if there is a connection or not. In python, booleans (False/True) will be interpreted as integers (0/1) if used in mathetmatical operations
    ###################################################################################################################################################
    # raise NotImplementedError("Please complete coding exercise")
    ###################################################################################################################################################

    if weighted:
      in_degree = np.sum(W, axis = 1)
    else:
      in_degree = np.sum(FLN != 0, axis = 1).astype(float)

    return in_degree


def plot_degree(Weighted = False):

  in_degree = compute_in_degree(FLN, weighted = Weighted)

  plot_matrix(matrix = FLN, coordinates = area_centers, node_sizes = 4*in_degree/np.nanmax(in_degree), node_labels = area_labels, threshold=0.1)

_ = widgets.interact(plot_degree)


### Background: Modular Organization in the Connectome

We have now looked at simple features of network nodes which clarify a bit their role within the network through the strength of their connections. However, such complex networks have characteristics that are beyond how strong their connections are.

One of such characteristics is **modularity** - the tendency of a network to organize in sub-networks of nodes that are strongly connected with each other and more weakly connected to the rest of the network. There is extensive evidence that structural (and functional) brain networks follow to a modular architecture, which is thought to minimize the cost of wiring and allow for module specialization, which is not only advantageous for behavior, but also to confine the spread of damage in the network (https://www.nature.com/articles/s41583-019-0177-6).

&nbsp;
&nbsp;

There are several ways to detect modules in networks (e.g. K-means clustering of connectivity). Here, we use a built-in function of the *networkx* package in python made for detecting community structure in networks that can be weighted and directed, such as ours: *networkx.community.louvain_communities*. This method relies in the following way of computing the degree of **modularity** ($Q$) of a network:

$$
Q = \frac{1}{m} \sum_{i,j}\left[ W_{ij} - \frac{\text{d}_i\text{d}_j}{m} \right] \delta(\text{module}_i, \text{module}_j)
$$


where $W_{ij}$ is the connectivity matrix, $d_i$ is the degree of node $i$, $m$ is $\sum_{i,j}W_{i,j}$ and $\delta(\text{module}_i, \text{module}_j)$ is a Dirac delta function that is 1 if $i$ and $j$ belong to the same module and 0 otherwise. For more details, consult https://iopscience.iop.org/article/10.1088/1742-5468/2008/10/P10008.

In short terms, this method measures **how strongly nodes within a module are connected to each other in comparison to their connectivity to the rest of the network**. While this specific formula works for undirected networks, it can be adapted to the use-case of directed networks such as ours (see *networkx* documentation).

&nbsp;
&nbsp;

To detect network communities, *networkx.community.louvain_communities* starts with all nodes belonging to their own community. Then, each node is moved to a different community and the move that results in a greatest increase in $Q$ is then kept as a new module. After this is repeated for all nodes, the algorithm is re-run in a network where each node represents a community obtained from the previous step. This process is repeated iteratively until no substantial gain in modularity can be obtained by reorganizing modules, reaching a maximum level of modularity

In [None]:
# @markdown Let's now look at the communities detected in our FLN network by this method. Execute this cell to run the Louvain community detection algorithm in our network and visualize the results.

# @markdown 1. You might notice that areas belonging to the same module (which is determined solely by their connectivity patterns) tend to also be clustered spatially in the brain, can you think about why it might be the case?


G = nx.from_numpy_array(FLN.T, create_using=nx.DiGraph)
mod_res = nx.community.louvain_communities(G, seed = 42)

modules = np.zeros(len(area_labels))
for i, m in enumerate(mod_res):
    for n in m:
        modules[n] = i

N_modules = int(np.max(modules))+1

for i in range(N_modules):
    print(f'Module {i}')
    print([area_labels[n] for n in range(len(area_labels)) if modules[n] == i])
    print('')



plot_matrix(FLN, area_centers, node_colors = modules, node_cmap = 'jet', node_labels=area_labels, node_label_color='k')

In [None]:
# @markdown Let's explore a possible explanation for the spatial clustering of nodes that area strongly connected. Run this cell to generate a plot of the FLN weight between two areas vs the Euclidean distance between their centers.

plt.figure(figsize = (10, 3), dpi = 150)

plt.subplot(1, 2, 1)


plt.scatter(distances[FLN != 0], FLN[FLN != 0], alpha = 0.2)
r, p = stat.pearsonr(distances[FLN != 0], FLN[FLN != 0])
plt.text(120, 0.5, f"Pearson's r = {r:.3f}\np =  {p:.3f}", ha = 'right')
plt.xlabel('Euclidean Distance (mm)')
plt.ylabel('FLN')


plt.subplot(1, 2, 2)
plt.scatter(distances[FLN != 0], FLN[FLN != 0], alpha = 0.2)
r, p = stat.pearsonr(distances[FLN != 0], np.log10(FLN[FLN != 0]))
plt.text(120, 0.5, f"Pearson's r = {r:.3f}\np =  {p:.3f}", ha = 'right', va = 'top')
plt.xlabel('Euclidean Distance (mm)')
plt.ylabel('FLN')
plt.yscale('log')

plt.tight_layout()
plt.show()

While their relationship might not be obvious when looking at a normal plot, the picture is clearer when plotting the log of FLN instead. This suggests that there is an exponential relationship between the distance between two areas in the brain and how strongly they connect. This exponential relation is a well known feature of brain organization, particularly at larger-scales (https://www.cell.com/neuron/fulltext/S0896-6273(13)00660-0). Since areas that area close together in space tend to connect more strongly, the modules defined by connectivity will also have this tendency to include areas that are close together in the brain.

Interestingly, there are some unusually strong long-range connections beyond this exponential decay rule that are thought to be quite influential in large-scale brain dynamics (https://doi.org/10.1073/pnas.2415102122).

### Exercise 2.3. Visualizing and understanding participation coefficients



Now that we have divided our network into modules, let's explore an interesting metric to quantify the "importance" of a node within the network: the **participation coefficient**, which quantifies how "uniformly-distributed" the links of a node are across different modules. A node with a high participation coefficient can be seen as a brain area that streamlines communication between different modules, while another with low participation coefficient will mostly interact with areas belonging to the same module.

Here's how the participation coefficient of node $i$ ($P_i$) can be calculated:

$$
P_i = 1 - \sum_{m} \left( \frac{\text{degree}_{i,m}}{\text{degree}_i} \right)^2
$$

where $\text{degree}_{i,m}$ is the sum of connections between node $i$ and nodes belonging to module $m$.

In [None]:
# @markdown Run this cell to visualize the participation coefficient of nodes. For simplicity, the results you will see relate to an undirected version of FLN, where link $ij$ is the average of links $ij$ and $ji$ in the original FLN Matrix.

# @markdown In the plot, colors represent the different modules and the size of nodes represents their participation coefficient.

# @markdown 1. Do you notice anything different when visualizing participation coefficients vs node degrees? Which one of the metrics do you think is more useful to study how different nodes shape the behavior of a network?


def compute_participation_coefficient(W, modules):
    """
    Computes the participation coefficient of a network with defined modules

    Args:
      W (array): connectivity matrix
      modules (array): array with the module associated with each node

    Returns:
      pc (array): participation coefficient
    """
    ###################################################################################################################################################
    # Exercise 1: Fill out the first degree_mod = ... and comment the line below before running this cell
    # Hint: you can use the variable modules as mask to known which module a node belongs to. For example, modules == 2 returns an array that is True for all nodes belonging to module 2.
    ###################################################################################################################################################
    # raise NotImplementedError("Please complete coding exercise")
    ###################################################################################################################################################

    degree = compute_degree(W) # in the following example, it does not matter if we use in or out-degrees because the matrix wil be symmetric

    pc = np.zeros(W.shape[0])
    for i in range(len(pc)):
        sum = 0
        for m in range(N_modules):
            degree_mod = np.sum(W[i, modules == m])
            sum += (degree_mod/degree[i])**2
        pc[i] = 1 - sum

    return pc


FLN_symmetric = 0.5 * (FLN + FLN.T) # average of in and out connections, for simplicity (avoids having to account separately for in and out degrees)

def plot_deg_or_pc(Node_Size = ['Degree', 'Participation Coefficient']):

  pc = compute_participation_coefficient(FLN_symmetric, modules)
  dg = compute_degree(FLN_symmetric)

  if Node_Size == 'Degree':
      plot_matrix(FLN, area_centers, node_colors = modules, node_sizes = 4*dg/np.nanmax(dg), node_cmap = 'jet', node_labels=area_labels, node_label_color='k')
  elif Node_Size == 'Participation Coefficient':
      plot_matrix(FLN, area_centers, node_colors = modules, node_sizes = 4*pc/np.nanmax(pc), node_cmap = 'jet', node_labels=area_labels, node_label_color='k')


_ = widgets.interact(plot_deg_or_pc)

### Background: Hierarchical Organization of Long-Range Connectivity

There is one feature of the cortex that we haven't accounted for, so far: it's **laminar architecture**.

The mammalian cortex is organized in layers (typically 6), with specific patterns of interconnectivity that define the basis of the canonical cortical microcircuit (see https://ieeexplore.ieee.org/abstract/document/6796535 for a seminal paper on the canonical microcircuit of the cortex). A cytoarchitectonic representation of the layered structure of the human cortex can be seen in the image below (from https://neurology.mhmedical.com/book.aspx?bookID=3024). Layers can also be grouped by their relative position to layer IV:

- **Superficial or supragranular**: layers I, II and III
- **Granular**: layer IV
- **Deep or infragranular**: layers V and VI

&nbsp;
![Cytoarchitectonic view of the laminar structure of the cortex across different areas](https://drive.google.com/uc?id=13iFzIK3Xd7mDp3QAV6ZicEgyuTisSyaw "Layers of the cortical microcircuit in different areas")


Why is this laminar structure relevant when talking about the connectome? When looking at the organization of white-matter tracts, while accounting for the laminar structure of the cortex, a **"hierarchical"** pattern becomes apparent:

Projections from sensory to association areas, or **feedforward projections** tend to originate from **superficial layers** and target excitatory neurons in layer IV, which relay the signal mostly to local superficial layers. Conversely, in the opposite direction, or **feedback projections** originate in **deep layers** and have more diffuse targets, reaching excitatory and inhibitory neurons across superficial and deep layers (see https://doi.org/10.1093/cercor/1.1.1-a and https://doi.org/10.1126/sciadv.1601335 for more details).

&nbsp;
![Simplified view of laminar connectivity patterns in the cortex](https://drive.google.com/uc?id=1FQzp5v4lOX5ZQeUC4PqV2FdKI8_brcDA " Simplified view of laminar connectivity patterns in the cortex")


Fortunately, these patterns can be measured with **retrograde tracing**! If a tracer is injected in V1, for example, we can count the number of number of neurons labeled in the superficial layers of area V4 and divide it by the total number of neurons found in area V4. This quantity is called the fraction of **supragranular labeled neurons**, or **SLN** and can inform about the relative position of these two areas in the feedforward and feedback flows of information.

As explained previously, connections from sensory to higher-order areas originate mostly from superficial layers (high SLN), while the opposite is the case in the opposite direction. This asymmetry can be used to determine a "hierarchical" organization of cortical areas. In short, the method used in https://www.cell.com/neuron/fulltext/S0896-6273(15)00765-5 optimizes the hierarhical value of each area that allows for the best prediction of the empirically obtained SLN values. With the resulting hierarchy values, one can have a better idea of where different areas of the cortex stand in the flow of information from early sensory to higher-order association areas.



In [None]:
# @markdown Run this cell to visualize the hierarchy across the brain.

plot_matrix(FLN, area_centers, node_colors = h, node_cmap = 'inferno', node_labels=area_labels, node_label_color='k')

### Exercise 2.4: Understanding Counterstream Inhibition

Our model, however, does not have the laminar structure that could support the hierarchical laminar organization of long-range connectivity measured through SLN.

However, there is a way of approaching some of the functional consequences of this organization.

A purely feedfoward connection is considered to have an excitatory effect, since it stems from excitatory neurons in superficial layers and targets excitatory neurons. Conversely, feedback connections have a more diffusive pattern of connectivity and target inhibitory as well as excitatory neurons. Therefore, our model is built on the hypothesis that feedback connections have a net inhibitory effect on the target population and thus, would target mostly interneurons (consult https://elifesciences.org/articles/72136 for more details on how this organization contributes to the stabilization of global dynamics and the propagation of information across the hierarchy).

Because connections are generally not purely feedforward or feedback, we use the following formulas to compute the strenght of long range projections reaching either excitatory ($W_{ij}^E$) or inhibitory (($W_{ij}^I$)) neurons in the target population:

&nbsp;
&nbsp;


$$
W_{ij}^E = SLN_{ij} * FLN_{ij}
$$

$$
W_{ij}^I = (1 - SLN_{ij}) * FLN_{ij}
$$

In [None]:
# @markdown Run this cell to visualize the excitatory targetting and inhibitory targetting components of long-range connectivity.

# @markdown 1. Try to visualize first E->E and then E->I projections. Do they seem generally balanced?

# @markdown 2. Choose the option "Excitatory - Inhibitory", which will show you the difference between E->E and E->I connections. Do you notice some directionality in certain groups of areas?


def plot_counterstream(Connection_Type = 'Excitatory'):

  if Connection_Type == 'Excitatory':
      matrix = FLN * SLN
      plot_matrix_EI(matrix = matrix, coordinates = area_centers, node_labels = area_labels, threshold=0.1)

  if Connection_Type == 'Inhibitory':
      matrix = FLN * (1 - SLN)
      plot_matrix_EI(matrix = -matrix, coordinates = area_centers, node_labels = area_labels, threshold=0.1)

  if Connection_Type == 'Excitatory - Inhibitory':
      matrix = FLN * (2 * SLN - 1)
      plot_matrix_EI(matrix = matrix, coordinates = area_centers, node_labels = area_labels, threshold=0.1)

_ = widgets.interact(plot_counterstream, Connection_Type = ['Excitatory', 'Inhibitory', 'Excitatory - Inhibitory'])


In [None]:
# @markdown Let's look now at just a subset of nodes, which will make it easier to understand how the effect of counterstream inhibition relates to the hierarchical organization of the cortex.

# @markdown 1. Pay attention to the hierarchical value of the different areas (represented by the color of nodes: brighter = higher hierachical value). Can you find a pattern of E-I connectivity between areas with different hierarchical values?

# @markdown 2. Can you tell the directionality of hierarchical interactions from the flow of excitation and inhibition?


def plot_EI_areas(visualize_areas = ['V1', 'V2', 'V4', 'MT']):

  if len(visualize_areas) > 1:

    idxs = [list(area_labels).index(v) for v in visualize_areas]

    # Selects only nodes in list
    centers_aux = area_centers[np.array(idxs), :]
    matrix_aux = (FLN * (2 * SLN - 1))[:, np.array(idxs)][np.array(idxs)]
    h_aux = h[np.array(idxs)].squeeze()

    plot_matrix_EI(matrix_aux, centers_aux, node_sizes = 2, node_colors = h_aux, vmin = 0, vmax = 1, node_cmap = 'inferno',
                  node_labels=visualize_areas, threshold = 1,  node_label_color = 'k')

  else:
    print('Please select at least 2 areas!')


_ = widgets.interact(plot_EI_areas,
                     visualize_areas = widgets.SelectMultiple(value = ['V1', 'V2', 'V4', 'MT'],
                                                              description = 'Cortical Areas',
                                                              options = area_labels))

## 3. Large-scale Model of the Macaque Cortex

Please run the next cell to load data.

In [None]:
#@title Load data
data = load_and_preprocess_data()
area_names = data['area_names']
area_lobes = data['area_lobes']
area_descriptions = data['area_descriptions']
F = data['SLN'] # Supragranular labelled neurons matrix
W = data['W'] # FLN based structural connectivity matrix
h = data['hier_vals'] # Hierarchy score of each area in the same order as area_names

We now have all the necessary ingredients to implement a large-scale model of the macaque cortex:

- local dynamic model
- anatomical structural connectivity
- hierarchical relationships
- counter-stream inhibitory bias

In the next sections, we are going to combine these elements and observe the large-scale dynamics.


### 3.1 Synaptic strength, local and long-range inputs to each node

In the large-scale model, each area has a distinct value of incoming synaptic strength, $J_S$, given by:
$$
J_S(i) = J_{MIN} \;+\; (J_{MAX} - J_{MIN}) h_i
$$

As you've seen in the first section of the tutorial, the bifurcation point of an isolated area is at $J_S$ = 0.4655 nA for our set of parameter values. We set $J_{MAX}$ below that value, implying that all areas in the network are monostable in isolation. In this situation, any sustained activity displayed by the model will be a consequence of a global, cooperative effect due to inter-areal interactions.

We compute the incoming synaptic strength (both local and long-range) of a given area as a linear function of the dendritic spine count values observed in anatomical studies, with age-related corrections when necessary. Alternatively, when spine count data is not available for a given area, we will use its position in the anatomical hierarchy, which displays a high correlation with the spine count data, as a proxy for the latter. After this process, the large-scale network will display a gradient of local and long-range recurrent strength, with sensory/association areas showing weak/strong local connectivity, respectively. We denote the local and long-range strength value of a given area $i$ in this gradient as $h_i$, and this value normalized between zero (bottom of the gradient, area V1) and one.

Here is how the gradient of J_S looks like for our model:

In [None]:
#@title 3.1.1 Gradient of J_s
# Plot gradient of J_s across all cortical areas.
# Note that the areas on the x-axis are ordered by anatomical hierarchy, but
# most J_S values are extracted using spine counts, with anatomical hierarchy
# used for areas where we could not obtain this information.
WM_params_plot = WorkingMemoryParameters()
J_S = WM_params_plot.J_MIN + (WM_params_plot.J_MAX - WM_params_plot.J_MIN) * h
fig, ax = plt.subplots()
ax.scatter(area_names, J_S)
ax.tick_params("x", rotation=70)
ax.set_xlabel('Cortical Area')
ax.set_ylabel('Local Synaptic Strenght, J_S')
plt.show()

The total input current $I^x_i$ for each population $‘i’$ - (A, B or C) of a given node $x$ is given by:

\begin{equation}
I_A^{x} = J_S S_A^{x} \;+\; J_C S_B^{x} \;+\; J_{EI} S_C^{x} \;+\; I_{ext_A} \;+\; I_{\mathrm{net}}^{A, x} \;+\; x_{A}(t)
\end{equation}

\begin{equation}
I_B^{x} = J_C S_A^{x} \;+\; J_S S_B^{x} \;+\; J_{EI} S_C^{x} \;+\; I_{ext_B} \;+\; I_{\mathrm{net}}^{B, x} \;+\; x_{B}(t)
\end{equation}

\begin{equation}
I_C^{x} = J_{IE} S_A^{x} \;+\; J_{IE} S_B^{x} \;+\; J_{II} S_C^{x} \;+\; I_{ext_C} \;+\; I_{\mathrm{net}}^{C, x} \;+\; x_{C}(t)
\end{equation}

In these equations, $J_S$, $J_C$ are the self- and cross-coupling between excitatory populations. Likewise, $J_{EI}$ , $J_{IE}$ and $J_{II}$ are the coupling from the inhibitory populations to any of the excitatory ones, the coupling from any of the excitatory populations to the inhibitory one, and the self-coupling strength of the inhibitory population, respectively. The parameters $I_{ext_i}$ with $i$ = A, B, C are background inputs to each population. Finally, the term $I_{\mathrm{net}}^{i}$ denotes the
long-range input coming from other areas in the network,
and the term $x_i(t)$ with $i$ = A, B, C represents noise (for details, see https://elifesciences.org/articles/72136).

The long-range input to each node $x$ from all other nodes $y$ is described by the following equations:

\begin{equation}
I_{\mathrm{net}}^{A, x}
= G \sum_{y} W^{xy}\,SLN^{xy}\,S^y_{A}
\end{equation}

\begin{equation}
I_{\mathrm{net}}^{B, x}
= G \sum_{y} W^{xy}\,SLN^{xy}\,S^y_{B}
\end{equation}

\begin{equation}
I_{\mathrm{net}}^{C, x}
= \frac{\alpha G}{Z} \sum_{y} W^{xy}\,(1 - SLN^{xy})\,(S^y_{A} + S^y_{B})
\end{equation}

where nodes A and B are the excitatory populations, and C is the inhibitory one.

It is important to note that long-range projections are influenced by two controllable parameters: the global coupling, $G$, and the strength of the feedback inhibition, $\alpha$. Later on in the tutorial, we will observe in more detail how altering the values of these parameters influences the large-scale activity. For now, take a few moments to understand how $G$ and $\alpha$ influence the long-range connection strengths of the network.

In [None]:
#@title 3.1.2 Effect of global coupling and alpha on the balance between long-range excitation and inhibition
def plot_G_alpha(G=0.48, alpha=1):
    matrix = G * FLN * (SLN - alpha * (1 -SLN))
    plot_matrix_EI(matrix = matrix, coordinates = area_centers, node_labels = area_labels, threshold=0.1)
_ = widgets.interact(plot_G_alpha,  G=(0.2, 1.8, 0.01), alpha = (0.4, 1.5, 0.1))

We are ready to simulate the large-scale network. Below the activity traces, you will see a figure highlighting on the connectome in yellow which areas exhibit high persistent activity.

#### 3.1.3 Simple working memory task

In [None]:
#@title Simulation
random_seed = 123
outside_rng = np.random.default_rng(random_seed)
# Initial condition and parameters
N = len(area_names)
Y0 = np.zeros((N, 8))
Y0[:, 0:3] = 5 * (1 + np.tanh(outside_rng.normal(0, 2, (N, 3)))) # r values are initialized randomly between 0 and 10
WM_params = WorkingMemoryParameters()
WM_params.update(G=0.48, J_MAX=0.42, ALPHA=1.0)
# Generate the network
WM_network = WorkingMemoryNetwork(area_names=area_names, area_lobes=area_lobes, Y0=Y0, W=W, F=F, h=h, params=WM_params)

# Run the simulation
t_end = 8
dt = 0.5e-3
I_ext_strength = np.zeros((N, 3))
I_ext_strength[0, 0] = 0.3
state_history = WM_network.run(t_end=t_end, dt=dt, I_ext_strengths=np.array([I_ext_strength]), ts_ext_start=np.array([4]), ts_ext_end=np.array([4.5]))
avg_firing_rates_A = WM_network.avg_firing_rates_last_two_seconds()
persistent = [1 if fr > 10 else 0 for fr in avg_firing_rates_A]

# Plot the results
_ = WM_network.plot_n_areas(xlim=(-4, t_end-4), ylim=(0, 50), title='Simple Working Memory Task')
plot_matrix(FLN, area_centers, node_colors = persistent, node_cmap = 'inferno', node_labels=area_labels, node_label_color='k')

### 3.2 Excitation-inhibition Balance

In this section, you will explore the effect of the global coupling, $G$, and the relative inhibitory bias, $\alpha$, on the temporal-activity traces across the nodes of the network. We set the initial conditions to represent the baseline case that is presented in the paper: $G$ = 0.48 and $\alpha$ = 1. While observing the dynamics, you can explore the following questions:
1. By how much can we increase / decrease the global coupling and still observe meaningful dynamics? What happens to stability beyond that threshold?
2. Why is a counter‐stream inhibitory bias ($\alpha$) important? What if you set $\alpha$ to 0, do you see runaway excitation or pathological suppression?

<em>Bonus</em>:
Can you think of other methods or mechanisms that would achieve a similar balance between excitation and inhibition?

In [None]:
#@title 3.2.1 Global coupling and feedback inhibition exploration
import ipywidgets as widgets

def G_alpha_exploration(G=0.48, alpha=1):
    # Initial condition and parameters
    N = len(area_names)
    Y0 = np.zeros((N, 8))
    Y0[:, 0:3] = 5 * (1 + np.tanh(outside_rng.normal(0, 2, (N, 3)))) # r values are initialized randomly between 0 and 10
    WM_params = WorkingMemoryParameters()

    WM_params.update(G=G, ALPHA=alpha, J_MAX=0.42)

    WM_network = WorkingMemoryNetwork(area_names=area_names, area_lobes=area_lobes, Y0=Y0, W=W, F=F, h=h, params=WM_params, random_seed=8)
    # Run the simulation
    t_end = 8
    dt = 0.5e-3
    I_ext_strength = np.zeros((N, 3))
    I_ext_strength[0, 0] = 0.3
    state_history = WM_network.run(t_end=t_end, dt=dt, I_ext_strengths=np.array([I_ext_strength]),
                                   ts_ext_start=np.array([4]), ts_ext_end=np.array([4.5]))
    _ = WM_network.plot_all_areas(xlim=(-4, t_end - 4), ylim=(0, 50))
    avg_firing_rates_A = WM_network.avg_firing_rates_last_two_seconds()
    persistent = [1 if fr > 10 else 0 for fr in avg_firing_rates_A]
    plot_matrix(FLN, area_centers, node_colors = persistent, node_cmap = 'inferno', node_labels=area_labels, node_label_color='k')

_ = widgets.interact(G_alpha_exploration, G=(0.2, 0.8, 0.01), alpha = (0.4, 1.5, 0.1))


### 3.3 Inactivating cortical areas and effects on dynamics

Lesions have been an important part of scientific discovery in neuroscience, allowing us to perform causal experiments and explore the role of individual brain regions in various settings. Optogenetics allows experimentalists to temporarily inhibit (or excite) specific parts of the neural tissue, in a reversible manner, making it a very useful technique to investigate the effects that specific brain areas could have on dynamics or performance during cognitive tasks.

In this exercise you'll play the role of a virtual experimentalist. By “inactivating” different cortical areas in our large-scale working memory model, you'll discover which nodes are most critical for maintaining a memory trace and why. Think about the following question: if you could silence one brain region at a time, how would you decide which to pick? After thinking about this question for a minute or two, please continue to 'Exploratory study'.

<details>
<summary>Exploratory study</summary>

- **Random lesions:** Turn off 1-5 areas at random. How many does it take before the network fails?
- **Hub lesions:** Turn off the areas with the strongest structural connections. Does performance break down faster than with random lesions?
- **Hierarchy lesions:** Turn off the top areas in the hierarchy. How does that compare?
- **Excitability effects:** Try two different values of the network's global excitation strength (G) or local excitation strength (J_MAX). Does that make the network more or less resistant to lesions?
- **Your own idea:** Think of another way to pick areas (for example, shortest path length, clustering) and test it.

</details>


Questions to think about:

- Which lesion strategy causes the steepest drop in neural activity across the network?
- About how many areas need to be silenced before the network can no longer hold the memory?
- What do your results tell you about how this model balances excitation and inhibition?

<details>
<summary>Hints</summary>

- Think about what you learnt in Section 2.
- For hub lesions, rank areas by the sum of their incoming and outgoing weights in the connectivity matrix.
- For hierarchy lesions, use the hierarchy score provided by the model.
- To test excitability, pick one “low” and one “high” value for the main global coupling parameter.
</details>

To choose which areas to inactivate, run the following code block and select one or multiple areas from the dropdown. You can select more than one area by holding Ctrl + click on all areas you wish to inactivate.

In [None]:
#@title Select which areas to inactivate
inactivated_areas = widgets.SelectMultiple(description='Areas to inactivate:', options=area_names)
inactivated_areas

After selecting which areas to inactivate, run the following code block to observe the activity traces across the network. Blue traces represent the case in which inactivations were applied, while grey traces indicate the baseline experiment (no inactivations). You will also see a connectome plot which shows which areas you lesioned (in red), which exhibit high persistent activity (in yellow) and which exhibit low activity (standard grey).

Feel free to return to the cell above, select different sets of areas to inactivate and repeat the experiment. Below this code block, you will find a helper function that allows you to plot the mean firing rate during the delay period.

In [None]:
#@title Working memory simulation with inactivations
# Initial condition and parameters
N = len(area_names)
Y0 = np.zeros((N, 8))
Y0[:, 0:3] = 5 * (1 + np.tanh(outside_rng.normal(0, 2, (N, 3))))  # r values are initialized randomly between 0 and 10
distributed_WM_params = WorkingMemoryParameters()
distributed_WM_params.update(J_MAX=0.26, G=0.48)

# Generate the network
distributed_WM_network = WorkingMemoryNetwork(area_names=area_names, area_lobes=area_lobes, Y0=Y0, W=W, F=F, h=h, params=distributed_WM_params, random_seed=8)

# Prepare the simulation
t_end = 15
dt = 0.5e-3
I_ext_strength = np.zeros((N, 3))
I_ext_strength[0, 0] = 0.3

# Prepare the visualization
target_area_names = ['V1', 'MT', 'LIP', '24c', 'STPi', '9/46d'] # If using plot_n_areas, these areas will be plotted if in rows we change area_names to target_area_names
rows = math.ceil(len(area_names) / 6)
fig, axes = plt.subplots(rows, 6, figsize=(10, 1.5 * rows), dpi=300)
fig.subplots_adjust(hspace=0.5, wspace=0.3)
axes = axes.flatten() # Flatten for easy iteration

# No lesion
distributed_WM_network.reset()
state_history_no_lesion = distributed_WM_network.run(t_end=t_end, dt=dt, I_ext_strengths=np.array([I_ext_strength]), ts_ext_start=np.array([4]), ts_ext_end=np.array([4.5]))
lines1 = distributed_WM_network.plot_all_areas(xlim=(-4, t_end-4), ylim=(0, 50), axes_flat=axes, variables=['r_A'], colors=[[.6, .6, .6]], legend=False)
avg_firing_rates_A_no_lesion = distributed_WM_network.avg_firing_rates_last_two_seconds()
# Lesion areas
distributed_WM_network.reset()
state_history_lesion = distributed_WM_network.run(t_end=t_end, dt=dt, I_ext_strengths=np.array([I_ext_strength]), ts_ext_start=np.array([4]), ts_ext_end=np.array([4.5]), lesion_areas=inactivated_areas.value)
lines2 = distributed_WM_network.plot_all_areas(xlim=(-4, t_end-4), ylim=(0, 50), axes_flat=axes, variables=['r_A'], colors=[[.1, .6, .8]], legend=False)
avg_firing_rates_A_lesion = distributed_WM_network.avg_firing_rates_last_two_seconds()
persistent = [10 if fr > 10 else 0 for fr in avg_firing_rates_A_lesion]
for lesion_area in inactivated_areas.value:
    area_idx = np.where(np.array(area_names) == lesion_area)[0][0]
    persistent[area_idx] = 5
# Set the legend
lines = lines1 + lines2
labels1 = [line.get_label()+' (no lesion)' for line in lines1]
labels2 = [line.get_label()+f' (silenced {inactivated_areas.value})' for line in lines2]
labels = labels1 + labels2
for ax_ids in range(5, len(axes), 6):
    axes[ax_ids].legend(handles=lines, labels=labels, loc='upper left', fontsize='small', bbox_to_anchor=(1.04, 1))
plt.show()

plot_matrix(FLN, area_centers, node_colors=persistent, node_cmap='inferno', node_labels=area_labels, node_label_color='k', vmin=0, vmax=10)

#### Helper plot to visualize mean delay-activity across network

Run the following code block to visualize the delay-activity firing rate across all areas of the network in baseline vs. inactivation conditions.

In [None]:
#@title Plot delay firing rates
fig, ax = plt.subplots(figsize=(6, 3), dpi=200)
ax.plot(range(len(avg_firing_rates_A_no_lesion)), avg_firing_rates_A_no_lesion, 'o', color=[.6, .6, .6], markersize=5, label='r_A_no_lesion')
ax.plot(range(len(avg_firing_rates_A_lesion)), avg_firing_rates_A_lesion, 'o', color=[.1, .6, .8], markersize=5, label=f'r_A_lesioned: {inactivated_areas.value}')
ax.set_xticks(np.arange(len(avg_firing_rates_A_no_lesion)), labels=area_names, rotation=70, fontsize=8)
ax.set_xlabel("Areas sorted by hierarchy (low -> high)", fontsize=8)
ax.set_ylabel("Firing Rate (Hz)", fontsize=8)
ax.set_title(f"Average firing rates in the last 2s of the simulation", fontsize=10, fontweight='bold')
plt.legend(fontsize=7)
plt.tight_layout()
ax = plt.gca()
for spine in ['top','right']:
    ax.spines[spine].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.grid(linestyle=':', alpha=0.5)
ax.tick_params(axis='both', which='major')
ax.tick_params(axis='both', which='minor')