
# SPADE (Spatio-Temporal Pattern Detection) — Step-by-step Notebook

This notebook re-creates the SPADE pipeline (binning → transactions → closed frequent patterns → significance via surrogate-based *pattern spectrum* → pattern-set reduction).
It assumes your spikes are provided as a NumPy array of shape `(2, n)` where the first row contains spike times (in **milliseconds**) and the second row contains integer neuron IDs in `[0, N-1]`.


In [1]:

# === Imports ===
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from dataclasses import dataclass
from typing import List, Dict, Tuple, Set, Iterable, Optional
import math
np.set_printoptions(edgeitems=5, linewidth=120, suppress=True)


In [3]:

# (Optional) Synthetic data generator for quick testing.
# If you already have `spikes` as a (2, n) array, skip this cell.
def make_synthetic_spikes(N=20, duration_ms=5000, rate_hz=5.0, seed=0,
                          embed_pattern=True, pattern_size=4, pattern_lag_ms=(0,3,7,12),
                          repetitions=8, base_time_ms=500, rep_stride_ms=400):
    rng = np.random.default_rng(seed)
    # Homogeneous Poisson spikes per neuron
    spikes = []
    for i in range(N):
        # Expected number of spikes:
        lam = rate_hz * (duration_ms / 1000.0)
        n_i = rng.poisson(lam)
        times = rng.uniform(0, duration_ms, size=n_i)
        for t in times:
            spikes.append((t, i))
    # Embed a repeating pattern (same relative lags across a subset of neurons)
    if embed_pattern:
        neurons = rng.choice(N, size=pattern_size, replace=False)
        for r in range(repetitions):
            t0 = base_time_ms + r * rep_stride_ms
            for j, lag in enumerate(pattern_lag_ms[:pattern_size]):
                t = t0 + lag
                if 0 <= t < duration_ms:
                    spikes.append((t, int(neurons[j])))
    spikes = np.array(spikes, dtype=float)
    # Sort by time
    order = np.argsort(spikes[:,0])
    spikes = spikes[order]
    # Shape to (2, n): row0 = times_ms, row1 = neuron_ids
    spikes_out = np.vstack([spikes[:,0], spikes[:,1].astype(int)]).astype(float)
    return spikes_out

# Example: comment out if you already have your own `spikes`
spikes = make_synthetic_spikes()
spikes.shape


(2, 529)

In [4]:

# === Parameters (edit these) ===
# If you already have `spikes` defined as (2, n) [times_ms, neuron_ids], leave it as is.
# Otherwise, uncomment the generator above or load your data here.

# Required: bin size and window length for pattern mining
dt_ms = 1.0                 # raster bin width (ms)
window_bins = 20            # K: number of bins in each sliding window
min_support = 2             # minimum number of window occurrences for a pattern to be considered frequent
max_itemset_size = None        # optional cap on pattern size to keep mining tractable (None for no cap)

# Surrogate / significance settings
n_surrogates = 20           # number of dither surrogates for pattern-spectrum (keep small for demo; raise for real use)
dither_halfwidth_bins = 1   # J: +/- jitter range in bins (e.g., 5 bins = +/- 5 ms if dt_ms=1)
alpha = 0.05                # FDR target for PSF; used also as per-test threshold in PSR demo
psr_h = 0                   # SPADE paper uses h=0
psr_k = 2                   # SPADE paper uses k=2

# If you generated synthetic data above, ensure `spikes` exists.
try:
    spikes  # noqa: F821
except NameError:
    print("Note: `spikes` is not defined. Either run the synthetic generator cell or load your data into a (2, n) array named `spikes`.")


In [5]:

