# White noise tutorial

This is a notebook which simulates a white noise experiment.

In brief, we're going to simulate the neural response to a one dimensional visual stimulus (luminance with respect to the mean, or contrast) and then reconstruct the mechanism of the neural response using the techniques described in the handout by E.J. Chichilnisky.  

The neural response is simulated as a non-homogenous Poisson process whose rate parameter (which will be called `nonlin_resp` below) is a linear function of the stimulus (which will be called `lin_resp` below) put through a static, single-valued non-linearity.

In [None]:
import numpy as np
from scipy import stats, optimize
import matplotlib.pyplot as plt

We're going to run the experiment for 600 seconds, or 10 minutes.  The method works better with longer times, but this works well enough for the tutorial.

In [None]:
duration = 600 * 1000  # Time is measured in msecs
stim_samp_time = 10  # 100 Hz monitor

Our "monitor" (device which generates the stimulus) refreshes the display at 100 HZ, or once every 10 msecs.

At each frame, the intensity of the uniform screen changes: it is drawn randomly from a particular probability distribution. The exact nature of the probability distribution is not terribly important. However, it is important that 1) the distribution be symmetric about zero and that 2) the contrast on each frame is uncorrelated with the others (this is the definition of a "white" random process).

As an aside: For stimuli that vary in space and time, the distribution must be symmetric about the origin. For practical purposes, this means it should be Gaussian.

In [None]:
seed = sum(map(ord, "white noise"))
rng = np.random.default_rng(seed)

In [None]:
n_samples = duration // stim_samp_time
x = np.clip(rng.normal(0, 1 / 3, n_samples), -1, 1)
stimulus = np.repeat(x, stim_samp_time)

In [None]:
f, ax = plt.subplots(figsize=(8, 3))
ax.plot(stimulus[:1000])
ax.set(
    xlabel="Time (ms)",
    ylabel="Stimulus contrast",
    title="First 1 second of stimulus",
    xlim=(0, 1000),
    ylim=(-1, 1),
)
f.tight_layout()

---

## Linear response

Now compute the linear response using a 3 stage cascade of exponential lowpass filters. You could substitute other linear filters here if you wanted to, but this one is pretty simple. The 3rd row of y is the linear response. This takes a couple of minutes to calculate

In [None]:
lin_resp = np.zeros((duration, 3))
tau = 15
for i, s_i in enumerate(stimulus[:-1]):
    lin_resp[i + 1, 0] = lin_resp[i, 0] + (1 / tau) * (s_i - lin_resp[i, 0])
    lin_resp[i + 1, 1] = lin_resp[i, 1] + (1 / tau) * (lin_resp[i, 0] - lin_resp[i, 1])
    lin_resp[i + 1, 2] = lin_resp[i, 2] + (1 / tau) * (lin_resp[i, 1] - lin_resp[i, 2])

# Getting rid of the first- and second-order filtered signals, we only want the third one.
lin_resp = lin_resp[:, -1]

The linear response is just a lowpass filtered version of the stimulus. The top panel of the figure shows the first second of the stimulus, the small middle panel shows the impulse response of the linear filter, and the third panel shows the first second of the linear response.

In [None]:
f, (ax_stim, ax_resp) = plt.subplots(2, figsize=(8, 7), sharex=True)

ax_stim.plot(stimulus[:1000])
ax_resp.plot(lin_resp[:1000])

ax_stim.set(
    title="Stimulus",
    xlim=(0, 1000),
    ylim=(-1, 1),
    
)

ax_resp.set(
    title="Linear response",
    xlabel="Time (ms)",
    ylim=(-.35, .35),
)

f.tight_layout(h_pad=10)

times = np.arange(200)
impulse_resp = times ** 2 * np.exp(-times / tau)

ax_kern = f.add_axes([.1, .41, .2, .2])
ax_kern.plot(times, impulse_resp, color="r")
ax_kern.text(100, 50, "Imp. Resp.")
ax_kern.set(xticks=[], yticks=[]);

---

Non-linear response

Under the simple non-linear (SNL) model, the firing rate of a neuron is a single-valued non-linear function of an underlying linear response. We can pick any such function we want, and, for this example, we've decided on the cumulative Gaussian function (the integral of a Gaussian probability density function). The way we've parameterized it, this function has three arguments: the slope (alpha), the point of inflection (-beta/alpha), and the upper asymptote (gamma). This function relates the probability of firing to the linear response.

In [None]:
alpha = 25
beta = -2
gamma = .15

f, ax = plt.subplots()
r = np.linspace(-.25, .3)
ax.plot(r, gamma * stats.norm.cdf(alpha * r + beta), lw=5)
ax.set(
    xlim=(-.25, .3),
    ylim=(0, gamma),
    xlabel="Linear response",
    ylabel="P(firing)",
)
f.tight_layout()

Here we're applying this non-linear transformation on the linear response of our simulated neuron.

In [None]:
nonlin_resp = gamma * stats.norm.cdf(alpha * lin_resp + beta)

We can use this non-linear response to simulate a Poisson-ish spike train...

In [None]:
xr = rng.binomial(1, nonlin_resp)

...and we can count up the number of spikes fired by the neuron in each 10 msec wide bin (each screen refresh).

