# Lateral Inhibition: Modeling Edge Detection with Spiking Neural Networks

## ðŸŽ¯ Learning Objectives

By the end of this tutorial, you will be able to:

1. **Understand** the biological principle of lateral inhibition and its role in sensory processing
2. **Implement** a spiking neural network (SNN) with lateral inhibition using `snntorch`
3. **Observe** edge enhancement and the Mach band illusion in neural responses
4. **Connect** biological lateral inhibition to convolutional neural networks (CNNs)
5. **Explain** the efficient coding hypothesis and how lateral inhibition compresses sensory information
6. **Explore** Hopfield networks as an advanced application of recurrent connections

## ðŸ“š Prerequisites

**Knowledge:**

- Basic Python programming and PyTorch
- Understanding of neurons and spiking (see Tutorial 13: STDP Learning)
- Familiarity with NumPy and basic plotting

**Installation:**

This tutorial requires the `snntorch` library. Make sure to install it by running:

```bash
uv sync --extra snn
```

This will install `snntorch` and its dependencies.

## Introduction

Lateral inhibition is a fundamental principle in neurobiology where an excited neuron reduces the activity of its neighbors. This mechanism increases the contrast and sharpness of sensory responses, enabling the brain to detect edges and boundaries effectively. In the visual system, this is famously observed in retinal ganglion cells and the Mach band illusion.

In this tutorial, we will build a simple network of LIF neurons to model this phenomenon.


## Setup and Imports


In [None]:
import numpy as np
import plotly.graph_objects as go
import snntorch as snn
import torch
import torch.nn as nn

from neuroai import plotting

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

## 1. Theory: Lateral Inhibition

Lateral inhibition is the capacity of an excited neuron to reduce the activity of its neighbors. Lateral inhibition disables the spreading of action potentials from excited neurons to neighboring neurons in the lateral direction. This creates a contrast in stimulation that allows increased sensory perception.

Mathematically, we can model the input current $I_i$ to neuron $i$ as:

<!-- prettier-ignore -->
$$ I_i = X_i - \sum_{j \neq i} w_{ij} Y_j $$

where:

- $X_i$ is the excitatory input from the receptor (or previous layer).
- $Y_j$ is the output (spike) of neighbor neuron $j$.
- $w_{ij}$ is the inhibitory weight from neuron $j$ to neuron $i$.

### The "Mexican Hat" Interaction

Commonly, the interaction profile is modeled as a "Mexican Hat" function (or Difference of Gaussians):

1.  **Short-range excitation**: Neurons might excite their immediate neighbors (though in our simple model, we only have self-excitation/input).
2.  **Medium-range inhibition**: Neurons inhibit their neighbors.
3.  **Long-range silence**: Neurons far away do not interact.

In our simplified 1D model, we will implement **local inhibition**, where a neuron inhibits only its immediate left and right neighbors.


In [None]:
# Define simulation parameters
num_steps = 100
num_neurons = 30

