# Binning diagnostic: is $\Delta t$ appropriate?

Two quick checks on whether the time bin size is appropriate for the data:

1. **Fraction of spike-free bins** — if most time bins contain zero spikes across all neurons, the Poisson likelihood map is uninformative and the decoder relies entirely on the prior/smoother.

2. **Localization length** — using Skaggs spatial information per neuron (shuffle-corrected), we estimate the spatial precision the neural population can sustain given the animal's speed. This is bin-size-independent.

In [87]:
import jax.numpy as jnp
import numpy as np

from simpl.environment import Environment
from simpl.kde import gaussian_kernel, kde
from simpl.utils import load_datafile, prepare_data

## Load data

In [97]:
raw = load_datafile("gridcelldata.npz")
N_NEURONS = 1
Y = raw["Y"][:, :N_NEURONS]
Xb = raw["Xb"]
time = raw["time"]

data = prepare_data(Y=Y, Xb=Xb, time=time)

DX = 0.02
env = Environment(X=data.Xb.values, pad=0.0, bin_size=DX)

# Estimate speed from behaviour
dt = float(time[1] - time[0])
dX = np.diff(data.Xb.values, axis=0)
speed = np.linalg.norm(dX, axis=1) / dt
median_speed = float(np.median(speed))
print(f"Median speed from behaviour: {median_speed:.3f} m/s")

Created a 2D cuboid environment with dimensions ['x', 'y'] and discretised shape (50, 50)
Environment limits are ((np.float64(0.0), np.float64(0.0)), (np.float64(1.0), np.float64(1.0)))
The coords of each dimension are stored in self.coords_dict and a list of combined ['x', 'y'] coords for all bins is stored in self.discretised_coords
Median speed from behaviour: 0.086 m/s


## Check 1: Fraction of spike-free bins

The simplest possible check. If a time bin has zero spikes across all neurons, the Poisson log-likelihood map reduces to $-\sum_n F_n(x)$ — peaked where neurons fire *least*, which carries little positional information.

In [98]:
spikes = np.array(data.Y.values)  # (T, N_neurons)
dt = float(time[1] - time[0])

total_spikes_per_bin = spikes.sum(axis=1)
frac_silent = float(np.mean(total_spikes_per_bin == 0))

print(f"dt = {dt} s, N_neurons = {spikes.shape[1]}")
print(f"Mean spikes per time bin: {total_spikes_per_bin.mean():.2f}")
print(f"Fraction of spike-free bins: {frac_silent:.1%}")
print()
if frac_silent < 0.5:
    print("-> Healthy: most bins contain at least one spike.")
elif frac_silent < 0.8:
    print("-> Warning: majority of bins are silent. The Kalman smoother is doing heavy lifting.")
else:
    print("-> Problem: almost all bins are silent. Consider coarsening dt or adding neurons.")

dt = 0.1 s, N_neurons = 1
Mean spikes per time bin: 0.07
Fraction of spike-free bins: 92.7%

-> Problem: almost all bins are silent. Consider coarsening dt or adding neurons.


## Check 2: Can information keep up with the animal?

Not all spikes are equal — a spike from a spatially tuned neuron constrains position far more than one from a flat-firing neuron. We quantify this with **Skaggs spatial information** (bits/spike):

$$I_n = \sum_x \frac{r_n(x)}{\bar{r}_n} \log_2 \frac{r_n(x)}{\bar{r}_n} \, P(x)$$

where $r_n(x)$ is the firing rate of neuron $n$ at position $x$, $\bar{r}_n$ is its mean firing rate, and $P(x)$ is the occupancy. The population **information rate** is:

$$\dot{\mathcal{I}} = \sum_n \bar{r}_n \cdot I_n \qquad \text{(bits/s)}$$

### Deriving the localization length

