# Simulation Case-Study: Neuroscience

This notebook implements and analyses the behaviours of biological functions associated with neuron and cortical activity. It is self-contained with the implementation, simulation, and further analysis as a single methodological process. The simulations that are conducted for this case-study are described below. Biophysical behaviour will be sourced from Izhikevich's 2003 _Simple Model of Spiking Neurons_, IEEE Transactions On Neural Networks, 14(6), p.1569-1572.

1. **Neuron Spiking Behaviour** - Modelling the activations and spiking behaviour of typical neuron conditions.
2. **Cortex Synchronisation** - Modelling a random network of neurons and resultant cortical synchronisation for firings.

In [None]:
import enum
from dataclasses import dataclass, field

import tqdm
import numpy as np
import pandas as pd
import mplcyberpunk
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
plt.style.use(["seaborn-v0_8-notebook", "cyberpunk"])
plt.rcParams["figure.figsize"] = (20, 10)

## Simulation 1: Neuron Spiking Behaviour

The goal of this simulation is to implement Izhikevich's original neuron conditions to reproduce the behaviours of the different types of neurons: Regular Spiking, Intrinsically Bursting, Chattering, Fast Spiking, Low-Threshold Spiking, Thalamo-Cortical, and Resonator. Constant neuron definitions are constructed so the behaviour of each is consistent. There are several variables and constants to consider:

- $a$: Time-scale for the neuron recovery variable.
- $b$: Sensitivity for the neuron recovery variable.
- $c$: Reset value for the membrane potential.
- $d$: Reset value for the recovery variable.
- $\hat{v}$: Simulation variable for the neuron membrane potential.
- $\hat{u}$: Simulation variable for the neuron recovery variable.

This simulation uses first-order differential equations for the simulated $\hat{v}$ and $\hat{u}$. The update function is performed for a given time-step delta as follows:

$$
\begin{align}
    \frac{d\hat{v}}{dt} &= 
        \begin{cases}
            c & \hat{v} \geq 30 \\
            0.04\hat{v}^2 + 5\hat{v} + 140 - \hat{u} + I & \text{otherwise}
        \end{cases} \\
    \frac{d\hat{u}}{dt} &= 
        \begin{cases}
            d + \hat{u} & \hat{v} \geq 30 \\
            a ( b\hat{v}-\hat{u}) & otherwise
        \end{cases} 
\end{align}
$$

Based on this simulation, the following figure describes the expected behaviours for each neuron class. For reference, input current has step-wise differences of $I=10$, variable resting membrane potential which will be denoted $r$, with varying stimulation periods and spikes.

<br>

<center>
<img src="./imgs/neuron-spiking.png" alt="expected neuron spiking behaviour" width="80%"/>
</center>



In [None]:
class NeuronName(enum.Enum):
    RS = "Regular Spiking"
    IB = "Intrinsically Bursting"
    CH = "Chattering"
    FS = "Fast Spiking"
    LTS = "Low-Threshold Spiking"
    TC1 = "Thalamo-Cortical (Normal)"
    TC2 = "Thalamo-Cortical (Burst)"
    RZ = "Resonator"
    
class NeuronType(enum.Enum):
    RS = "Excitatory"
    IB = "Excitatory"
    CH = "Excitatory"
    FS = "Inhibitory"
    LTS = "Inhibitory"
    TC1 = "Other"
    TC2 = "Other"
    RZ = "Other"

@dataclass(frozen=True, slots=True, eq=True)
class Neuron:
    """
    Neurons are constant abstractions for the neuroscience case-study. The different
    parameters are associated with the individual behaviours of each kind which impacts
    simulated membrane potential and recovery variable changes over time.
    """

    name: NeuronName = field(metadata={"desc": "Name of the neuron."})
    dtype: NeuronType = field(metadata={"desc": "Type of the neuron."})
    a: float = field(metadata={"desc": "Time-scale for the neuron recovery variable."})
    b: float = field(metadata={"desc": "Sensitivity for the neuron recovery variable."})
    c: float = field(metadata={"desc": "Reset value for the membrane potential."})
    d: float = field(metadata={"desc": "Reset value for the recovery variable."})
    r: float = field(metadata={"desc": "Resting value for the membrane potential."})

