<a href="https://colab.research.google.com/github/MANicholson/Modelling-and-Simulation-of-a-CMOS-Synapse-Implementing-Two-phase-Plasticity/blob/main/Differential_equation_based_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Imports and installs**

In [None]:
!pip install equinox

# Imports
import jax
import random
import jax.numpy as jnp
import matplotlib.pyplot as plt
import equinox as eqx

Collecting equinox
  Downloading equinox-0.11.4-py3-none-any.whl (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.2/175.2 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.2.30-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.9/41.9 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard==2.13.3 (from jaxtyping>=0.2.20->equinox)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox
Successfully installed equinox-0.11.4 jaxtyping-0.2.30 typeguard-2.13.3


# **Differential equations**

In [None]:
class full_synapse_diff(eqx.Module):
    # Calcium constants --------------------------------------------------------
    c_pre: float
    c_post: float
    tau_c: float
    dt: float
    c_0: float
    # Early-phase constants ----------------------------------------------------
    tau_h: float
    h_max: float
    gamma_p: float
    gamma_d: float
    theta_p: float
    theta_d: float
    h_0: float
    h_init: float
    # Late-phase constants -----------------------------------------------------
    tau_z: float
    z_max: float
    z_min: float
    theta_tag: float
    alpha: float
    theta_pro: float
    z_0: float

    def __init__(self, c_pre, c_post, tau_c, dt, c_0, tau_h, h_max, gamma_p, gamma_d, theta_p, theta_d, h_0, h_init, tau_z, z_max, z_min, theta_tag, alpha, theta_pro, z_0):
        # Calcium constants ----------------------------------------------------
        self.c_pre = c_pre
        self.c_post = c_post
        self.tau_c = tau_c
        self.dt = dt
        self.c_0 = c_0
        # Early-phase constants ------------------------------------------------
        self.tau_h = tau_h
        self.h_max = h_max
        self.gamma_p = gamma_p
        self.gamma_d = gamma_d
        self.theta_p = theta_p
        self.theta_d = theta_d
        self.h_0 = h_0
        self.h_init = h_init
        # Late-phase constants -------------------------------------------------
        self.tau_z = tau_z
        self.z_max = z_max
        self.z_min = z_min
        self.theta_tag = theta_tag
        self.alpha = alpha
        self.theta_pro = theta_pro
        self.z_0 = z_0

    def __call__(self, input_):

        # Calcium part ---------------------------------------------------------
        c_ji_0 = jnp.zeros((1,)) + self.c_0  # Initial state

        def differential_eq_calcium(c_ji, input_):
            t_spike_pre, t_spike_post = input_

            dc_dt = -c_ji / self.tau_c
            c_pre_increase = t_spike_pre * self.c_pre
            c_post_increase = t_spike_post * self.c_post

            c_ji_new = c_ji + dc_dt * self.dt + c_pre_increase + c_post_increase

            return c_ji_new, c_ji_new

        _, c_ji = jax.lax.scan(differential_eq_calcium, c_ji_0, input_)

        time_array = jnp.arange(len(c_ji)) * self.dt

        # Early-phase part -----------------------------------------------------
        if self.h_init is None:
            self.h_init = self.h_0

        h_ji_0 = jnp.zeros((1,)) + self.h_init

        def differential_eq_early_phase(h_ji, c_ji):
            # Early-phase weight decay:
            decay_term = (1 / self.tau_h) * 0.1 * (self.h_0 - h_ji)

            # Early-phase LTP (potentiation occurs when the calcium concentration surpasses the specified threshold)
            ltp_term = (1 / self.tau_h) * self.gamma_p * (self.h_max - h_ji) * (c_ji > self.theta_p)

            # Early-phase LTD (depression occurs when the calcium concentration surpasses the specified threshold)
            ltd_term = (1 / self.tau_h) * self.gamma_d * h_ji * (c_ji > self.theta_d)

            # Combine terms
            dh_dt = decay_term + ltp_term - ltd_term

            # Update the early-phase variable using Euler's method
            h_ji_new = h_ji + dh_dt * self.dt

            return h_ji_new, h_ji_new

        _, h_ji = jax.lax.scan(differential_eq_early_phase, h_ji_0, c_ji)

        # Late-phase part ------------------------------------------------------
        z_ji_0 = jnp.zeros((1,)) + self.z_0

        # Set the protein availability to 0:
        p_i = 0.0

        def differential_eq_late_phase(carry, input_):
            z_ji, p_i = carry
            h_ji, c_ji = input_

            # Early-phase change of neuron i
            epsilon_hi = jnp.abs(h_ji - self.h_0)

            # Check if the early-phase change is sufficient to trigger protein synthesis:
            condition = jnp.squeeze(epsilon_hi > self.theta_pro)
            p_i = jax.lax.cond(condition, lambda _: self.alpha, lambda _: p_i, operand=None)

            # Calculate the change in late-phase variable based on the differential equation.

            # Potentiation contribution
            pot_term = p_i * (self.z_max - z_ji) * ((h_ji - self.h_0 - self.theta_tag) > 0)

            # Depression contribution
            dep_term = p_i * (z_ji - self.z_min) * ((self.h_0 - h_ji - self.theta_tag) > 0)

            # Combine terms to get the total change
            dz_dt = (1 / self.tau_z) * (pot_term - dep_term)

            # Update late-phase variable using Euler's method
            z_ji_new = z_ji + dz_dt * self.dt

            return (z_ji_new, p_i), z_ji_new

        carry = (z_ji_0, p_i)
        inputs = (h_ji, c_ji)
        _, z_ji = jax.lax.scan(differential_eq_late_phase, carry, inputs)

        # Calculate the total synaptic weight ----------------------------------

        w_ji_0 = jnp.zeros((1,)) + h_ji_0 + h_ji_0 * z_ji_0


        def total_synaptic_weight(w_ji, Inputs_):
            h_ji, z_ji = Inputs_

            w_ji_new = h_ji + h_ji_0 * z_ji

            return w_ji_new, w_ji_new


        Inputs_ = (h_ji, z_ji)
        _, w_ji = jax.lax.scan(total_synaptic_weight, w_ji_0, Inputs_)

        return c_ji, h_ji, z_ji, w_ji, time_array


# **Testing**

## Defining constants (for the differential equation code)

In [None]:
# Constants

# tc_delay = 0.0188  # Delay of postsynaptic calcium influx after presynaptic spike [s]
c_pre = 0.6  # Presynaptic calcium contribution [in vivo adjusted]
c_post = 0.1655  # Postsynaptic calcium contribution [in vivo adjusted]
tau_c = 0.0488  # Calcium time constant [s]
tau_h = 688.4 # ALTERED FROM 688.4  # Early-phase time constant [s]
tau_p = 60*60  # Protein time constant [s]
tau_z = 60*60  # Late-phase time constant [s]
# nu_th = 40  # Firing rate threshold for LTP induction [Hz]
gamma_p = 1645.6 # ALTERED FROM 1645.6  # Potentiation rate
gamma_d = 313.1 # ALTERED FROM 313.1  # Depression rate
theta_p = 0.3 # ALTERED FROM 3  # Calcium threshold for potentiation
theta_d = 0.2 # ALTERED FROM 1.2  # Calcium threshold for depression
# sigma_pl = 2.90436 * 10**(-3)  # Standard deviation for plasticity fluctuations [V]
alpha = 1.0  # Protein synthesis rate
theta_pro = 0.0023 #2.10037  * 10**(-3)  # Protein synthesis threshold [V]
theta_tag = 0.640149  * 10**(-4)  # Tagging threshold [V]
h_0 = 4.20075  * 10**(-3) # Median initial excitatory→excitatory coupling strength [V]
h_init = h_0
z_0 = 0.1
# Combining Jorge's thesis and Jannik's paper

h_max = 10 * 10**(-3) # The maximum value of the early-phase variable [V]
z_max = 1 # The minimum value of the late-phase variable
z_min = - 0.5 # The maximum value of the late-phase variable

# beta = 4.6675*10**(-3)

## Test input

### Test input generator

In [None]:
import jax
import jax.numpy as jnp
import jax.random as random

def generate_binary_array(length, ratio, key):
    """
    Generate a binary array of 0s and 1s with a specific ratio of 1s to 0s.

    Args:
        length (int): Length of the array.
        ratio (float): Desired ratio of 1s to 0s.
        key (jax.random.PRNGKey): Random key for JAX's random number generator.

    Returns:
        jax.numpy.ndarray: Array of 0s and 1s with the specified ratio.
    """
    num_ones = int(length * ratio)
    num_zeros = length - num_ones

    # Create an array with the specified number of 1s and 0s
    binary_array = jnp.array([1] * num_ones + [0] * num_zeros)

    # Shuffle the array to randomize the distribution of 1s and 0s
    shuffled_array = random.permutation(key, binary_array)

    return shuffled_array

# Example usage
key = random.PRNGKey(0)
length = 100  # Length of the array
ratio = 0.1   # Desired ratio of 1s to 0s

pre_spikes =  [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] #generate_binary_array(length, ratio, key)

post_spikes = [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] #generate_binary_array(length, ratio, key)



In [None]:
# Add 300 zeros to the end of the array
post_spikes.extend([0] )
pre_spikes.extend([0])

print(post_spikes)
print(pre_spikes)
print(f"Length of post_spikes: {len(post_spikes)}")

t_array = jnp.linspace(0, 10, len(pre_spikes))
dt = 0.01

t_array_c = jnp.linspace(0, 50, len(pre_spikes))
dt_c = 0.001


[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
Length of post_spikes: 41


In [None]:
# Initialize time array and concentration array.
t_max = t_array[-1]
t = jnp.arange(0, t_max + dt, dt)

# Map presynaptic spike times onto new t array to ensure the spikes on a multiple of dt
t_spike_pre_new = jnp.zeros_like(t)
for i, time_point in enumerate(t_array):
    # Find the index of the closest time point in 't'
    closest_idx = jnp.argmin(jnp.abs(t - time_point))
    # Assign the value from 'i_in' to the corresponding location in 'i_new'
    t_spike_pre_new = t_spike_pre_new.at[closest_idx].set(pre_spikes[i])


# Map postsynaptic spike times onto new t array to ensure the spikes on a multiple of dt
t_spike_post_new = jnp.zeros_like(t)
for i, time_point in enumerate(t_array):
    # Find the index of the closest time point in 't'
    closest_idx = jnp.argmin(jnp.abs(t - time_point))
    # Assign the value from 'i_in' to the corresponding location in 'i_new'
    t_spike_post_new = t_spike_post_new.at[closest_idx].set(post_spikes[i])

In [None]:
# Initialize time array and concentration array.
t_max_c = t_array_c[-1]
t_c = jnp.arange(0, t_max_c + dt_c, dt_c)

# Map presynaptic spike times onto new t array to ensure the spikes on a multiple of dt
t_spike_pre_new_c = jnp.zeros_like(t_c)
for i, time_point in enumerate(t_array_c):
    # Find the index of the closest time point in 't'
    closest_idx_c = jnp.argmin(jnp.abs(t_c - time_point))
    # Assign the value from 'i_in' to the corresponding location in 'i_new'
    t_spike_pre_new_c = t_spike_pre_new_c.at[closest_idx_c].set(pre_spikes[i])

# Map postsynaptic spike times onto new t array to ensure the spikes on a multiple of dt
t_spike_post_new_c = jnp.zeros_like(t_c)
for i, time_point in enumerate(t_array_c):
    # Find the index of the closest time point in 't'
    closest_idx_c = jnp.argmin(jnp.abs(t_c - time_point))
    # Assign the value from 'i_in' to the corresponding location in 'i_new'
    t_spike_post_new_c = t_spike_post_new_c.at[closest_idx_c].set(post_spikes[i])

In [None]:
synapseModel = full_synapse_diff(c_pre, c_post, tau_c, dt, 0, tau_h, h_max, gamma_p, gamma_d, theta_p, theta_d, h_0, h_init, tau_z, z_max, z_min, theta_tag, alpha, theta_pro, z_0)

c_ji_eqx, h_ji_eqx, z_ji_eqx, w_ji_eqx, t_eqx = synapseModel((jnp.array(t_spike_pre_new), jnp.array(t_spike_post_new)))