# Spike-Timing-Dependent Plasticity (STDP)

## ðŸŽ¯ Learning Objectives

By the end of this tutorial, you will be able to:

- Understand the biological basis of Spike-Timing-Dependent Plasticity (STDP).
- Implement the STDP learning rule in Python.
- Simulate a Leaky Integrate-and-Fire (LIF) neuron with plastic synapses.
- Demonstrate how STDP enables a neuron to learn repeated patterns in input spike trains.

## ðŸ“š Prerequisites

- Basic understanding of the Leaky Integrate-and-Fire (LIF) neuron model.
- Familiarity with synaptic transmission.
- Basic Python and NumPy skills.

## Introduction

In the previous tutorials, we explored how neurons integrate inputs and generate spikes, and how synapses transmit signals. However, the strength of synapses is not fixed; it changes based on the activity of the pre- and post-synaptic neurons. This **synaptic plasticity** is the cellular basis of learning and memory.

**Hebbian learning** is often summarized as "Cells that fire together, wire together." **Spike-Timing-Dependent Plasticity (STDP)** is a biologically detailed form of Hebbian learning that depends on the precise timing of spikes:

- If the pre-synaptic neuron fires _before_ the post-synaptic neuron (causal relationship), the synapse is strengthened (**Long-Term Potentiation, LTP**).
- If the pre-synaptic neuron fires _after_ the post-synaptic neuron (acausal relationship), the synapse is weakened (**Long-Term Depression, LTD**).

In this tutorial, we will implement STDP and show how it allows a neuron to detect repeating patterns in a noisy input stream.


## Setup and Imports


In [None]:
import numpy as np
import plotly.graph_objects as go

from neuroai import plotting

# Set random seed for reproducibility
np.random.seed(42)

## 1. The STDP Learning Rule

The change in synaptic weight $\Delta w$ depends on the relative timing $\Delta t = t_{post} - t_{pre}$ between the post-synaptic spike and the pre-synaptic spike.

The standard STDP function is defined as:

$$
\Delta w(\Delta t) = \begin{cases}
A_+ e^{-\Delta t/\tau_+} & \text{if } \Delta t > 0 \quad (\text{Pre before Post} \rightarrow \text{LTP}) \\
-A_- e^{\Delta t/\tau_-} & \text{if } \Delta t < 0 \quad (\text{Post before Pre} \rightarrow \text{LTD})
\end{cases}
$$

Where:

- $A_+, A_-$ are the maximum amplitudes of potentiation and depression.
- $\tau_+, \tau_-$ are the time constants of the STDP window.

Let's visualize this function.


In [None]:
def stdp_window(delta_t, A_plus=0.01, A_minus=0.01, tau_plus=20, tau_minus=20):
    """
    Calculate the synaptic weight change based on the time difference.

    Args:
        delta_t (np.array): t_post - t_pre (ms)
        A_plus (float): Max potentiation
        A_minus (float): Max depression
        tau_plus (float): Time constant for LTP (ms)
        tau_minus (float): Time constant for LTD (ms)

    Returns:
        np.array: Weight change delta_w

    Examples:
        >>> delta_t = np.array([10, -10, 0])
        >>> dw = stdp_window(delta_t, A_plus=1.0, A_minus=1.0, tau_plus=10.0, tau_minus=10.0)
        >>> np.round(dw, 4)
        array([ 0.3679, -0.3679,  0.    ])
    """
    dw = np.zeros_like(delta_t)

    # LTP: Pre before Post (delta_t > 0)
    mask_ltp = delta_t > 0
    dw[mask_ltp] = A_plus * np.exp(-delta_t[mask_ltp] / tau_plus)

    # LTD: Post before Pre (delta_t < 0)
    mask_ltd = delta_t < 0
    dw[mask_ltd] = -A_minus * np.exp(delta_t[mask_ltd] / tau_minus)

    return dw


import doctest

doctest.testmod(verbose=True)

# Plotting
delta_t = np.linspace(-100, 100, 1000)
dw = stdp_window(delta_t)

plotting.plot_stdp_window(delta_t, dw)

## 2. Implementing STDP in a LIF Neuron

To simulate STDP efficiently, we use an **online** implementation with **synaptic traces**. Instead of iterating through all pairs of spikes, we maintain traces that decay over time:

- **Pre-synaptic trace ($x_i$)**: Increases when a pre-synaptic spike arrives, decays with $\tau_+$.
- **Post-synaptic trace ($y$)**: Increases when the post-synaptic neuron fires, decays with $\tau_-$.