simulation_total_time: int = 1000
simulation_time_step: float = 0.1
steps: int = int(simulation_total_time / simulation_time_step)

In [None]:
class NeuronSpikingSimulation:
    NEURON_INITIAL_RECOVERY_VARIABLE: float = 0.0
    NEURON_THRESHOLD: float = 30.0

    @staticmethod
    def dvdt(v: float, u: float, I: float) -> float:
        """
        Update function for a time-step delta of the neuron membrane potential based on
        the previous value. 
        
        Args:
            v (float): Previously updated membrane potential.
            u (float): Previously updated recovery variable.
            I (float): Post-synaptic input current.

        Returns:
            float: New membrane potential.
        """
        return (0.04 * v**2) + (5 * v) + 140 - u + I

    @staticmethod
    def dudt(neuron: Neuron, v: float, u: float) -> float:
        """
        Update function for a time-step delta of the neuron recovery variable based on
        the previous value. 

        Args:
            neuron (Neuron): Neuron container for necessary constant parameters.
            v (float): Previously updated membrane potential.
            u (float): Previously updated recovery variable.

        Returns:
            float: Delta time-step recovery from Izhikevich's formula.
        """
        return neuron.a * (neuron.b * v - u)

    @staticmethod
    def simulate(
        neuron: Neuron, 
        I: np.ndarray, 
        simulation_total_time: int = simulation_total_time, 
        simulation_time_step: float = simulation_time_step,
    ) -> pd.DataFrame:
        """
        Simulate the changing membrane potential and recovery variable for an individual
        neuron and given time-steps/input currents. 

        Args:
            neuron (Neuron): Parameterised neuron for simulating spiking behaviour.
            I (np.ndarray): Array of time-step input currents to simulate.
            simulation_total_time (int, optional): Milliseconds time. Defaults to 1000.
            simulation_time_step (float, optional): Millisecond step. Defaults to 0.1.
        Returns:
            pd.DataFrame: DataFrame of simulation results. Rows correspond to time step
                with columns for simulation time, membrane potential, and recovery.
        """

        # Calculate number of simulation steps
        steps = int(simulation_total_time / simulation_time_step)
        assert I.size == steps

        # Setup containers for recording time-step values
        T = np.zeros(steps, dtype=np.float32)
        v_arr = np.zeros(steps, dtype=np.float32)
        u_arr = np.zeros(steps, dtype=np.float32)
        v_arr[0] = neuron.r
        u_arr[0] = neuron.r * neuron.b

        # Cycle through the simulation steps and perform updates
        desc = f"Neuron Spiking Simulation ({neuron.name.value})"
        for index in tqdm.tqdm(range(1, steps), desc=desc):
            # Obtain previous variable values
            v_prev = v_arr[index - 1]
            u_prev = u_arr[index - 1]

            # Update based on the spiking threshold
            if v_prev >= NeuronSpikingSimulation.NEURON_THRESHOLD:
                v_arr[index] = neuron.c
                u_arr[index] = neuron.d + u_prev

            # Otherwise perform normal variable updates
            else:
                v_new = NeuronSpikingSimulation.dvdt(v_prev, u_prev, I[index])
                u_new = NeuronSpikingSimulation.dudt(neuron, v_prev, u_prev)
                v_arr[index] = v_prev + simulation_time_step * v_new
                u_arr[index] = u_prev + simulation_time_step * u_new

            # Update the time array
            T[index] = T[index - 1] + simulation_time_step

        return pd.DataFrame(
            {
                "Time": T, 
                "Potential": v_arr, 
                "Recovery": u_arr, 
                "Current": I
            }
        )