# Create a 1D edge input (step function)
# Light on the left (high intensity), Dark on the right (low intensity)
input_signal = torch.zeros(num_steps, num_neurons)
input_signal[:, : num_neurons // 2] = 0.8  # High intensity
input_signal[:, num_neurons // 2 :] = 0.2  # Low intensity

# Plot the input profile (spatial)
plotting.plot_input_profile(input_signal, num_neurons)

## 2. Building the Network

We will construct a Spiking Neural Network where neurons are arranged in a 1D line.
Each neuron receives:

1. **Direct Input**: From the sensory signal (identity connection).
2. **Lateral Inhibition**: Inhibitory input from its immediate left and right neighbors.

We will use `snntorch.Leaky` neurons. The lateral connections will be implemented as a recurrent weight matrix.


In [None]:
class LateralInhibitionNet(nn.Module):
    """
    A Spiking Neural Network with lateral inhibition.
    Neurons are arranged in a 1D line and inhibit their immediate neighbors.
    """

    def __init__(self, num_neurons, inhibition_strength=0.5):
        super().__init__()
        self.num_neurons = num_neurons

        # Initialize LIF neuron
        # beta is the decay rate of the membrane potential
        self.lif = snn.Leaky(beta=0.9)

        # Define Lateral Inhibition Weights
        # We want a matrix where W[i, i-1] = -w and W[i, i+1] = -w
        self.w_lat = torch.zeros(num_neurons, num_neurons)

        for i in range(num_neurons):
            if i > 0:
                self.w_lat[i, i - 1] = -inhibition_strength  # Inhibit left neighbor
            if i < num_neurons - 1:
                self.w_lat[i, i + 1] = -inhibition_strength  # Inhibit right neighbor

    def forward(self, x):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input tensor of shape (num_steps, num_neurons)

        Returns:
            spk_rec (torch.Tensor): Recorded spikes of shape (num_steps, num_neurons)
            mem_rec (torch.Tensor): Recorded membrane potentials
        """
        # x shape: (num_steps, num_neurons)
        num_steps = x.shape[0]

        # Initialize hidden state (membrane potential)
        mem = self.lif.init_leaky()

        # Record the spike train and membrane potential
        spk_rec = []
        mem_rec = []

        # Initialize previous spike for recurrence (at t=0, no previous spikes)
        spk = torch.zeros(self.num_neurons)

        for step in range(num_steps):
            # Current input is External Input + Lateral Inhibition
            # x[step] is shape (num_neurons,)

            # Lateral input: W_lat @ spk_prev
            # If neighbor spiked last step, I get inhibited this step
            lateral_input = torch.matmul(self.w_lat, spk)

            # Total current
            current_input = x[step] + lateral_input

            # Run LIF neuron
            spk, mem = self.lif(current_input, mem)

            spk_rec.append(spk)
            mem_rec.append(mem)

        return torch.stack(spk_rec), torch.stack(mem_rec)


# Instantiate the network
# We use a relatively high inhibition strength to see the effect clearly
net = LateralInhibitionNet(num_neurons=num_neurons, inhibition_strength=5.0)

# Visualize the weight matrix
plotting.plot_weight_matrix(net.w_lat)

## 3. Simulation and Results

Now we run the simulation. We will record the spikes and calculate the firing rate for each neuron.
We expect to see **edge enhancement**:

- The neuron on the bright side of the edge should have a higher firing rate than the interior bright neurons (because it receives less inhibition from the dark side).
- The neuron on the dark side of the edge should have a lower firing rate than the interior dark neurons (because it receives more inhibition from the bright side).


In [None]:
# Run simulation
spk_rec, mem_rec = net(input_signal)

# Calculate firing rate (spikes per step)
firing_rate = spk_rec.float().mean(dim=0)

# --- Plotting with Plotly (Dual Axis) ---
plotting.plot_simulation_results(input_signal, firing_rate, num_neurons)

# --- Raster Plot ---
plotting.plot_raster(spk_rec)

### ðŸ§  Let's think about it!

Look at the "Output Firing Rate" graph above.

1.  **Bright Side (Left)**: Why is the firing rate of the neuron just before the edge (index 14) _higher_ than the neurons further left (e.g., index 10)?
2.  **Dark Side (Right)**: Why is the firing rate of the neuron just after the edge (index 15) _lower_ than the neurons further right (e.g., index 20)?

<details>
<summary>Click to reveal explanation</summary>

1.  **Bright Side Enhancement**: The neuron at index 14 is stimulated by bright light. Its left neighbor (13) is also bright (high inhibition), but its right neighbor (15) is dark (low inhibition). Therefore, it receives _less total inhibition_ than the interior bright neurons (who are surrounded by bright neighbors on both sides). Less inhibition = higher firing rate.
2.  **Dark Side Suppression**: The neuron at index 15 is stimulated by dim light. Its right neighbor (16) is dim (low inhibition), but its left neighbor (14) is bright (high inhibition). Therefore, it receives _more total inhibition_ than the interior dark neurons. More inhibition = lower firing rate.

This "push-pull" effect exaggerates the difference at the boundary, making the edge "pop" out.

</details>


## 4. The Link to Deep Learning: Convolutions

You might have noticed that the operation performed by our lateral inhibition network looks suspiciously similar to a fundamental operation in Deep Learning: **Convolution**.

In a Convolutional Neural Network (CNN), a "kernel" or "filter" slides across the input.
If we define a kernel $K = [-w, 1, -w]$, and convolve it with our input $X$, we get:

<!-- prettier-ignore -->
$ (X * K)_i = -w \cdot X_{i-1} + 1 \cdot X_i - w \cdot X_{i+1} $

This is exactly what our lateral inhibition circuit is doing!

- The **center** of the kernel ($1$) represents the direct excitatory input.
- The **surround** of the kernel ($-w$) represents the lateral inhibitory input.

In NeuroAI, this specific filter is often called a **Difference of Gaussians (DoG)** or a **Laplacian** filter. It is a standard edge-detection filter in computer vision.

### ðŸ§ª Experiment: Comparing SNN to CNN

Let's verify this by using a standard PyTorch `Conv1d` layer with fixed weights to replicate the behavior of our biological SNN.


In [None]:
# Define a Conv1d layer
# 1 input channel, 1 output channel, kernel size 3, padding 1 (to keep same size)
conv_layer = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=False)

# Manually set the weights to mimic our lateral inhibition
# Weights: [Left Neighbor, Self, Right Neighbor]
# Note: In our SNN, 'Self' was the direct input (weight 1.0) and neighbors were inhibitory.
# We need to normalize or scale this to match the firing rate scale, but the shape should be identical.
inhibition_w = 0.5
conv_weights = torch.tensor([[-inhibition_w, 1.0, -inhibition_w]])
conv_layer.weight.data = conv_weights.view(1, 1, 3)

# Prepare input for Conv1d (Batch, Channel, Length)
input_tensor = input_signal[0].unsqueeze(0).unsqueeze(0)  # Take one time step

# Run Convolution
with torch.no_grad():
    conv_output = conv_layer(input_tensor).squeeze()

# Plot comparison
plotting.plot_cnn_comparison(firing_rate, conv_output, num_neurons)

### ðŸ§  Let's think about it!

If a simple Convolutional Layer can do the same thing as our complex Spiking Neural Network, why use spikes?

<details>
<summary>Click to reveal</summary>

1.  **Energy Efficiency**: In biology (and neuromorphic hardware), spikes are sparse. If there is no edge (uniform input), the lateral inhibition might silence the neurons completely. No spikes = no energy consumed. A standard CNN consumes energy for every multiplication, regardless of the output.
2.  **Time**: SNNs operate in time. The "edge detection" can happen very quickly (first spike latency), or integrate over time to improve signal-to-noise ratio.
3.  **Plasticity**: In biology, these weights aren't just fixed; they can adapt locally using rules like STDP (Spike-Timing-Dependent Plasticity).
</details>

### Efficient Coding Hypothesis

This tutorial demonstrates a core principle of the **Efficient Coding Hypothesis** (Horace Barlow, 1961).
The goal of the sensory system is to represent the environment as efficiently as possible.

- Natural scenes have high **redundancy** (neighboring pixels are usually similar).
- Transmitting the raw pixel values is wasteful.
- Transmitting only the **changes** (edges) removes this redundancy and compresses the information.

Lateral inhibition is the brain's built-in compression algorithm!


## 5. ðŸŽ“ Exercises

### Exercise 1: Wider Inhibition Kernel

**Task**

The current model only inhibits the immediate neighbors (distance = 1).
Modify the `LateralInhibitionNet` class (or create a new one) to implement a wider inhibition kernel.

**Requirements:**

1.  Create a new class `WideInhibitionNet`.
2.  In `__init__`, set up the weights such that:
    - Neighbors at distance 1 (i-1, i+1) have inhibition strength `w1`.
    - Neighbors at distance 2 (i-2, i+2) have inhibition strength `w2`.
3.  Run the simulation with `w1 = 5.0` and `w2 = 2.0`.
4.  Plot the results and compare with the previous simple model.

**Hints:**

- You will need to add checks in your loop to ensure you don't access indices outside `[0, num_neurons-1]`.
- For example, `if i > 1: self.w_lat[i, i-2] = -w2`.


In [None]:
# Your code here


class WideInhibitionNet(nn.Module):
    def __init__(self, num_neurons, w1=5.0, w2=2.0):
        super().__init__()
        self.num_neurons = num_neurons
        self.lif = snn.Leaky(beta=0.9)
        self.w_lat = torch.zeros(num_neurons, num_neurons)

        # TODO: Implement the weight initialization loop
        for i in range(num_neurons):
            pass  # Replace with your code

    def forward(self, x):
        # You can copy the forward method from the previous class
        # or inherit and reuse it
        pass


# Instantiate and run
# wide_net = WideInhibitionNet(...)
# spk_wide, _ = wide_net(input_signal)
# ... Plotting code ...


# Test the solution
def test_wide_inhibition():
    net = WideInhibitionNet(num_neurons=10, w1=0.5, w2=0.2)
    w = net.w_lat
    assert w[2, 1] == -0.5, f"Weight at distance 1 should be -0.5, got {w[2, 1]}"
    assert w[2, 3] == -0.5, f"Weight at distance 1 should be -0.5, got {w[2, 3]}"
    assert w[2, 0] == -0.2, f"Weight at distance 2 should be -0.2, got {w[2, 0]}"
    assert w[2, 4] == -0.2, f"Weight at distance 2 should be -0.2, got {w[2, 4]}"
    print("âœ… Exercise 1 Solution Passed!")


test_wide_inhibition()

In [None]:
# Visualize the results of Exercise 1

# Instantiate the wide inhibition network
wide_net = WideInhibitionNet(num_neurons=num_neurons, w1=5.0, w2=2.0)

# Visualize the weight matrix to see the wider inhibition pattern
plotting.plot_weight_matrix(wide_net.w_lat, title="Wide Inhibition Weight Matrix (Distance 1 & 2)")

# Run simulation with wide inhibition network
spk_wide, _ = wide_net(input_signal)
firing_rate_wide = spk_wide.float().mean(dim=0)

# Compare firing rates: Original vs Wide Inhibition
fig = go.Figure()

# Original network (narrow inhibition)
fig.add_trace(
    go.Scatter(
        x=np.arange(num_neurons),
        y=firing_rate.detach().numpy(),
        mode="lines+markers",
        name="Narrow Inhibition (distance=1)",
        line=dict(color="blue"),
    )
)

# Wide inhibition network
fig.add_trace(
    go.Scatter(
        x=np.arange(num_neurons),
        y=firing_rate_wide.detach().numpy(),
        mode="lines+markers",
        name="Wide Inhibition (distance=1,2)",
        line=dict(color="red"),
    )
)

# Add vertical line at edge boundary
fig.add_vline(x=num_neurons // 2 - 0.5, line_dash="dash", line_color="gray", opacity=0.5)

fig.update_layout(
    title="Comparison: Narrow vs Wide Lateral Inhibition",
    xaxis_title="Neuron Index",
    yaxis_title="Firing Rate",
    hovermode="x unified",
)

fig.show()

# Also plot the raster plot for wide inhibition
plotting.plot_raster(spk_wide)

### Exercise 2: The Effect of Inhibition Strength

**Task**

Investigate how the strength of lateral inhibition affects the edge detection capability.
Run the simulation with the original `LateralInhibitionNet` using three different inhibition strengths:

1.  `inhibition_strength = 0.0` (No inhibition)
2.  `inhibition_strength = 2.0` (Moderate inhibition)
3.  `inhibition_strength = 10.0` (Strong inhibition)

**Questions:**

- What happens to the firing rates when inhibition is zero?
- Does the "Mach band" effect appear in all cases?
- What happens when inhibition is very strong? Does it suppress the signal too much?


In [None]:
# Your code for Exercise 2 here


### Exercise 3: Robustness to Noise

**Task**

Real-world sensory signals are rarely clean; they are often corrupted by noise.

1.  Create a noisy version of the input signal by adding Gaussian noise.
    ```python
    noise = torch.randn_like(input_signal) * 0.1
    noisy_input = input_signal + noise
    noisy_input = torch.clamp(noisy_input, 0, 1) # Keep values between 0 and 1
    ```
2.  Feed this `noisy_input` into the `LateralInhibitionNet` (use `inhibition_strength=5.0`).
3.  Plot the input intensity vs. output firing rate.

**Question:**

- Does the network still detect the edge despite the noise?
- Does lateral inhibition amplify or suppress the noise in the uniform regions?


In [None]:
# Your code for Exercise 3 here


## 6. Advanced Topic: From Inhibition to Memory (Hopfield Networks)

We have seen how **fixed inhibitory** recurrent connections can perform useful computations like edge detection.
But what happens if we have **learnable excitatory** recurrent connections?

This leads us to the concept of **Attractor Networks** and **Associative Memory**. The most famous example is the **Hopfield Network** (John Hopfield, 1982).

### Theory: Energy Landscapes and Hebbian Learning

A Hopfield network is a fully connected recurrent network (with symmetric weights) that acts as a content-addressable memory.
Instead of mapping input $X \to Y$, it maps an initial state $S_{init} \to S_{stable}$, where $S_{stable}$ is a stored memory pattern.

The network minimizes an "Energy Function":

<!-- prettier-ignore -->
$$ E = -\frac{1}{2} \sum_{i,j} w_{ij} s_i s_j $$

If we update neurons asynchronously to align with their local field, the energy $E$ is guaranteed to decrease (or stay same) until it reaches a local minimum (attractor).

**Hebbian Learning Rule**:
To store patterns $\xi^{\mu}$, we set the weights proportional to the correlation between neurons:

<!-- prettier-ignore -->
$$ w_{ij} = \frac{1}{N} \sum_{\mu=1}^{P} \xi_i^{\mu} \xi_j^{\mu} $$
"Neurons that fire together, wire together."

### Implementation

Let's implement a simple binary Hopfield network to store and recover images.


In [None]:
class HopfieldNetwork:
    def __init__(self, num_neurons):
        self.num_neurons = num_neurons
        self.weights = torch.zeros(num_neurons, num_neurons)

    def train(self, patterns):
        """
        Train using Hebbian rule.
        patterns: (P, N) tensor of -1 and +1
        """
        P, N = patterns.shape
        # W = (1/N) * X.T @ X
        self.weights = (1.0 / N) * torch.matmul(patterns.T, patterns)

        # Remove self-connections (diagonal = 0)
        self.weights.fill_diagonal_(0)

    def predict(self, state, steps=10):
        """
        Asynchronous update dynamics.
        """
        state = state.clone()
        for _ in range(steps):
            # Compute local field
            # h = W @ s
            h = torch.matmul(self.weights, state)

            # Update rule: s = sign(h)
            # Note: In true async, we update one by one.
            # Here we do synchronous for simplicity (Little model),
            # or we can loop through indices.
            state = torch.sign(h)

            # Handle zero case (keep previous state)
            state[h == 0] = state[h == 0]

        return state


# Create simple patterns (e.g., 5x5 images flattened to 25 neurons)
N_sqrt = 10
N = N_sqrt * N_sqrt

# Pattern 1: A Cross
p1 = -torch.ones(N_sqrt, N_sqrt)
p1[N_sqrt // 2, :] = 1
p1[:, N_sqrt // 2] = 1
p1 = p1.flatten()

# Pattern 2: A Square
p2 = -torch.ones(N_sqrt, N_sqrt)
p2[2:-2, 2:-2] = 1
p2 = p2.flatten()

patterns = torch.stack([p1, p2])

# Train
hopfield = HopfieldNetwork(N)
hopfield.train(patterns)

# Visualize Patterns
plotting.plot_hopfield_patterns(p1, p2, N_sqrt)

### Visualizing the Hopfield Connectivity

The "knowledge" of the Hopfield network is stored in its weight matrix.
Unlike the lateral inhibition network where weights were fixed and local (only neighbors), here the weights are **global** (fully connected) and **learned** from the data.

Let's visualize the weight matrix. You might notice some structure that reflects the stored patterns.


In [None]:
# Visualize the learned weights
plotting.plot_weight_matrix(
    hopfield.weights, title="Hopfield Network Weight Matrix (Hebbian Learning)"
)

### Visualizing the Network Graph

To better understand the "all-to-all" connectivity, let's visualize a small Hopfield network as a graph.

- **Nodes** represent neurons.
- **Edges** represent the learned weights between them.
- **Red edges** are inhibitory (negative weights).
- **Blue edges** are excitatory (positive weights).


In [None]:
# Visualize the Network Graph
plotting.plot_hopfield_topology(HopfieldNetwork)

### Associative Memory Demo

Now we will:

1.  Take a stored pattern.
2.  Corrupt it with noise (flip some bits).
3.  Let the network evolve.
4.  See if it recovers the original pattern.


In [None]:
# Corrupt Pattern 1
noise_level = 0.3
corrupted_p1 = p1.clone()
mask = torch.rand(N) < noise_level
corrupted_p1[mask] *= -1  # Flip bits

# Recover
recovered_p1 = hopfield.predict(corrupted_p1, steps=5)

# Plot
plotting.plot_hopfield_recall(p1, corrupted_p1, recovered_p1, N_sqrt, noise_level)

### ðŸ“š Further Reading on Hopfield Networks

- **Original Paper**: Hopfield, J. J. (1982). "Neural networks and physical systems with emergent collective computational abilities". _PNAS_.
- **Modern Hopfield Networks**: Recent work (e.g., "Dense Associative Memory" by Krotov & Hopfield, 2016) has shown that changing the energy function can drastically increase storage capacity, linking Hopfield Networks to **Transformers** (Attention mechanisms).
- **Scholarpedia**: [Hopfield Network](http://www.scholarpedia.org/article/Hopfield_network)


## Conclusion

In this tutorial, we explored the power of recurrent connections in neural networks:

1.  **Lateral Inhibition**: We saw how fixed inhibitory connections create **edge detection** and contrast enhancement (Mach bands), a principle used in the retina and modeled by Convolutions in Deep Learning.
2.  **Hopfield Networks**: We saw how learnable excitatory connections create **associative memory**, allowing networks to store and recover patterns from noise.

These two examples demonstrate how the _structure_ of connections (topology) determines the _function_ of the network.