**Update Rules:**

1. **On Pre-synaptic spike** (at synapse $i$):

   - Update trace: $x_i \leftarrow x_i + 1$
   - **LTD**: $w_i \leftarrow w_i - A_- \cdot y$ (Depress weight based on recent post-synaptic activity)

2. **On Post-synaptic spike**:
   - Update trace: $y \leftarrow y + 1$
   - **LTP**: For all synapses $i$, $w_i \leftarrow w_i + A_+ \cdot x_i$ (Potentiate weights based on recent pre-synaptic activity)

Let's implement a simple LIF neuron with these plastic synapses.


In [None]:
class LIFNeuronSTDP:
    def __init__(
        self,
        n_inputs: int,
        tau_m: float = 20.0,
        v_rest: float = -65.0,
        v_reset: float = -65.0,
        v_thresh: float = -50.0,
        tau_plus: float = 20.0,
        tau_minus: float = 20.0,
        A_plus: float = 0.01,
        A_minus: float = 0.012,
        w_max: float = 1.0,
    ):
        # Neuron parameters
        self.tau_m = tau_m  # Membrane time constant (ms)
        self.v_rest = v_rest  # Resting potential (mV)
        self.v_reset = v_reset  # Reset potential (mV)
        self.v_thresh = v_thresh  # Threshold (mV)

        # STDP parameters
        self.tau_plus = tau_plus
        self.tau_minus = tau_minus
        self.A_plus = A_plus
        self.A_minus = A_minus
        self.w_max = w_max

        # State variables
        self.v = self.v_rest
        self.n_inputs = n_inputs
        self.weights = np.random.uniform(0.1, 0.5, n_inputs)  # Initial random weights

        # Traces
        self.x_trace = np.zeros(n_inputs)  # Pre-synaptic traces
        self.y_trace = 0.0  # Post-synaptic trace

    def step(self, dt, input_spikes):
        """
        Simulate one time step.

        Args:
            dt (float): Time step (ms)
            input_spikes (np.array): Boolean array of size n_inputs, True if spike occurred
        """
        # 1. Decay traces
        self.x_trace *= np.exp(-dt / self.tau_plus)
        self.y_trace *= np.exp(-dt / self.tau_minus)

        # 2. Handle Pre-synaptic spikes (LTD and Trace update)
        # Indices of neurons that spiked
        pre_spike_indices = np.where(input_spikes)[0]

        if len(pre_spike_indices) > 0:
            # Update pre-synaptic traces
            self.x_trace[pre_spike_indices] += 1.0

            # LTD: Weaken weights if post-synaptic neuron fired recently (high y_trace)
            # w = w - A_minus * y_trace
            self.weights[pre_spike_indices] -= self.A_minus * self.y_trace
            self.weights = np.clip(self.weights, 0, self.w_max)

        # 3. Update Membrane Potential (LIF dynamics)
        # I_syn = sum(w_i * spike_i) (simplified current injection)
        # In a real simulation, we might use conductance-based or current-based synapses with time constants.
        # Here we assume instantaneous current injection for simplicity.
        i_syn = np.sum(self.weights[input_spikes]) * 10.0  # Scaling factor for current

        # dV = (-(V - V_rest) + R*I) / tau_m * dt
        # Assuming R=1 for simplicity
        dv = (-(self.v - self.v_rest) + i_syn) / self.tau_m * dt
        self.v += dv

        # 4. Check for Post-synaptic spike
        post_spike = False
        if self.v >= self.v_thresh:
            post_spike = True
            self.v = self.v_reset

            # Update post-synaptic trace
            self.y_trace += 1.0

            # LTP: Strengthen weights if pre-synaptic neurons fired recently (high x_trace)
            # w = w + A_plus * x_trace
            self.weights += self.A_plus * self.x_trace
            self.weights = np.clip(self.weights, 0, self.w_max)

        return post_spike, self.v, self.weights.copy()

## 3. Experiment: Pattern Learning

We will now demonstrate how STDP allows a neuron to learn a repeating pattern hidden in noise.

**Setup:**

- **100 Input Neurons** connected to 1 Output Neuron.
- **Input:** Most of the time, inputs are random Poisson noise.
- **Pattern:** Every 200 ms, a specific "frozen" pattern of spikes (lasting 50 ms) is presented to the inputs.
- **Goal:** The neuron should learn to recognize this pattern by strengthening the weights of synapses that participate in the pattern and weakening others.

We expect the neuron to eventually fire reliably when the pattern is presented.


