# Goal

For each pid with at least 1 region containing >30 neurons, train linear decoder, obtain stats, save to drive

Each pid info is saved as follows after loading:
```
{
    "pid": pid,
    "eid": eid,
    "n_neurons": int,
    "regions": dict,
    "spike_matrix": np.ndarray,     # (neurons × timebins)
    "bin_edges": np.ndarray,
    "bin_centers": np.ndarray,
    "whisker_motion": np.ndarray,   # raw energy trace resampled to bins
    "cluster_regions": np.ndarray,  # region acronym per neuron
    "cluster_ids": np.ndarray,      # neuron ids
}
```

| Dataset                                  | Purpose                                                                        |
| ---------------------------------------- | ------------------------------------------------------------------------------ |
| `whisker_motion_raw`                     | raw, unmodified motion energy for visualization or inverse transform           |
| `whisker_motion_clean`                   | log-transformed, z-scored motion target for modeling                           |
| `transform_params`                       | to reconstruct predictions to physical scale                                   |
| `X_train`, `X_test`, `y_train`, `y_test` | modeling datasets                                                              |
| `meta`                                   | dictionary with fields like `pid`, `n_neurons`, `regions`, `lags`, `cut`, etc. |


# Setup

In [None]:
! pip install ONE-api
! pip install ibllib

from one.api import ONE
ONE.setup(base_url='https://openalyx.internationalbrainlab.org', silent=True)
one = ONE(password='international')

from one.api import ONE
one = ONE()

Collecting ONE-api
  Downloading one_api-3.4.1-py3-none-any.whl.metadata (4.2 kB)
Collecting iblutil>=1.14.0 (from ONE-api)
  Downloading iblutil-1.20.0-py3-none-any.whl.metadata (1.6 kB)
Collecting boto3 (from ONE-api)
  Downloading boto3-1.40.73-py3-none-any.whl.metadata (6.8 kB)
Collecting colorlog>=6.0.0 (from iblutil>=1.14.0->ONE-api)
  Downloading colorlog-6.10.1-py3-none-any.whl.metadata (11 kB)
Collecting botocore<1.41.0,>=1.40.73 (from boto3->ONE-api)
  Downloading botocore-1.40.73-py3-none-any.whl.metadata (5.9 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from boto3->ONE-api)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Collecting s3transfer<0.15.0,>=0.14.0 (from boto3->ONE-api)
  Downloading s3transfer-0.14.0-py3-none-any.whl.metadata (1.7 kB)