def bin_spikes(spikes_2xn: np.ndarray, dt_ms: float, t_start_ms: Optional[float]=None, t_end_ms: Optional[float]=None):
    """
    Convert (2, n) spikes [times_ms; neuron_id] to a binary raster of shape (N, B).
    - Enforces at most one spike per neuron per bin via clipping to {0,1}.
    Returns: raster (bool array), bin_edges_ms (length B+1), neuron_count N
    """
    assert spikes_2xn.shape[0] == 2, "Input must be (2, n): [times_ms; neuron_ids]"
    times_ms = spikes_2xn[0]
    neuron_ids = spikes_2xn[1].astype(int)
    N = int(neuron_ids.max()) + 1 if neuron_ids.size > 0 else 0
    if t_start_ms is None:
        t_start_ms = 0.0 if times_ms.size == 0 else float(np.floor(times_ms.min()))
    if t_end_ms is None:
        t_end_ms = float(np.ceil(times_ms.max())) if times_ms.size > 0 else 0.0
    if t_end_ms < t_start_ms:
        t_end_ms = t_start_ms
    B = int(np.ceil((t_end_ms - t_start_ms) / dt_ms))
    bin_edges_ms = t_start_ms + np.arange(B+1) * dt_ms
    raster = np.zeros((N, B), dtype=bool)
    if times_ms.size == 0 or N == 0 or B == 0:
        return raster, bin_edges_ms, N
    # Bin indices
    bin_idx = ((times_ms - t_start_ms) / dt_ms).astype(int)
    mask = (bin_idx >= 0) & (bin_idx < B) & (neuron_ids >= 0) & (neuron_ids < N)
    bin_idx = bin_idx[mask]
    neuron_ids = neuron_ids[mask]
    # Set raster; clip multi-hits to 1
    raster[neuron_ids, bin_idx] = True
    return raster, bin_edges_ms, N

# Example usage (if you have `spikes`):
# raster, bin_edges_ms, N = bin_spikes(spikes, dt_ms)
# raster.shape, N


In [6]:

def make_transactions(raster: np.ndarray, window_bins: int) -> Tuple[List[Set[int]], int]:
    """
    Build transactions from raster using a sliding window of `window_bins` bins.
    Each transaction is a set of integer-encoded attributes a = i*K + kappa where
    i = neuron index, kappa = local time within the window [0..K-1].
    Returns:
      - transactions: list of sets
      - K: the window_bins (redundant, but helpful for decoding)
    """
    N, B = raster.shape
    K = int(window_bins)
    if K <= 0:
        return [], K
    transactions: List[Set[int]] = []
    # For each window start tau in [0 .. B-K]
    for tau in range(0, B - K + 1):
        # Collect (i, kappa) where spike present
        attrs = set()
        # Find indices where any spike in window
        # Vectorized: find all (i,b) spikes, then map to kappa=b-tau in [0..K-1]
        # But to keep it readable:
        window = raster[:, tau:tau+K]
        is_i, is_kappa = np.where(window)
        for i, kappa in zip(is_i, is_kappa):
            attr = int(i * K + kappa)
            attrs.add(attr)
        transactions.append(attrs)
    return transactions, K