In [None]:
# Simulation Parameters
dt = 1.0  # Time step (ms)
T = 10000  # Total duration (ms)
n_steps = int(T / dt)
n_inputs = 100
input_rate = 0.01  # Probability of spike per ms (10 Hz)
pattern_rate = 0.08  # Probability of spike in pattern (80 Hz)

# Generate Input Spikes
# 1. Background noise
input_spikes = np.random.rand(n_steps, n_inputs) < (input_rate * dt)

# 2. Create a frozen pattern (50 ms duration)
pattern_duration = 50
pattern_template = np.random.rand(pattern_duration, n_inputs) < (pattern_rate * dt)

# 3. Embed pattern at random intervals
pattern_times = []
current_time = 200
while current_time < n_steps - pattern_duration:
    input_spikes[current_time : current_time + pattern_duration] = pattern_template
    pattern_times.append(current_time)
    current_time += 200 + np.random.randint(0, 100)  # Repeat every ~250ms

print(f"Pattern embedded {len(pattern_times)} times.")

plotting.plot_input_spike_train(input_spikes, pattern_times, pattern_duration, n_steps, dt)

In [None]:
# Initialize Neuron
neuron = LIFNeuronSTDP(n_inputs, A_plus=0.015, A_minus=0.01, w_max=0.5, v_thresh=-50)