Over a time window $\tau$, the decoder accumulates $\dot{\mathcal{I}} \cdot \tau$ bits. In a $D$-dimensional environment of side $L$, resolving position to precision $\ell$ requires distinguishing $(L/\ell)^D$ possible locations, which costs $D \log_2(L/\ell)$ bits. The decoder can localise to precision $\ell$ if:

$$\dot{\mathcal{I}} \cdot \tau \geq D \log_2 \frac{L}{\ell}$$

The natural timescale is self-consistent: $\tau = \ell / v$, the time for the animal (moving at speed $v$) to cross the resolution element itself. Longer than this and the animal has moved; shorter and we're asking for more precision than the window supports. Substituting:

$$\dot{\mathcal{I}} \cdot \frac{\ell}{v} = D \log_2 \frac{L}{\ell}$$

To solve, note that $\log_2(L/\ell) = \ln(L/\ell) / \ln 2$. For the self-consistent solution where the decoder is just barely keeping up, the dominant balance gives:

$$\ell^* = \frac{D \, v \, \ln 2}{\dot{\mathcal{I}}}$$

(This is exact when $\ell^* \ll L$, where $\log_2(L/\ell)$ varies slowly compared to $\ell$ itself.)

### The localization speed

Since $\ell^* \propto v$, faster movement means coarser resolution. There is a critical speed $v^*$ at which the localization length reaches the environment size ($\ell^* = L$), i.e. the decoder can barely localise at all:

$$L = \frac{D \, v^* \, \ln 2}{\dot{\mathcal{I}}} \quad \Longrightarrow \quad v^* = \frac{\dot{\mathcal{I}} \, L}{D \, \ln 2}$$

The diagnostic ratio is:

$$\frac{v^*}{v} = \frac{L}{\ell^*}$$

- $v^*/v \gg 1$: information arrives much faster than the animal moves — comfortable margin for decoding.
- $v^*/v \gtrsim 1$: marginal — the decoder can just keep up.
- $v^*/v < 1$: the animal outpaces the information — the decoder cannot localise.

In [99]:
# Estimate place fields from behavioural trajectory (epoch-0 M-step)
bins = jnp.array(env.flattened_discretised_coords)
X = jnp.array(data.Xb.values)
spikes_jnp = jnp.array(data.Y.values)

F, PX = kde(
    bins=bins,
    trajectory=X,
    spikes=spikes_jnp,
    kernel=gaussian_kernel,
    kernel_bandwidth=0.02,
    return_position_density=True,
)


def compute_spatial_info(r, r_mean, PX):
    """Skaggs spatial information (bits/spike) per neuron."""
    eps = 1e-10
    ratio = r / (r_mean[:, None] + eps)
    return jnp.clip(jnp.sum(ratio * jnp.log2(ratio + eps) * PX[None, :], axis=1), 0, None)


# Spatial information per neuron
r = F / dt  # firing rates in Hz
r_mean = jnp.sum(r * PX[None, :], axis=1)
spatial_info = compute_spatial_info(r, r_mean, PX)

# Information rate (bits/s)
info_rate = float(jnp.sum(r_mean * spatial_info))

# Localization speed and ratio
D = env.D
L = float(np.mean([env.lims[1][d] - env.lims[0][d] for d in range(D)]))
loc_speed = info_rate * L / (D * np.log(2))
ratio = loc_speed / median_speed

print(f"Info rate:             {info_rate:.1f} bits/s")
print(f"Median animal speed:   {median_speed:.3f} m/s")
print(f"Localization speed:    {loc_speed:.3f} m/s")
print(f"v*/v:                  {ratio:.1f}")
print()
if ratio > 1:
    print(f"-> Information can keep up with the animal ({ratio:.1f}x faster).")
else:
    print(f"-> Warning: the animal outpaces the information ({ratio:.1f}x). Consider coarsening dt or adding neurons.")

Info rate:             0.2 bits/s
Median animal speed:   0.086 m/s
Localization speed:    0.123 m/s
v*/v:                  1.4

-> Information can keep up with the animal (1.4x faster).