Downloading one_api-3.4.1-py3-none-any.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading iblutil-1.20.0-py3-none-any.whl (43

In [None]:
import os
import gc
import json
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from pathlib import Path
from sklearn.metrics import r2_score

from scipy.ndimage import gaussian_filter1d
from sklearn.preprocessing import StandardScaler

from collections import Counter

from one.api import ONE
from brainbox.io.one import SessionLoader, SpikeSortingLoader
import matplotlib.pyplot as plt
from iblatlas.atlas import AllenAtlas
import numpy as np
import traceback

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ---------- Config ----------
ALPHA = 0.5                     # Ridge regularization
TEST_FRAC = 0.2                 # 80/20 temporal split
LAGS = np.arange(0, 30)
MIN_NEURONS_PER_REGION = 30     # at least one region with >= 30 neurons
MIN_TOTAL_NEURONS = 30          # sanity floor; effectively redundant given above

RANDOM_STATE = 0

OUTPUT_DIR = "/content/drive/MyDrive/S25/Langone/Breathing/experiments/decoding/"

In [None]:
# Final summary CSV (single giant table)
SUMMARY_CSV = os.path.join(OUTPUT_DIR, "ibl_neural_decoding_summary.csv")
CANDIDATES_CSV = os.path.join(OUTPUT_DIR, "ibl_eid_pid_candidates.csv")

In [None]:
SUMMARY_COLUMNS = [
    "pid",
    "n_neurons",              # total neurons used in this decoding
    "regions",                # JSON: {region_acronym: n_neurons, ...}

    "train_R2_all",
    "train_R_all",
    "test_R2_all",
    "test_R_all",

    # JSON: {region: {"R2": 0.123, "R": 0.456}, ...}
    "train_region_metrics",
    "test_region_metrics",
]

# helper functions

## utilities

In [None]:
# ============================================================================
# UTILITIES
# ============================================================================

def print_skip_messages(captured_output):
    """Extract and print only skip/warning messages."""
    for line in captured_output.split('\n'):
        if '[SKIP]' in line or '[WARN]' in line:
            print(line)

def round5(x):
    """Round to 5 decimal places. Returns None for NaN/None."""
    if x is None:
        return None
    try:
        if np.isnan(x):
            return None
    except (TypeError, ValueError):
        pass
    return float(np.round(x, 5))

def to_json(obj):
    """Convert dict to JSON string for CSV storage."""
    return json.dumps(obj, separators=(",", ":"), sort_keys=True)

def get_all_pids(one, project=None):
    """
    Return a list of pids that have spike sorting (spikes.times present).
    Optionally restrict by project if you only care about a specific IBL dataset,
    e.g. project='ibl_neuropixel_brainwide_01'.
    """
    search_kwargs = dict(dataset='spikes.times')
    if project is not None:
        search_kwargs['project'] = project

    try:
        pids = one.search_insertions(**search_kwargs)
    except AttributeError:
        # Fallback: query Alyx directly
        insertions = one.alyx.rest(
            'insertions', 'list',
            dataset='spikes.times',
            project=project
        )
        pids = [ins['id'] for ins in insertions]

    print(f"[INFO] Found {len(pids)} pids with spikes.times")
    return list(sorted(set(pids)))


## data loading

In [None]:
# ============================================================================
# DATA LOADING
# ============================================================================

def load_session_data(pid, one, ba, min_neurons=30):
    """
    Load spike and whisker motion data for one probe insertion.
    Returns dict with spike_matrix, whisker_motion (raw & clean), and metadata.
    Returns None if session should be skipped.
    """
    try:
        print(f"Processing pid: {pid}")

        # Get experiment ID
        eid_info = one.pid2eid(pid)
        eid = eid_info[0] if isinstance(eid_info, (tuple, list)) else eid_info

        # Load spike sorting
        ssl = SpikeSortingLoader(one=one, pid=pid, atlas=ba)
        spikes, clusters, channels = ssl.load_spike_sorting()
        clusters = ssl.merge_clusters(spikes, clusters, channels)

        # Filter good clusters
        good_clusters = np.where(clusters['label'] == 1)[0]
        if len(good_clusters) == 0:
            print(f"[SKIP] pid={pid}: no good clusters.", flush=True)
            return None

        cluster_regions = clusters['acronym'][good_clusters]
        region_counts = dict(Counter(cluster_regions))
        total_units = len(good_clusters)
        top_region, top_count = max(region_counts.items(), key=lambda kv: kv[1])

        if top_count < min_neurons:
            print(f"[SKIP] pid={pid} n_neurons={total_units} max_region={top_region} ({top_count})", flush=True)
            return None

        # Load whisker motion
        try:
            sl = SessionLoader(one=one, eid=eid)
            sl.load_motion_energy(views=['left'])
            whisker = sl.motion_energy['leftCamera']
            whisker_times = np.asarray(whisker['times'])
            whisker_trace = np.asarray(whisker['whiskerMotionEnergy'])

            mask = np.isfinite(whisker_times)
            whisker_times, whisker_trace = whisker_times[mask], whisker_trace[mask]
            whisker_times, unique_idx = np.unique(whisker_times, return_index=True)
            whisker_trace = whisker_trace[unique_idx]
        except Exception as e:
            print(f"[WARN] pid={pid}: whisker motion load failed - {e}", flush=True)
            return None

        # Bin spikes
        bin_size = np.median(np.diff(whisker_times))
        if not np.isfinite(bin_size) or bin_size <= 0:
            print(f"[WARN] pid={pid}: invalid bin size.", flush=True)
            return None

        bin_edges = np.arange(whisker_times[0], whisker_times[-1] + bin_size, bin_size)
        bin_centers = bin_edges[:-1] + np.diff(bin_edges) / 2
        n_bins = len(bin_edges) - 1

        spike_matrix = np.zeros((len(good_clusters), n_bins), dtype=np.uint16)
        for i, cid in enumerate(good_clusters):
            spike_times = spikes['times'][spikes['clusters'] == cid]
            if len(spike_times) > 0:
                counts = np.histogram(spike_times, bins=bin_edges)[0]
                spike_matrix[i] = np.clip(counts, 0, 65535).astype(np.uint16)

        # Resample and align whisker motion
        whisker_sync = resample_to_bins(whisker_times, whisker_trace, bin_centers)
        n_frames = min(spike_matrix.shape[1], len(whisker_sync))
        spike_matrix = spike_matrix[:, :n_frames]
        whisker_sync = whisker_sync[:n_frames]
        times = bin_centers[:n_frames]

        # Clean motion and store transform params
        whisker_clean, transform_params = clean_motion_energy(whisker_sync.copy(), log_transform=True)

        print(f"✓ pid={pid} neurons={spike_matrix.shape[0]} frames={n_frames}", flush=True)

        return {
            'pid': pid,
            'eid': eid,
            'n_neurons': total_units,
            'regions': region_counts,
            'spike_matrix': spike_matrix,
            'whisker_motion_raw': whisker_sync,
            'whisker_motion_clean': whisker_clean,
            'transform_params': transform_params,
            'cluster_regions': cluster_regions,
            'times': times,
        }

    except Exception as e:
        traceback.print_exc()
        print(f"[WARN] pid={pid}: {type(e).__name__} — {e}", flush=True)
        return None

## preprocessing

In [None]:
# ============================================================================
# PREPROCESSING - With lightweight progress statements
# ============================================================================

from scipy.ndimage import gaussian_filter1d
import numpy as np
from pathlib import Path

# --- Helper functions (no output) ---

def resample_to_bins(signal_times, signal_values, bin_centers):
    """Resample continuous signal to bin centers via interpolation."""
    signal_times = np.asarray(signal_times, dtype=np.float32)
    signal_values = np.asarray(signal_values, dtype=np.float32)

    mask = np.isfinite(signal_times) & np.isfinite(signal_values)
    signal_times, signal_values = signal_times[mask], signal_values[mask]

    if len(signal_times) == 0:
        return np.full(len(bin_centers), np.nan, dtype=np.float32)

    signal_times, unique_idx = np.unique(signal_times, return_index=True)
    signal_values = signal_values[unique_idx]

    return np.interp(bin_centers, signal_times, signal_values).astype(np.float32)


def clean_motion_energy(motion, zmax=3, pct=99.5, log_transform=True):
    """
    Clean and normalize whisker motion: clip outliers, log transform, z-score.
    Returns cleaned motion and transform_params for inverse transform.
    """
    motion = np.asarray(motion, dtype=np.float32)

    hi = np.percentile(motion, pct)
    np.clip(motion, None, hi, out=motion)

    if log_transform:
        motion += 1.0
        np.log(motion, out=motion)

    mean, std = np.nanmean(motion), np.nanstd(motion)
    motion -= mean
    motion /= std
    motion[np.abs(motion) > zmax] = np.nan
    np.nan_to_num(motion, copy=False, nan=0.0)

    return motion, {
        'log_transform': log_transform,
        'mean': mean,
        'std': std,
        'log_epsilon': 1.0
    }


def inverse_transform_motion(y_pred_clean, transform_params):
    """
    Convert predictions from z-scored, logged space back to raw scale.
    Reverses: z-score → log transform.
    """
    y_pred_clean = np.asarray(y_pred_clean, dtype=np.float32)
    y_pred_log = y_pred_clean * transform_params['std'] + transform_params['mean']

    if transform_params['log_transform']:
        return np.exp(y_pred_log) - transform_params['log_epsilon']
    return y_pred_log


def add_lags(X_TxF, lags=None):
    """Stack past frames as features. Returns (X_lagged, maxlag)."""
    if lags is None or len(lags) == 0:
        return X_TxF, 0

    lags = [int(L) for L in lags]
    maxlag = int(np.max(lags))
    T, F = X_TxF.shape
    T_out = T - maxlag

    X_lagged = np.zeros((T_out, F * len(lags)), dtype=X_TxF.dtype)
    for i, L in enumerate(lags):
        start_idx = maxlag - L
        X_lagged[:, i*F:(i+1)*F] = X_TxF[start_idx:start_idx + T_out]

    return X_lagged, maxlag


# ============================================================================
# Core preprocessing with disk-backed lag builder
# ============================================================================

def add_lags_memmap(X_TxF, lags, out_dir=Path("/content/mmap_cache")):
    """
    Build lagged features FAST: construct in RAM, then write once to memmap.

    For 150 neurons × 30 lags × 240k timesteps:
    - ~5-10 seconds (vs 2-3 minutes with loop + repeated flushes)

    Returns (np.memmap array, maxlag, out_path)
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    lags = np.asarray(lags, dtype=int)
    maxlag = int(lags.max())
    T, F = X_TxF.shape
    T_out = T - maxlag
    n_lags = len(lags)

    #print(f"[LAGS] {n_lags} lags: {T}×{F} → {T_out}×{F*n_lags} ({T_out*F*n_lags*4/1e9:.2f}GB)...")

    # --- Step 1: Build lagged matrix IN MEMORY (fast indexing) ---
    X_lagged_ram = np.empty((T_out, F * n_lags), dtype=np.float32)

    for i, L in enumerate(lags):
        start = maxlag - L
        end = start + T_out
        X_lagged_ram[:, i*F:(i+1)*F] = X_TxF[start:end]
        #if (i + 1) % 10 == 0:
            #print(f"  → {i+1}/{n_lags}")

    #print(f"[LAGS] writing to disk...")

    # --- Step 2: Write once to memmap (sequential, efficient) ---
    out_path = out_dir / f"lags_{np.random.randint(1e9)}.mmap"
    X_lagged_mm = np.memmap(out_path, dtype=np.float32, mode='w+', shape=(T_out, F * n_lags))

    # Write in chunks to avoid peak RAM spike
    CHUNK = max(1, int(64_000_000 // (F * n_lags)))
    for s in range(0, T_out, CHUNK):
        e = min(T_out, s + CHUNK)
        X_lagged_mm[s:e] = X_lagged_ram[s:e]

    X_lagged_mm.flush()
    del X_lagged_ram  # Free the RAM copy

    # Re-open read-only
    X_lagged_mm = np.memmap(out_path, dtype=np.float32, mode='r', shape=(T_out, F * n_lags))
    #print(f"[LAGS] ✓")

    return X_lagged_mm, maxlag, out_path


def preprocess_session(sess, lags=np.arange(0, 30).tolist(), smooth_sigma=1.0,
                       sqrt_counts=True, test_frac=0.2, normalization='standard',
                       verbose=False):
    """
    Preprocess a single session: load, clean, lag, normalize, split.
    """
    pid = sess["pid"]

    # Load motion
    #print(f"[PREP] {pid} | loading motion...")
    m_clean = np.asarray(sess["whisker_motion_clean"], dtype=np.float32)
    m_raw   = np.asarray(sess["whisker_motion_raw"], dtype=np.float32)
    transform_params = sess["transform_params"]

    # Spike preprocessing
    #print(f"[PREP] {pid} | loading spikes...")
    S = sess["spike_matrix"].astype(np.float32, copy=False)

    if sqrt_counts:
        #print(f"[PREP] {pid} | sqrt counts...")
        np.sqrt(S, out=S)

    if smooth_sigma and smooth_sigma > 0:
        #print(f"[PREP] {pid} | smoothing σ={smooth_sigma}...")
        S = gaussian_filter1d(S, sigma=smooth_sigma, axis=1)

    # Time-major transpose
    #print(f"[PREP] {pid} | transpose...")
    X_raw = S.T
    del S

    # Add lags via memmap
    #print(f"[PREP] {pid} | adding {len(lags)} lags...")
    X, cut, lag_path = add_lags_memmap(X_raw, lags)
    del X_raw

    # Align targets
    #print(f"[PREP] {pid} | align targets...")
    y_clean = m_clean[cut:cut + len(X)]
    y_raw   = m_raw[cut:cut + len(X)]

    # Drop NaNs
    mask = np.isfinite(y_clean)
    if (~mask).any():
        n_nan = (~mask).sum()
        #print(f"[PREP] {pid} | dropping {n_nan} NaNs...")
        X       = X[mask]
        y_clean = y_clean[mask]
        y_raw   = y_raw[mask]

    # Split train/test
    #print(f"[PREP] {pid} | splitting train/test...")
    split = int((1 - test_frac) * len(y_clean))
    X_train_mm, X_test_mm = X[:split], X[split:]
    y_train, y_test = y_clean[:split], y_clean[split:]
    y_train_raw, y_test_raw = y_raw[:split], y_raw[split:]
    del X, y_clean, y_raw

    # Copy memmap to RAM for normalization (read-only memmap can't be modified in-place)
    #print(f"[PREP] {pid} | copying to RAM...")
    X_train = np.array(X_train_mm, dtype=np.float32, copy=True)
    X_test = np.array(X_test_mm, dtype=np.float32, copy=True)
    del X_train_mm, X_test_mm

    # Normalization
    #print(f"[PREP] {pid} | normalizing ({normalization})...")
    if normalization == 'standard':
        mu  = X_train.mean(axis=0, dtype=np.float32)
        var = ((X_train - mu) ** 2).mean(axis=0, dtype=np.float32)
        std = np.sqrt(var, dtype=np.float32)
        std[std < 1e-6] = 1.0
        X_train -= mu;  X_train /= std
        X_test  -= mu;  X_test  /= std
        scale_params = {"mean": mu, "std": std}
    elif normalization == 'layer':
        def norm_layer_inplace(X):
            m = X.mean(axis=1, keepdims=True, dtype=np.float32)
            s2 = ((X - m) ** 2).mean(axis=1, keepdims=True, dtype=np.float32)
            s = np.sqrt(s2, dtype=np.float32)
            s[s < 1e-6] = 1.0
            X -= m; X /= s
        norm_layer_inplace(X_train)
        norm_layer_inplace(X_test)
        scale_params = None
    else:
        scale_params = None

    meta = {
        'pid': pid,
        'n_neurons': sess["n_neurons"],
        'regions': sess["regions"],
        'cluster_regions': sess["cluster_regions"],
        'lags': lags,
        'cut': cut,
        'transform_params': transform_params,
        'scale_params': scale_params,
        'lag_path': str(lag_path),  # For cleanup after training
    }

    print(f"[PREP] {pid} | ✓ train={X_train.shape} test={X_test.shape}")

    return X_train, X_test, y_train, y_test, y_train_raw, y_test_raw, meta

## training and eval

In [None]:
def compute_region_metrics(y_true, y_pred, X_features, cluster_regions, lags):
    """Create feature masks for per-region analysis."""
    n_lags = len(lags)
    region_metrics = {}

    for region in np.unique(cluster_regions):
        region_mask = (cluster_regions == region)
        if region_mask.sum() == 0:
            continue

        feature_mask = np.zeros(X_features.shape[1], dtype=bool)
        for neuron_idx in np.where(region_mask)[0]:
            feature_mask[neuron_idx * n_lags:(neuron_idx + 1) * n_lags] = True

        region_metrics[region] = {'feature_mask': feature_mask}

    return region_metrics


def compute_region_predictions(model, X_features, region_metrics):
    """Compute predictions using only neurons from each region."""
    region_predictions = {}
    w = model["coef_"]

    for region, info in region_metrics.items():
        X_region = X_features.copy()
        X_region[:, ~info['feature_mask']] = 0
        region_predictions[region] = X_region @ w

    return region_predictions

from scipy.stats import pearsonr

def train_and_evaluate_session(
    sess, meta, X_train, X_test, y_train, y_test,
    y_train_raw, y_test_raw, alpha=0.5,
    chunk=20000, verbose=False
):
    """
    Ridge regression via chunked normal equations (XᵀX, Xᵀy).
    Peak RAM ≈ O(n_features²), not O(n_samples·n_features).

    Returns (model_dict, summary_row, predictions)
    """
    pid = sess["pid"]
    n_samples, n_features = X_train.shape

    if verbose:
        print(f"[TRAIN] pid={pid} | {n_samples}×{n_features} | α={alpha}")

    # --- Build Gram matrix & RHS incrementally ---
    G = np.zeros((n_features, n_features), dtype=np.float32)
    b = np.zeros(n_features, dtype=np.float32)

    for s in range(0, n_samples, chunk):
        e = min(n_samples, s + chunk)
        Xc = X_train[s:e].astype(np.float32, copy=False)
        yc = y_train[s:e].astype(np.float32, copy=False)
        G += Xc.T @ Xc
        b += Xc.T @ yc

    G += alpha * np.eye(n_features, dtype=np.float32)

    # --- Solve (G + αI)w = b ---
    w = la.solve(G.astype(np.float64), b.astype(np.float64), assume_a='pos').astype(np.float32)
    del G, b

    # --- Predictions (chunked to keep RAM flat) ---
    def predict_chunked(X, w, chunk=20000):
        y_pred = np.empty(X.shape[0], dtype=np.float32)
        for s in range(0, X.shape[0], chunk):
            e = min(X.shape[0], s + chunk)
            y_pred[s:e] = X[s:e] @ w
        return y_pred

    y_train_pred = predict_chunked(X_train, w, chunk)
    y_test_pred  = predict_chunked(X_test,  w, chunk)

    # --- Inverse-transform to raw scale ---
    y_train_pred_raw = inverse_transform_motion(y_train_pred, meta['transform_params'])
    y_test_pred_raw  = inverse_transform_motion(y_test_pred,  meta['transform_params'])

    # --- Overall metrics ---
    train_R2 = r2_score(y_train_raw, y_train_pred_raw)
    train_R  = pearsonr(y_train_raw, y_train_pred_raw)[0]
    test_R2  = r2_score(y_test_raw,  y_test_pred_raw)
    test_R   = pearsonr(y_test_raw,  y_test_pred_raw)[0]

    if verbose:
        print(f"[EVAL] pid={pid} | Train R²={train_R2:.3f} | Test R²={test_R2:.3f}")

    # --- Per-region metrics ---
    region_info = compute_region_metrics(y_train_raw, y_train_pred_raw, X_train,
                                         meta['cluster_regions'], meta['lags'])

    model_dict = {"coef_": w}

    train_region_metrics = {}
    for region, y_pred in compute_region_predictions(model_dict, X_train, region_info).items():
        y_pred_raw = inverse_transform_motion(y_pred, meta['transform_params'])
        train_region_metrics[region] = {
            "R2": round5(r2_score(y_train_raw, y_pred_raw)),
            "R":  round5(pearsonr(y_train_raw, y_pred_raw)[0])
        }

    test_region_metrics = {}
    for region, y_pred in compute_region_predictions(model_dict, X_test, region_info).items():
        y_pred_raw = inverse_transform_motion(y_pred, meta['transform_params'])
        test_region_metrics[region] = {
            "R2": round5(r2_score(y_test_raw, y_pred_raw)),
            "R":  round5(pearsonr(y_test_raw, y_pred_raw)[0])
        }

    # --- Timeline predictions for downstream analysis ---
    times_full = sess['times']
    cut = meta['cut']
    times_aligned = times_full[cut:cut + len(y_train_raw) + len(y_test_raw)]

    predictions = [
        {"pid": str(pid), "time": round5(t), "y_true": round5(yt),
         "y_pred": round5(yp), "split": split}
        for split, (ts, ys, yp) in zip(
            ["train", "test"],
            [(times_aligned[:len(y_train_raw)], y_train_raw, y_train_pred_raw),
             (times_aligned[len(y_train_raw):], y_test_raw, y_test_pred_raw)]
        )
        for t, yt, yp in zip(ts, ys, yp)
    ]

    # --- Summary row (for CSV) ---
    summary_row = {
        "pid": str(pid),
        "n_neurons": meta['n_neurons'],
        "regions": to_json(meta['regions']),
        "train_R2_all": round5(train_R2),
        "train_R_all": round5(train_R),
        "test_R2_all":  round5(test_R2),
        "test_R_all":   round5(test_R),
        "train_region_metrics": to_json(train_region_metrics),
        "test_region_metrics":  to_json(test_region_metrics),
    }

    # Compact model object
    model = {"coef_": w, "alpha": alpha}

    return model, summary_row, predictions


## file IO

In [None]:
# ============================================================================
# FILE I/O
# ============================================================================

def init_summary_csv(path):
    """Initialize CSV with headers if it doesn't exist."""
    if not Path(path).exists():
        df = pd.DataFrame(columns=[
            "pid", "n_neurons", "regions", "train_R2_all", "train_R_all",
            "test_R2_all", "test_R_all", "train_region_metrics", "test_region_metrics"
        ])
        df.to_csv(path, index=False)


def append_to_summary_csv(path, row_dict):
    """Append a single row to summary CSV."""
    pd.DataFrame([row_dict]).to_csv(path, mode='a', header=False, index=False)


def save_predictions_batch(predictions_list, output_file):
    """Append predictions to parquet file."""
    df_new = pd.DataFrame(predictions_list)
    if Path(output_file).exists():
        df_existing = pd.read_parquet(output_file)
        df_new = pd.concat([df_existing, df_new], ignore_index=True)
    df_new.to_parquet(output_file, index=False)


def load_predictions_for_pid(pid, predictions_file):
    """Load predictions for a specific pid for plotting."""
    df = pd.read_parquet(predictions_file)
    return df[df['pid'] == str(pid)].drop(columns=['pid'])

## viz

In [None]:
# ============================================================================
# VISUALIZATION
# ============================================================================

def plot_prediction_window(y_true, y_pred, times, window_start_sec=10,
                           window_duration_sec=10, split_name="test", pid=None):
    """Plot 10-second window of predictions vs ground truth."""
    start_time = times[0] + window_start_sec
    end_time = start_time + window_duration_sec
    mask = (times >= start_time) & (times <= end_time)

    if mask.sum() == 0:
        print(f"[WARN] No data in window [{start_time:.1f}, {end_time:.1f}]")
        return

    fig, ax = plt.subplots(figsize=(12, 4))
    ax.plot(times[mask], y_true[mask], 'k-', linewidth=1.5, label='Ground Truth', alpha=0.7)
    ax.plot(times[mask], y_pred[mask], 'r-', linewidth=1.5, label='Predicted', alpha=0.7)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Whisker Motion Energy')
    ax.legend()
    ax.grid(True, alpha=0.3)

    title = f"{split_name.capitalize()}: {window_start_sec}s-{window_start_sec+window_duration_sec}s"
    if pid:
        title = f"PID: {pid}\n{title}"
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

    corr = np.corrcoef(y_true[mask], y_pred[mask])[0, 1]
    print(f"\nWindow stats: Correlation={corr:.5f}, "
          f"True range=[{y_true[mask].min():.3f}, {y_true[mask].max():.3f}], "
          f"Pred range=[{y_pred[mask].min():.3f}, {y_pred[mask].max():.3f}]")

In [None]:
def reconstruct_timeline(
    sess,
    y_train,
    y_test,
    y_train_pred,
    y_test_pred,
    lags=(0,1,2,3,4,5),
    motion_key="whisker_motion_clean"
):
    """Map train/test predictions back to original session timeline."""
    maxlag = max(lags) if lags else 0

    times_full = np.asarray(sess["times"])
    y_full = np.asarray(sess[motion_key], dtype=float)

    # Apply lag cutoff and NaN filtering (matching prepare_session)
    times_aligned = times_full[maxlag:maxlag + len(y_full) - maxlag]
    y_aligned = y_full[maxlag:maxlag + len(y_full) - maxlag]

    mask = np.isfinite(y_aligned)
    times_aligned = times_aligned[mask]

    # Split timeline
    n_train = len(y_train)
    times_train = times_aligned[:n_train]
    times_test = times_aligned[n_train:n_train + len(y_test)]
    split_time = times_train[-1] if len(times_train) > 0 else times_test[0]

    return {
        'times_full': times_full,
        'times_train': times_train,
        'times_test': times_test,
        'y_train': y_train,
        'y_test': y_test,
        'y_train_pred': y_train_pred,
        'y_test_pred': y_test_pred,
        'split_time': split_time,
        't_end_test': times_test[-1] if len(times_test) > 0 else None
    }

# main loop

In [None]:
all_pids = get_all_pids(one, project=None)

[INFO] Found 1179 pids with spikes.times


In [None]:
one = ONE()
ba = AllenAtlas()

Downloading: /root/Downloads/ONE/openalyx.internationalbrainlab.org/histology/ATLAS/Needles/Allen/average_template_25.nrrd Bytes: 32998960


100%|██████████| 31.470260620117188/31.470260620117188 [00:01<00:00, 17.34it/s]


Downloading: /root/Downloads/ONE/openalyx.internationalbrainlab.org/histology/ATLAS/Needles/Allen/annotation_25.nrrd Bytes: 4035363


100%|██████████| 3.848422050476074/3.848422050476074 [00:00<00:00,  4.34it/s]


In [None]:
# ============================================================================
# MAIN PIPELINE - Cleaned
# ============================================================================

import warnings
import json
import gc
import io
import contextlib
import uuid
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np

warnings.filterwarnings('ignore')

# --- Configuration ---
LOCAL_SUMMARY_CSV = "decoding_summary1.csv"
LOCAL_SKIPPED_CSV = "skipped1.csv"
SYNC_INTERVAL = 25

DRIVE_DIR = Path("/content/drive/MyDrive/fullrun/")
DRIVE_DIR.mkdir(parents=True, exist_ok=True)

DRIVE_SUMMARY_CSV = DRIVE_DIR / "decoding_summary1.csv"
DRIVE_SKIPPED_CSV = DRIVE_DIR / "skipped1.csv"

MMAP_DIR = Path("/content/mmap_cache")
MMAP_DIR.mkdir(parents=True, exist_ok=True)

# --- Utilities ---

def init_outputs_if_missing():
    """Initialize output CSV files with headers."""
    if not DRIVE_SUMMARY_CSV.exists():
        pd.DataFrame(columns=[
            "pid", "n_neurons", "regions", "train_R2_all", "train_R_all",
            "test_R2_all", "test_R_all", "train_region_metrics", "test_region_metrics"
        ]).to_csv(DRIVE_SUMMARY_CSV, index=False)
    if not DRIVE_SKIPPED_CSV.exists():
        pd.DataFrame(columns=["pid", "reason"]).to_csv(DRIVE_SKIPPED_CSV, index=False)


def append_to_csv(csv_path, row_dict):
    """Append a row to CSV file (create if missing)."""
    file_exists = Path(csv_path).exists() and Path(csv_path).stat().st_size > 0
    pd.DataFrame([row_dict]).to_csv(csv_path, mode="a", header=not file_exists, index=False)


def append_to_skipped_log(pid, reason):
    """Log a skipped PID with reason."""
    append_to_csv(LOCAL_SKIPPED_CSV, {"pid": str(pid), "reason": reason})


def sync_to_drive_and_cleanup():
    """Sync local CSVs to Google Drive and clean up."""
    try:
        if Path(LOCAL_SUMMARY_CSV).exists() and Path(LOCAL_SUMMARY_CSV).stat().st_size > 0:
            local_df = pd.read_csv(LOCAL_SUMMARY_CSV)
            if not local_df.empty:
                local_df.to_csv(DRIVE_SUMMARY_CSV, mode="a", header=False, index=False)
                Path(LOCAL_SUMMARY_CSV).unlink()

        if Path(LOCAL_SKIPPED_CSV).exists() and Path(LOCAL_SKIPPED_CSV).stat().st_size > 0:
            local_df = pd.read_csv(LOCAL_SKIPPED_CSV)
            if not local_df.empty:
                local_df.to_csv(DRIVE_SKIPPED_CSV, mode="a", header=False, index=False)
                Path(LOCAL_SKIPPED_CSV).unlink()

        gc.collect()
        return True
    except Exception as e:
        print(f"[SYNC ERROR] {e}")
        return False


# --- Memmap Helpers ---

def _as_float32_array(x):
    """Convert input to float32 array (minimal copy)."""
    if isinstance(x, pd.DataFrame):
        return x.to_numpy(dtype=np.float32, copy=False)
    x = np.asarray(x)
    if x.dtype != np.float32:
        return x.astype(np.float32, copy=False)
    return x


def _to_memmap(arr: np.ndarray, shape, fname: Path):
    """Write array to disk-backed memmap (chunked to keep RAM flat)."""
    mm = np.memmap(fname, dtype=np.float32, mode='w+', shape=shape)
    # Write in chunks: ~64M elements / cols = ~256MB per chunk @ float32
    CHUNK = max(1, int(64_000_000 // shape[1]))
    start = 0
    while start < shape[0]:
        end = min(shape[0], start + CHUNK)
        mm[start:end] = arr[start:end]
        start = end
    mm.flush()
    # Re-open read-only to prevent accidental copies
    return np.memmap(fname, dtype=np.float32, mode='r', shape=shape)


def _safe_unlink(path: Path):
    """Safely remove file if it exists."""
    try:
        if path.exists():
            path.unlink()
    except Exception:
        pass


# --- Core Processing ---

def process_pid(pid, one, ba):
    """
    Process a single PID: load → preprocess → train → evaluate.
    Returns (summary_row, skipped_flag)
    """
    try:
        fbuf = io.StringIO()
        with contextlib.redirect_stdout(fbuf), contextlib.redirect_stderr(fbuf):
            sess = load_session_data(pid, one, ba)
        captured = fbuf.getvalue()

        if sess is None:
            lines = [l for l in captured.splitlines() if ('[SKIP]' in l or '[WARN]' in l)]
            reason = lines[0] if lines else "Unknown reason"
            append_to_skipped_log(pid, reason)
            return None, True

        X_train, X_test, y_train, y_test, ytr_raw, yte_raw, meta = preprocess_session(sess)

        # Save lag_path to delete later
        lag_path = meta.get('lag_path')

        # --- Downcast & memmap to cap RAM ---
        X_train = _as_float32_array(X_train)
        X_test  = _as_float32_array(X_test)
        y_train = _as_float32_array(y_train).reshape(-1, 1)
        y_test  = _as_float32_array(y_test).reshape(-1, 1)

        # Unique file names per PID
        uid = uuid.uuid4().hex[:8]
        fXtr = MMAP_DIR / f"Xtr_{uid}.mmap"
        fXte = MMAP_DIR / f"Xte_{uid}.mmap"
        fytr = MMAP_DIR / f"ytr_{uid}.mmap"
        fyte = MMAP_DIR / f"yte_{uid}.mmap"

        # Persist to disk-backed arrays
        X_train_mm = _to_memmap(X_train, X_train.shape, fXtr)
        X_test_mm  = _to_memmap(X_test,  X_test.shape,  fXte)
        y_train_mm = _to_memmap(y_train, y_train.shape, fytr)
        y_test_mm  = _to_memmap(y_test,  y_test.shape,  fyte)

        # Free original dense arrays ASAP
        del X_train, X_test, y_train, y_test
        gc.collect()

        # --- Training on memmap arrays (read-only views) ---
        model, summary_row, _ = train_and_evaluate_session(
            sess, meta,
            X_train_mm, X_test_mm,
            y_train_mm.ravel(), y_test_mm.ravel(),
            ytr_raw, yte_raw,
            alpha=5,
            verbose=False
        )

        append_to_csv(LOCAL_SUMMARY_CSV, summary_row)

        # Cleanup heavy objects immediately
        del sess, model, meta, ytr_raw, yte_raw, X_train_mm, X_test_mm, y_train_mm, y_test_mm
        gc.collect()

        # Remove memmap files
        _safe_unlink(fXtr)
        _safe_unlink(fXte)
        _safe_unlink(fytr)
        _safe_unlink(fyte)
        if lag_path:
            _safe_unlink(Path(lag_path))

        return summary_row, False

    except Exception as e:
        append_to_skipped_log(pid, f"FAILED: {type(e).__name__} — {e}")
        import traceback
        traceback.print_exc()
        return None, True


# --- Main Pipeline ---

def main(all_pids, one, ba):
    """Main driver for full decoding pipeline."""
    init_outputs_if_missing()
    skipped_count, trained_count = 0, 0

    pbar = tqdm(all_pids, desc="Processing PIDs", unit="pid")

    for idx, pid in enumerate(pbar, 1):
        summary_row, skipped = process_pid(pid, one, ba)
        gc.collect()

        if skipped or summary_row is None:
            skipped_count += 1
            pbar.set_postfix_str(f"SKIPPED | {trained_count}✓ {skipped_count}✗")
        else:
            trained_count += 1
            r2 = summary_row.get('test_R2_all', np.nan)
            pbar.set_postfix_str(f"R²={r2:.3f} | {trained_count}✓ {skipped_count}✗")

        # Periodic sync to Drive
        if idx % SYNC_INTERVAL == 0:
            synced = sync_to_drive_and_cleanup()
            msg = "Synced" if synced else "Sync failed"
            pbar.write(f"[SYNC] {msg} after {idx} PIDs")

    # Final sync
    sync_ok = sync_to_drive_and_cleanup()
    print("\n[SYNC] Final sync " + ("completed" if sync_ok else "failed"))

    # Summary statistics
    total = trained_count + skipped_count
    print("\n" + "="*80)
    print("=== Pipeline Summary ===")
    print("="*80)
    print(f"Total PIDs: {total}")
    print(f"  ✓ Trained: {trained_count}")
    print(f"  ✗ Skipped: {skipped_count}")

    # Compute overall metrics from Drive CSV
    try:
        if DRIVE_SUMMARY_CSV.exists() and DRIVE_SUMMARY_CSV.stat().st_size > 0:
            df_stats = pd.read_csv(DRIVE_SUMMARY_CSV)
            if not df_stats.empty and "test_R2_all" in df_stats.columns:
                mu = df_stats["test_R2_all"].mean()
                sd = df_stats["test_R2_all"].std()
                idxmax = df_stats["test_R2_all"].idxmax()
                best_pid = df_stats.loc[idxmax, "pid"]
                best_r2 = df_stats.loc[idxmax, "test_R2_all"]
                print(f"\nPerformance Summary:")
                print(f"  Mean R²: {mu:.5f} ± {sd:.5f}")
                print(f"  Best: {best_pid} (R²={best_r2:.5f})")
    except Exception as e:
        print(f"[WARN] Could not load summary stats: {e}")

    print("\nOutputs:")
    print(f"  - Summary CSV: {DRIVE_SUMMARY_CSV}")
    print(f"  - Skipped CSV: {DRIVE_SKIPPED_CSV}")
    print(f"  - Local disk: CLEAN (synced to Drive)")
    print("="*80)

# run

In [None]:
import scipy.linalg as la

In [None]:
main(all_pids, one, ba)

Processing PIDs:   6%|▋         | 9/144 [04:16<1:08:05, 30.27s/pid, SKIPPED | 0✓ 9✗]



(384, 2),	localCoordinates
(192, 3),	mlapdv
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
Processing PIDs:  10%|▉         | 14/144 [06:20<55:12, 25.48s/pid, SKIPPED | 0✓ 14✗]



(384, 2),	localCoordinates
(192, 3),	mlapdv
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
Processing PIDs:  17%|█▋        | 25/144 [11:07<50:30, 25.47s/pid, SKIPPED | 0✓ 25✗]

[SYNC] Synced after 25 PIDs


(384,),	rawInd
(192,),	brainLocationIds_ccf_2017
(384, 2),	localCoordinates
(192, 3),	mlapdv
Processing PIDs:  27%|██▋       | 39/144 [17:38<46:37, 26.65s/pid, SKIPPED | 0✓ 39✗]



(384,),	rawInd
(192,),	brainLocationIds_ccf_2017
(384, 2),	localCoordinates
(192, 3),	mlapdv
Processing PIDs:  35%|███▍      | 50/144 [22:32<40:54, 26.11s/pid, SKIPPED | 0✓ 50✗]

[SYNC] Synced after 50 PIDs


(192,),	brainLocationIds_ccf_2017
(192, 3),	mlapdv
(384,),	rawInd
(384, 2),	localCoordinates
(384, 2),	localCoordinates
(192, 3),	mlapdv
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
(192, 3),	mlapdv
(384, 2),	localCoordinates
(384,),	rawInd
(192,),	brainLocationIds_ccf_2017
Processing PIDs:  42%|████▏     | 61/144 [27:23<36:49, 26.62s/pid, SKIPPED | 0✓ 61✗]

[PREP] e940541b-c564-46cf-99c8-f2207cfdb79c | ✓ train=(217292, 3870) test=(54323, 3870)


Traceback (most recent call last):
  File "/tmp/ipython-input-167243327.py", line 162, in process_pid
    model, summary_row, _ = train_and_evaluate_session(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3644626716.py", line 65, in train_and_evaluate_session
    w = la.solve(G.astype(np.float64), b.astype(np.float64), assume_a='pos').astype(np.float32)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/scipy/_lib/_util.py", line 1233, in wrapper
    return f(*arrays, *other_args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/scipy/linalg/_basic.py", line 341, in solve
    _solve_check(n, info)
  File "/usr/local/lib/python3.12/dist-packages/scipy/linalg/_basic.py", line 43, in _solve_check
    raise LinAlgError('Matrix is singular.')
numpy.linalg.LinAlgError: Matrix is singular.
(192,),	brainLocationIds_ccf_2017
(384, 2)



Processing PIDs:  52%|█████▏    | 75/144 [34:35<30:38, 26.65s/pid, SKIPPED | 0✓ 75✗]

[SYNC] Synced after 75 PIDs


(521,),	channels
(521,),	depths
(748, 16),	metrics
(384,),	rawInd
(192, 3),	mlapdv
(384, 2),	localCoordinates
(192,),	brainLocationIds_ccf_2017
(384, 2),	localCoordinates
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
(192, 3),	mlapdv
Processing PIDs:  69%|██████▉   | 100/144 [46:59<19:34, 26.69s/pid, SKIPPED | 0✓ 100✗]

[SYNC] Synced after 100 PIDs


(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
(384, 2),	localCoordinates
(192, 3),	mlapdv
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
(384, 2),	localCoordinates
(192, 3),	mlapdv
Processing PIDs:  87%|████████▋ | 125/144 [58:43<09:06, 28.77s/pid, SKIPPED | 0✓ 125✗]

[SYNC] Synced after 125 PIDs


(192, 3),	mlapdv
(192,),	brainLocationIds_ccf_2017
(384,),	rawInd
(384, 2),	localCoordinates
Processing PIDs: 100%|██████████| 144/144 [1:07:41<00:00, 28.20s/pid, SKIPPED | 0✓ 144✗]


[SYNC] Final sync completed

=== Pipeline Summary ===
Total PIDs: 144
  ✓ Trained: 0
  ✗ Skipped: 144

Performance Summary:
  Mean R²: -10863527243515472199548928.00000 ± 209246312195206743018438656.00000
  Best: 784c5282-2749-48ca-b211-100d0b24e29b (R²=0.75861)

Outputs:
  - Summary CSV: /content/drive/MyDrive/fullrun/decoding_summary1.csv
  - Skipped CSV: /content/drive/MyDrive/fullrun/skipped1.csv
  - Local disk: CLEAN (synced to Drive)