# Recording arrays
v_rec = np.zeros(n_steps)
weights_rec = np.zeros((n_steps // 100, n_inputs))  # Record weights every 100 steps
output_spikes = []

# Run Simulation
print("Running simulation...")
for t in range(n_steps):
    spike, v, w = neuron.step(dt, input_spikes[t])

    v_rec[t] = v
    if spike:
        output_spikes.append(t)

    if t % 100 == 0:
        weights_rec[t // 100] = w

print(f"Simulation complete. Total output spikes: {len(output_spikes)}")

# Identify neurons that are active in the pattern
pattern_activity = np.sum(pattern_template, axis=0)
pattern_indices = np.where(pattern_activity > 0)[0]
non_pattern_indices = np.where(pattern_activity == 0)[0]

plotting.plot_stdp_simulation_results(
    v_rec,
    output_spikes,
    input_spikes,
    weights_rec,
    pattern_times,
    pattern_duration,
    pattern_indices,
    non_pattern_indices,
    dt,
)

## 4. ðŸŽ“ Exercises

Now it's your turn to explore the properties of STDP!

### Exercise 1: The Balance of Power

In the simulation above, we set $A_- > A_+$ (slightly). This bias towards depression is often important for stability, ensuring weights don't explode.

**Task:**

1.  Modify the `LIFNeuronSTDP` initialization to set `A_plus` significantly larger than `A_minus` (e.g., `A_plus=0.015`, `A_minus=0.01`).
2.  Run the simulation again.
3.  Plot the weight evolution.

**Question:** What happens to the weights of the "noise" synapses (those not involved in the pattern)? Do they stay low?

### Exercise 2: Learning Speed

**Task:**

1.  Increase both `A_plus` and `A_minus` by a factor of 5.
2.  Run the simulation.

**Question:** Does the neuron learn the pattern faster? Is the final weight state stable, or does it fluctuate wildly?


In [None]:
# Exercise 1: Modify A_plus > A_minus

# TODO: Initialize neuron with new parameters (A_plus=0.015, A_minus=0.01)

# Record weights every 100 steps
weights_rec_ex1 = np.zeros((n_steps // 100, n_inputs))
# TODO: Run Simulation

# TODO: Plot the weight evolution

# Plot the weight evolution
plotting.plot_weight_evolution(
    weights_rec_ex1,
    pattern_indices,
    non_pattern_indices,
    n_steps,
    dt,
    title="Exercise 1: Weight Evolution (A+ > A-)",
)

In [None]:
# Exercise 2: High Learning Rate

# TODO: Initialize neuron with 5x learning rates

weights_rec_ex2 = np.zeros((n_steps // 100, n_inputs))
# TODO: Run simulation

# TODO: Plot the weight evolution

# Plot results
plotting.plot_weight_evolution(
    weights_rec_ex2,
    pattern_indices,
    non_pattern_indices,
    n_steps,
    dt,
    title="Exercise 2: Weight Evolution (High Learning Rate)",
)

### Exercise 3: Homeostasis via Synaptic Scaling

Hebbian learning can sometimes lead to "runaway" dynamics where weights either explode to the maximum or vanish to zero. To prevent this, the brain uses **homeostatic mechanisms** to keep neuronal activity stable.

One such mechanism is **Synaptic Scaling**, where the total synaptic weight received by a neuron is kept constant. This introduces competition: if one synapse gets stronger, others must get weaker.

**Task:**

1.  Modify the simulation loop to include a normalization step.
2.  After every weight update (or every $N$ steps), scale the weights so that their sum equals a target value $W_{target}$.
    $$ \mathbf{w} \leftarrow \mathbf{w} \frac{W\_{target}}{\sum \mathbf{w}} $$
3.  Run the simulation with `A_plus` > `A_minus` (which caused instability in Exercise 1).
4.  Does synaptic scaling stabilize the learning?


In [None]:
# Exercise 3: Synaptic Scaling

# TODO: Initialize neuron with unstable parameters (A_plus > A_minus) and a target total weight

weights_rec_ex3 = np.zeros((n_steps // 100, n_inputs))
# TODO: Run Simulation with synaptic scaling step

# TODO: Plot results

# Plot results
plotting.plot_weight_evolution(
    weights_rec_ex3,
    pattern_indices,
    non_pattern_indices,
    n_steps,
    dt,
    title="Exercise 3: Weight Evolution (Synaptic Scaling)",
)

### Exercise 4: Anti-Hebbian Learning

In some neural circuits (e.g., the cerebellum-like structures), the learning rule is reversed:

- **Pre-before-Post** leads to **LTD**.
- **Post-before-Pre** leads to **LTP**.

This is known as **Anti-Hebbian learning**. It is often used for predictive cancellation of expected sensory inputs.

**Task:**

1.  Modify the `stdp_window` function (or create a new one) to implement Anti-Hebbian learning.
2.  Plot the STDP window to verify the shape.


In [None]:
# Exercise 4: Anti-Hebbian STDP Window

# TODO: Implement anti-Hebbian window function
def anti_hebbian_window(delta_t, A_plus=0.01, A_minus=0.01, tau_plus=20, tau_minus=20): ...


# TODO: Plot the window

# Plotting
delta_t = np.linspace(-100, 100, 1000)
dw_anti = anti_hebbian_window(delta_t)

# Note: We keep this plot here as it's specific to the exercise
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=delta_t,
        y=dw_anti,
        mode="lines",
        name="Anti-Hebbian Window",
        line=dict(color="purple", width=2),
    )
)
fig.add_hline(y=0, line_dash="dash", line_color="gray", opacity=0.5)
fig.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
fig.update_layout(title="Anti-Hebbian Learning Window", xaxis_title="Î”t (ms)", yaxis_title="Î”w")
fig.show()

### Exercise 5: Neuromodulated STDP (Global Signals)

In the brain, learning is often gated by global neuromodulatory signals (like dopamine, acetylcholine, or norepinephrine) that signal reward, novelty, or surprise. This is often modeled as a **three-factor learning rule**:

$$ \Delta w \propto M \cdot \text{STDP}(\Delta t) $$

Where $M$ is a global neuromodulatory signal. If $M=0$, no learning occurs. If $M>0$, learning is enabled or amplified.

**Task:**

1.  Think about how you would modify the `LIFNeuronSTDP` class to include a global `reward` signal.
2.  Implement a simple simulation where learning is only enabled during specific "rewarded" windows (e.g., when the pattern is present).


In [None]:
# Exercise 5: Neuromodulated STDP

# TODO: Initialize neuron

# TODO: Define reward signal

weights_rec_gated = np.zeros((n_steps // 100, n_inputs))
# TODO: Run Simulation with gated learning

# TODO: Plot results

# Plot results
plotting.plot_weight_evolution(
    weights_rec_gated,
    pattern_indices,
    non_pattern_indices,
    n_steps,
    dt,
    title="Exercise 5: Gated STDP (Learning only during pattern)",
)

## Conclusion

In this tutorial, we implemented the STDP learning rule and applied it to a Leaky Integrate-and-Fire neuron. We observed that:

1.  **STDP modifies synaptic weights** based on the precise timing of pre- and post-synaptic spikes.
2.  **Pattern Learning**: The neuron learned to detect a repeating spatio-temporal pattern hidden in noise. Synapses corresponding to the pattern were strengthened (LTP), while others were weakened (LTD) or remained low.
3.  **Selectivity**: After learning, the neuron fires reliably when the pattern is presented, acting as a pattern detector.

This mechanism is thought to be a fundamental way the brain learns to recognize features and sequences in sensory input.