In [None]:
spike_counts = np.sum(np.split(xr, duration // stim_samp_time), axis=1)

So far, we constructed a white noise stimulus, we linearly filtered it, put this linearly filtered signal through a non-linear function to calculate an underlying firing rate of the cell, and used this underlying firing rate to simulate spikes coming out of the cell. Here's the first second of each of these functions

In [None]:
f, axes = plt.subplots(4, figsize=(8, 8), sharex=True)

axes[0].plot(lin_resp[:1000])
axes[1].plot(nonlin_resp[:1000])
axes[2].plot(xr[:1000])
axes[3].plot(np.arange(0, duration, stim_samp_time)[:100], spike_counts[:100])

axes[0].set(title="Linear response", ylim=(-.35, .35))
axes[1].set(title="Firing probability", ylim=(0, .2))
axes[2].set(title="# of spikes (1 ms bins)", ylim=(0, 2))
axes[3].set(title="# of spikes (10 ms bins)", ylim=(0, 5),
            xlabel="Time (ms)", xlim=(0, 1000))

f.tight_layout()

---

## Spike-triggered average

Now we compute the spike-triggered average stimulus. This is accomplished by taking the 300 milliseconds of stimulus immediately preceding each spike and adding them together. This sum is then divided by the total number of spikes fired over the course of the entire experiment to determine the average stimulus preceding a spike. This spike-triggered average is, in a sense, a template for what the neuron is "looking for" in the stimulus.

In [None]:
window_size = 300
a = np.zeros(window_size)
spike_times, = np.nonzero(xr)
for t in spike_times[spike_times > window_size]:
    a += stimulus[t - window_size:t]
a /= xr.sum()

f, ax = plt.subplots()
ax.plot(np.arange(-window_size, 0), a)
f.tight_layout()

A powerful result from the theory of white noise analysis states that if the neuron satisfies the assumptions of the SNL model, and the stimulus is drawn according to a distribution which is symmetric about the origin, the spike-triggered average converges, as the time of the experiment goes to infinity, to to the (time-reversed) impulse response of the linear part of the SNL system (up to an arbitrary scale factor). To the extent that neurons are well- modeled as SNL systems, this means that we can easily measure the linear component. Because this is a tutorial, we know exactly what filtering was done on the stimulus to get the linear response ("linearResp"). Recall that it was a cascade of three first-order exponential filters. Below, we compare the spike-triggered average to the impulse response of the filter we used.

In [None]:
times = np.arange(window_size)
impulse_resp = times ** 2 * np.exp(-times / tau)

f, ax = plt.subplots()
ax.plot(impulse_resp / impulse_resp.max())
ax.plot(a[::-1])
f.tight_layout()

There is a discrepancy between the impulse response of the filter and the (time-flipped) spike-triggered average. One is a scaled version of the other. There is an ambiguity between the magnitude of the linear response and the scale of the non-linear function (intuitively, we could get identical neural responses either by scaling the linear response or by scaling the subsequent non-linear function).

Here they are again, this time scaled to have the same energy

In [None]:
f, ax = plt.subplots()
ax.plot(impulse_resp / np.sqrt((impulse_resp ** 2).sum()))
ax.plot(a[::-1] / np.sqrt((a ** 2).sum()))
f.tight_layout()

Now we can estimate the non-linear function which relates the linear response to the probability of firing a spike. This is accomplished by plotting the estimated linear response versus the spike count in each 10 ms bin. It turns out that a simple scatter diagram is pretty uninformative because most of the points overlap each other...

In [None]:
linear_est = np.zeros(duration)
for t in range(window_size, duration):
    linear_est[t] = a @ stimulus[t - window_size:t]
    
linear_est = np.mean(np.split(linear_est, duration // stim_samp_time), axis=1)
linear_est /= stim_samp_time

f, ax = plt.subplots()
ax.scatter(linear_est, spike_counts,  linewidth=1, fc="none", ec="C0")
f.tight_layout()

We can clean this plot up by plotting the **average** number of spikes fired in response to similar linear responses.

In [None]:
# First we decide on linear response bins...
x = np.linspace(-.2, .3, 11)

# ... and then calculate the average
means = np.zeros_like(x)
serrs = np.zeros_like(x)
for i, x_i in enumerate(x):
    bins = (linear_est > (x_i - .025)) & (linear_est <= (x_i + .025))
    bin_counts = spike_counts[bins]
    means[i] = np.mean(bin_counts)
    serrs[i] = stats.sem(bin_counts)

f, ax = plt.subplots()
ax.errorbar(x, means, serrs, ls="", marker="o", ms=5, elinewidth=3, capsize=3)
ax.set(xlabel="Linear response component", ylabel="Mean spike count")
f.tight_layout()

In [None]:
# First we decide on linear response bins...
x = np.linspace(-.2, .3, 11)

# ... and then calculate the average
means = np.zeros_like(x)
serrs = np.zeros_like(x)
bins = np.digitize(linear_est, x + .025)
for i, _ in enumerate(x):
    bin_counts = spike_counts[bins == i]
    means[i] = np.mean(bin_counts)
    serrs[i] = stats.sem(bin_counts)

f, ax = plt.subplots()
ax.errorbar(x, means, serrs, ls="", marker="o", ms=5, elinewidth=3, capsize=3)
ax.set(xlabel="Linear response component", ylabel="Mean spike count")
f.tight_layout()

Now we can superimpose the non-linear function that we actually used to determine the spike firing probabilities. The plot of linear response versus mean spike count should have the same shape as this function, but remember, there was an arbitrary scale factor relating these two quantities. Below, we estimate this scale factor using least-squares.

In [None]:
y_unscaled = stats.norm.cdf(alpha * x + beta)
fit = optimize.lsq_linear(y_unscaled[:, None], means)
grid = np.linspace(-.2, .35, 100)
y = fit.x * stats.norm.cdf(alpha * grid + beta)

f, ax = plt.subplots()
ax.errorbar(x, means, serrs, ls="", marker="o", ms=5, elinewidth=3, capsize=3)
ax.plot(grid, y)
ax.set(xlabel="Linear response component", ylabel="Mean spike count")
f.tight_layout()