def decode_attr(attr: int, K: int) -> Tuple[int, int]:
    return (attr // K, attr % K)

# Example (after binning):
# tx, K = make_transactions(raster, window_bins)
# len(tx), K, len(tx[0]) if tx else None


In [7]:

@dataclass(frozen=True)
class Pattern:
    items: frozenset  # of int attributes
    extent: np.ndarray  # window indices (occurrence windows)
    size: int           # z = |items|
    support: int        # c = |extent|

def _build_inverted_index(transactions: List[Set[int]]) -> Dict[int, np.ndarray]:
    """Return dict: item -> sorted array of transaction indices where item occurs."""
    item_to_idxs: Dict[int, List[int]] = {}
    for t_idx, trans in enumerate(transactions):
        for a in trans:
            item_to_idxs.setdefault(a, []).append(t_idx)
    return {a: np.array(sorted(v), dtype=int) for a, v in item_to_idxs.items()}

def _intersect_extents(extents: Iterable[np.ndarray]) -> np.ndarray:
    extents = list(extents)
    if not extents:
        return np.array([], dtype=int)
    # Start with the smallest to intersect efficiently
    extents.sort(key=lambda x: x.size)
    res = extents[0]
    for e in extents[1:]:
        res = np.intersect1d(res, e, assume_unique=True)
        if res.size == 0:
            break
    return res

def apriori_closed_patterns(transactions: List[Set[int]], min_support: int, max_size: Optional[int]=None) -> List[Pattern]:
    """
    Simple Apriori miner that returns only **closed** frequent itemsets as Pattern objects.
    - Uses extent intersections to compute supports efficiently.
    - After mining all sizes, filters non-closed sets (those with a strict superset with same support).
    """
    if min_support <= 1 and not transactions:
        return []
    item_to_idxs = _build_inverted_index(transactions)
    # L1 frequent singletons
    Lk = []
    for a, idxs in item_to_idxs.items():
        if idxs.size >= min_support:
            Lk.append((frozenset([a]), idxs))
    Lk = sorted(Lk, key=lambda x: (len(x[0]), -x[1].size))
    all_levels = [Lk]
    k = 1
    # Generate higher-order candidates
    while Lk:
        if max_size is not None and k >= max_size:
            break
        # Join step: unify pairs sharing k-1 items prefix
        Ck1 = []
        Lk_sorted = sorted(Lk, key=lambda x: tuple(sorted(x[0])))
        for i in range(len(Lk_sorted)):
            for j in range(i+1, len(Lk_sorted)):
                A, extA = Lk_sorted[i]
                B, extB = Lk_sorted[j]
                A_sorted = tuple(sorted(A))
                B_sorted = tuple(sorted(B))
                if A_sorted[:k-1] != B_sorted[:k-1]:
                    break
                candidate = A.union(B)
                if len(candidate) != k+1:
                    continue
                # Prune: all (k)-subsets must be frequent (Apriori)
                # (We can skip explicit check since we join only from Lk.)
                ext = _intersect_extents([extA, extB])
                if ext.size >= min_support:
                    Ck1.append((candidate, ext))
        # Deduplicate candidates
        seen = {}
        for items, ext in Ck1:
            key = tuple(sorted(items))
            if key not in seen:
                seen[key] = (items, ext)
        Lk = sorted(seen.values(), key=lambda x: (len(x[0]), -x[1].size))
        if Lk:
            all_levels.append(Lk)
        k += 1

    # Flatten all frequent itemsets
    all_freq = []
    for level in all_levels:
        for items, ext in level:
            all_freq.append((items, ext))

    # Filter to closed: remove any itemset that has a strict superset with the same support
    # Build support -> list of itemsets
    sup_to_sets: Dict[int, List[frozenset]] = {}
    for items, ext in all_freq:
        sup_to_sets.setdefault(ext.size, []).append(items)

    closed = []
    for items, ext in all_freq:
        is_closed = True
        for sup_items in sup_to_sets[ext.size]:
            if items < sup_items:  # strict subset
                is_closed = False
                break
        if is_closed:
            closed.append(Pattern(frozenset(items), ext, len(items), ext.size))
    # Sort by (size desc, support desc)
    closed.sort(key=lambda p: (-p.size, -p.support, tuple(sorted(p.items))))
    return closed

# Example (after transactions):
# patterns = apriori_closed_patterns(tx, min_support, max_itemset_size)
# len(patterns), patterns[0].size, patterns[0].support if patterns else None


In [8]:

def decode_pattern_items(pattern: Pattern, K: int) -> List[Tuple[int, int]]:
    return [decode_attr(a, K) for a in sorted(pattern.items)]

def pattern_signature(pattern: Pattern) -> Tuple[int, int]:
    return (pattern.size, pattern.support)

def summarize_patterns(patterns: List[Pattern], K: int, max_show: int=10):
    print(f"Found {len(patterns)} closed frequent patterns.")
    for idx, p in enumerate(patterns[:max_show], 1):
        pairs = decode_pattern_items(p, K)
        print(f"{idx:2d}) size z={p.size:2d}, support c={p.support:3d}, items (i,kappa)={pairs}")


In [9]:

def plot_raster(raster: np.ndarray, dt_ms: float, t_start_ms: float=0.0,
                pattern: Optional[Pattern]=None, K: Optional[int]=None, highlight_extent: Optional[np.ndarray]=None):
    """
    Simple raster plot (one figure). Optionally highlights the occurrence windows of a given pattern.
    - highlight_extent: array of window start indices (tau indices) where the pattern occurs.
    """
    N, B = raster.shape
    fig, ax = plt.subplots(figsize=(10, 4))
    # Scatter spikes
    is_i, is_b = np.where(raster)
    ax.scatter(is_b * dt_ms + t_start_ms, is_i, s=4)
    # Highlight occurrence windows (if provided)
    if pattern is not None and K is not None and highlight_extent is not None and highlight_extent.size > 0:
        for tau in highlight_extent:
            x0 = (tau * dt_ms) + t_start_ms
            x1 = x0 + K * dt_ms
            ax.axvspan(x0, x1, alpha=0.15)
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Neuron ID")
    ax.set_title("Raster")
    plt.show()


In [10]:

def dither_raster(raster: np.ndarray, halfwidth_bins: int, rng: Optional[np.random.Generator]=None) -> np.ndarray:
    """
    Spike dithering on raster: shift each spike independently by a uniform integer in [-J, +J] bins.
    Clips at edges and enforces boolean spikes (collisions are fine -> remain True).
    """
    if rng is None:
        rng = np.random.default_rng()
    N, B = raster.shape
    if halfwidth_bins <= 0:
        return raster.copy()
    is_i, is_b = np.where(raster)
    shifts = rng.integers(-halfwidth_bins, halfwidth_bins+1, size=is_i.size)
    new_b = np.clip(is_b + shifts, 0, B-1)
    sur = np.zeros_like(raster, dtype=bool)
    sur[is_i, new_b] = True
    return sur


In [11]:

def pattern_spectrum_pvals(raster: np.ndarray, window_bins: int, min_support: int, max_itemset_size: Optional[int],
                           n_surrogates: int, dither_halfwidth_bins: int, rng: Optional[np.random.Generator]=None) -> Dict[Tuple[int,int], float]:
    """
    Compute SPADE-like pattern spectrum:
      p(z,c) = fraction of surrogates that produce at least one closed frequent pattern with signature (z,c).
    Uses simple Apriori miner above for demo purposes.
    Returns dict mapping (z,c) -> p-value (with +1/(K+1) smoothing to avoid zeros).
    """
    if rng is None:
        rng = np.random.default_rng()
    K = window_bins
    B = raster.shape[1]
    # Collect signatures across surrogates
    sig_hits: Dict[Tuple[int,int], int] = {}
    for _ in tqdm(range(n_surrogates), desc="Surrogates"):
        sur = dither_raster(raster, dither_halfwidth_bins, rng)
        tx, _ = make_transactions(sur, K)
        if not tx:
            continue
        pats = apriori_closed_patterns(tx, min_support, max_itemset_size)
        seen = set(pattern_signature(p) for p in pats)
        for sig in seen:
            sig_hits[sig] = sig_hits.get(sig, 0) + 1
    # Convert to p-values with +1 smoothing
    pvals = {}
    for sig, count in sig_hits.items():
        pvals[sig] = (count + 1) / (n_surrogates + 1)
    # Any signature not seen in surrogates -> assign p=1 by default (conservative) upon lookup
    return pvals


In [12]:

def benjamini_hochberg(pvals: Dict[Tuple[int,int], float], alpha: float) -> Set[Tuple[int,int]]:
    """Return the set of signatures passing BH-FDR at level alpha."""
    if not pvals:
        return set()
    items = sorted(pvals.items(), key=lambda kv: kv[1])
    m = len(items)
    passed = set()
    p_star = 0.0
    k_star = 0
    for k, (sig, p) in enumerate(items, start=1):
        if p <= alpha * k / m:
            p_star = p
            k_star = k
    for k, (sig, p) in enumerate(items, start=1):
        if k <= k_star:
            passed.add(sig)
    return passed

def psf_filter(patterns: List[Pattern], p_spectrum: Dict[Tuple[int,int], float], alpha: float) -> List[Pattern]:
    """Keep patterns whose signature is significant under BH-FDR based on the surrogate spectrum."""
    # Build p-values only for signatures present in original patterns
    orig_sigs = {pattern_signature(p) for p in patterns}
    pvals = {sig: p_spectrum.get(sig, 1.0) for sig in orig_sigs}
    sig_keep = benjamini_hochberg(pvals, alpha)
    kept = [p for p in patterns if pattern_signature(p) in sig_keep]
    return kept


In [13]:

def psr_reduction(patterns: List[Pattern], p_spectrum: Dict[Tuple[int,int], float], alpha: float,
                  h: int=0, k: int=2) -> List[Pattern]:
    """
    Reduce overlapping patterns via SPADE-like conditional tests.
    For each pair (A, B) with B ⊂ A:
      - test A|B with signature (z_A - z_B + h, c_A)
      - test B|A with signature (z_B, c_B - c_A + k)
    Keep those that remain significant (p <= alpha). If both fail, keep the one with larger z*c (tie-break on z).
    Note: This is a single-pass greedy reduction for demonstration.
    """
    keep = set(range(len(patterns)))
    def p_of(sig): return p_spectrum.get(sig, 1.0)
    for i in range(len(patterns)):
        if i not in keep: 
            continue
        Ai = patterns[i]
        for j in range(len(patterns)):
            if i == j or j not in keep:
                continue
            Bj = patterns[j]
            # Only compare if Bj is a strict subset of Ai (over items)
            if not Bj.items.issubset(Ai.items):
                continue
            if Bj.items == Ai.items:
                continue
            zA, cA = Ai.size, Ai.support
            zB, cB = Bj.size, Bj.support
            sig_A_given_B = (max(1, zA - zB + h), cA)
            sig_B_given_A = (zB, max(1, cB - cA + k))
            pA = p_of(sig_A_given_B)
            pB = p_of(sig_B_given_A)
            A_sig = pA <= alpha
            B_sig = pB <= alpha
            if A_sig and not B_sig:
                # Keep A, drop B
                keep.discard(j)
            elif B_sig and not A_sig:
                # Keep B, drop A
                keep.discard(i)
                break
            elif (not A_sig) and (not B_sig):
                # Keep the one with larger z*c
                scoreA = zA * cA
                scoreB = zB * cB
                if scoreA > scoreB or (scoreA == scoreB and zA >= zB):
                    keep.discard(j)
                else:
                    keep.discard(i)
                    break
            # else: both significant -> keep both
    return [patterns[idx] for idx in sorted(list(keep))]


In [14]:

def run_spade_pipeline(spikes: np.ndarray, dt_ms: float, window_bins: int,
                       min_support: int, max_itemset_size: Optional[int],
                       n_surrogates: int, dither_halfwidth_bins: int,
                       alpha: float, psr_h: int=0, psr_k: int=2,
                       t_start_ms: Optional[float]=None, t_end_ms: Optional[float]=None,
                       verbose: bool=True):
    # 1) Rasterize
    raster, bin_edges_ms, N = bin_spikes(spikes, dt_ms, t_start_ms, t_end_ms)
    if verbose:
        print(f"Raster: N={N} neurons, B={raster.shape[1]} bins, dt={dt_ms} ms, window_bins={window_bins}")
    # 2) Transactions
    tx, K = make_transactions(raster, window_bins)
    if verbose:
        print(f"Transactions: {len(tx)} windows; average items/window ≈ {np.mean([len(t) for t in tx]) if tx else 0:.2f}")
    # 3) Mine closed frequent patterns
    patterns = apriori_closed_patterns(tx, min_support, max_itemset_size)
    if verbose:
        summarize_patterns(patterns, K, max_show=10)
    # 4) Pattern spectrum from surrogates
    p_spectrum = pattern_spectrum_pvals(raster, window_bins, min_support, max_itemset_size,
                                        n_surrogates, dither_halfwidth_bins)
    # 5) PSF filtering
    psf_kept = psf_filter(patterns, p_spectrum, alpha)
    if verbose:
        print(f"PSF-kept patterns: {len(psf_kept)}")
        summarize_patterns(psf_kept, K, max_show=10)
    # 6) PSR reduction
    final_patterns = psr_reduction(psf_kept, p_spectrum, alpha, h=psr_h, k=psr_k)
    if verbose:
        print(f"Final patterns after PSR: {len(final_patterns)}")
        summarize_patterns(final_patterns, K, max_show=10)
    return {
        "raster": raster,
        "bin_edges_ms": bin_edges_ms,
        "transactions": tx,
        "K": K,
        "patterns_all": patterns,
        "p_spectrum": p_spectrum,
        "patterns_psf": psf_kept,
        "patterns_final": final_patterns,
    }


In [15]:

# === Example usage ===
# (If you did not generate or load `spikes`, uncomment the generator in Cell 3.)
spikes = make_synthetic_spikes()

results = run_spade_pipeline(
    spikes=spikes,
    dt_ms=dt_ms,
    window_bins=window_bins,
    min_support=min_support,
    max_itemset_size=max_itemset_size,
    n_surrogates=n_surrogates,
    dither_halfwidth_bins=dither_halfwidth_bins,
    alpha=alpha,
    psr_h=psr_h,
    psr_k=psr_k,
    verbose=True
)
raster = results["raster"]
K = results["K"]
final_patterns = results["patterns_final"]
if final_patterns:
    # Visualize raster with highlighted windows for the top pattern
    plot_raster(raster, dt_ms, pattern=final_patterns[0], K=K, highlight_extent=final_patterns[0].extent)


Raster: N=20 neurons, B=4997 bins, dt=1.0 ms, window_bins=20
Transactions: 4978 windows; average items/window ≈ 2.11
Found 991 closed frequent patterns.
 1) size z= 4, support c=  8, items (i,kappa)=[(6, 7), (7, 3), (10, 0), (14, 12)]
 2) size z= 4, support c=  8, items (i,kappa)=[(6, 8), (7, 4), (10, 1), (14, 13)]
 3) size z= 4, support c=  8, items (i,kappa)=[(6, 9), (7, 5), (10, 2), (14, 14)]
 4) size z= 4, support c=  8, items (i,kappa)=[(6, 10), (7, 6), (10, 3), (14, 15)]
 5) size z= 4, support c=  8, items (i,kappa)=[(6, 11), (7, 7), (10, 4), (14, 16)]
 6) size z= 4, support c=  8, items (i,kappa)=[(6, 12), (7, 8), (10, 5), (14, 17)]
 7) size z= 4, support c=  8, items (i,kappa)=[(6, 13), (7, 9), (10, 6), (14, 18)]
 8) size z= 4, support c=  8, items (i,kappa)=[(6, 14), (7, 10), (10, 7), (14, 19)]
 9) size z= 3, support c=  8, items (i,kappa)=[(6, 4), (7, 0), (14, 9)]
10) size z= 3, support c=  8, items (i,kappa)=[(6, 5), (7, 1), (14, 10)]


Surrogates: 100%|██████████████████████████████████████████████████████████████████████| 20/20 [00:13<00:00,  1.43it/s]

PSF-kept patterns: 0
Found 0 closed frequent patterns.
Final patterns after PSR: 0
Found 0 closed frequent patterns.






**Notes & caveats**

- This notebook uses a simple **Apriori** miner for clarity. SPADE uses efficient closed itemset mining (e.g., FP-growth / Eclat variants). For large data, consider replacing `apriori_closed_patterns` with a faster miner.
- The surrogate-based **pattern spectrum** here uses dithering on the raster (`±J` bins). Tune `dither_halfwidth_bins` to your physiological hypothesis (e.g., ±5 ms to disrupt millisecond precision).
- The **PSR** step is implemented as a single greedy pass; SPADE describes reciprocal conditional testing. For publication-grade analysis, implement full iterative reduction.
- All parameters (`dt_ms`, `window_bins`, `min_support`, `n_surrogates`, `alpha`, `h`, `k`) should be adapted to your dataset and compute budget.