In [None]:
neurons: list[Neuron] = [
    Neuron(NeuronName.RS,   NeuronType.RS,  a=0.02, b=0.20, c=-65, d=8.00, r=-70.0),
    Neuron(NeuronName.IB,   NeuronType.IB,  a=0.02, b=0.20, c=-55, d=4.00, r=-70.0),
    Neuron(NeuronName.CH,   NeuronType.CH,  a=0.02, b=0.20, c=-50, d=2.00, r=-50.0),
    Neuron(NeuronName.FS,   NeuronType.FS,  a=0.10, b=0.20, c=-65, d=2.00, r=-70.0),
    Neuron(NeuronName.LTS,  NeuronType.LTS, a=0.02, b=0.25, c=-65, d=2.00, r=-65.0),
    Neuron(NeuronName.TC1,  NeuronType.TC1, a=0.02, b=0.25, c=-65, d=0.05, r=-63.0),
    Neuron(NeuronName.TC2,  NeuronType.TC2, a=0.02, b=0.25, c=-65, d=0.05, r=-87.0),
    Neuron(NeuronName.RZ,   NeuronType.RZ,  a=0.10, b=0.26, c=-65, d=2.00, r=-65.0),
]

I: int = 10
D: int = 250
Iarray = lambda cur : np.concatenate([np.repeat(n, l) for n, l in cur]) 
activations: list[np.ndarray] = [
    Iarray([(0, D), (I, steps - D)]), 
    Iarray([(0, D), (I, steps - D)]), 
    Iarray([(0, D), (I, steps - D)]), 
    Iarray([(0, D), (I, steps - D)]), 
    Iarray([(0, D), (I, steps - D)]), 
    Iarray([(0, D), (I * .5, steps - D)]), 
    Iarray([(-I, D), (0, steps - D)]), 
    Iarray([(0, 300), (I*.5,  500), (I * 1, 40),  (I*.5, steps - 840)]), 
]

# Simulate and aggregate results for all neurons
res = [NeuronSpikingSimulation.simulate(n, I) for n, I in zip(neurons, activations)]
res = [(neuron, df) for neuron, df in zip(neurons, res)]
for neuron, df in res:
    df["Neuron"] = neuron.name.value
    df["Type"] = neuron.dtype.value
res = pd.concat([df for _, df in res], axis=0)

# Plot the output spiking behaviour on a single plot
fig, axes = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True)
for i in range(len(neurons)):
    # Define the axes of interest for the given neuron
    row, col = i // 4, i % 4
    axes1 = axes[(row, col)]
    axes2 = axes1.twinx()
    
    # Plot each line of interest
    df = res[res.Neuron == neurons[i].name.value]
    sns.lineplot(ax=axes1, data=df, x="Time", y="Potential", linewidth=.25)
    sns.lineplot(ax=axes2, data=df, x="Time", y="Current",   linewidth=.50, color="w")
    axes2.plot([0, 1000], [0, 0], linestyle="--", linewidth=.50, color="w")
    axes1.set_xlim((0, 200))
    axes1.set_ylim((-100, 70))
    axes1.set_xlim((0, 200))
    axes2.set_ylim((-10, 175))
    axes1.set_title(neurons[i].name.value)
    
    # Modify twin axis for all but the last column
    axes1.set_ylabel("")
    axes1.set_xlabel("")
    axes2.set_ylabel("")
    if col != 3:
        axes2.set_yticklabels("")
fig.text(0.500, 0.04, "Elapsed Time (ms)", ha="center", rotation=0)
fig.text(0.090, 0.50, "Membrane Potential (mV)", va="center", rotation=90)
fig.text(0.925, 0.50, "Post-Synaptic Input Current", va="center", rotation=270)

## Simulation 2: Cortex Synchronisation

The goal of this simulation is to mimic a collection of excitatory and inhibitory neurons and how spiking behaviours across this collection synchronises over a period of time. This is depicted as sinusoidal-like densities of neuron firings. The behaviour of this simulation is dependent on several constants under which assumptions are made. The conditions for this experiment uses the following values:

- $N_{excitatory} = 800$
- $N_{inhibitory} = 200$
- $a_{excitatory} = 0.02$
- $a_{inhibitory} = 0.02 + 0.08 \cdot \mathbb{U}[0,1]$
- $b_{excitatory} = 0.2$
- $b_{inhibitory} = 0.25 + 0.05 \cdot \mathbb{U}[0,1]$
- $c_{excitatory} = -65 + 15 \cdot \mathbb{U}[0,1]^2$
- $c_{inhibitory} = -65$
- $d_{excitatory} = -8 - 6 \cdot \mathbb{U}[0,1]^2$
- $d_{inhibitory} = -2$


Identical first-order differential equations for $\hat{v}$ and $\hat{u}$ to the previous simulation is used. Similar, neurons are given randomised activations from $\mathbb{U}[0,1]$ with different weightings and initial $\hat{v}=-65$ and $\hat{u}=-65 * b$. The threshold for neuron firing is kept consistent at $30mV$. The simulation is run with time-step delta 1ms where two 0.5ms steps are performed on each iteration.

In [None]:
class CotexSynchronisationSimulation:
    @staticmethod
    def simulate(Ne: int = 800,  Ni: int = 200,  ms_total: int = 1000) -> pd.DataFrame:
        """
        Perform a simulation of cortex activity to determine how neurons attain
        synchronisation over some period of time. A fixed 1ms time-step is used.

        Args:
            Ne (int, optional): Number of excitatory neurons. Defaults to 800.
            Ni (int, optional): Number of inhibitory neurons. Defaults to 200.
            ms_total (int, optional): Simulation time in milliseconds. Defaults to 1000.

        Returns:
            pd.DataFrame: Results from the simulation with column pair of the time of
                a neuron spiking and the respective neuron index.            
        """    
        # Create random excitatory and inhibitory neurons
        re = np.random.random(Ne)
        ri = np.random.random(Ni)

        # Initialise respective hyperparameter behaviours for each group
        a = np.concatenate([0.02 * np.ones(Ne), 0.02 + 0.08 * ri])
        b = np.concatenate([0.2  * np.ones(Ne), 0.25 - 0.05 * ri])
        c = np.concatenate([-65 + 15 * re ** 2, -65 * np.ones(Ni)])
        d = np.concatenate([8 - 6 * re ** 2, 2 * np.ones(Ni)])

        # Create random behaviours for the input current
        S = np.concatenate([0.5 * np.random.random((Ne + Ni, Ne)), 
                            -1 * np.random.random((Ne + Ni, Ne + Ni))], axis=1)
        
        # Initialise update variables for the potential and recovery
        v = -65 * np.ones(Ne + Ni)
        u = b * v

        # Simulate for the given time period
        firings = None
        for t in range(1000):
            # Calculate the input current for this time-step
            I = np.concatenate([5 * np.random.normal(size=Ne), 
                                2 * np.random.normal(size=Ni)])
            
            # Isolate fired neurons and reset them
            fired = np.where(v >= 30)[0]
            timings = np.stack([t + 0 * fired, fired], axis=1)
            if firings is None:
                firings = timings
            else:
                firings = np.append(firings, timings, axis=0)
            v[fired] = c[fired]
            u[fired] = u[fired] + d[fired]

            # Update the potentials and recoveries for each neuron
            I = I + S[:,fired].sum(axis=1)
            v = v + 0.5 * (0.04 * v ** 2 + 5 * v + 140 - u + I)
            v = v + 0.5 * (0.04 * v ** 2 + 5 * v + 140 - u + I)
            u = u + a * (b * v - u)   
        return pd.DataFrame(firings, columns=["Time (ms)", "Neuron Number"])

In [None]:
# Plot the results from the simulation
res = CotexSynchronisationSimulation.simulate()
g = sns.jointplot(data=res, 
                  x="Time (ms)", 
                  y="Neuron Number", 
                  s=.75, 
                  marginal_kws=dict(bins=1000, fill=True), 
                  legend=False)
plt.plot([0, 1000], [800, 800], linestyle="--")
g.ax_marg_y.remove()