# Decoding Social Intent from Neural Oscillations
## Exploratory Data Analysis

**COGS 118C — Signal Processing Course Project**

**Research Question:** Can we classify whether an animal is interacting socially vs. exploring alone based on the spectral features of its neural calcium signals?

---

### Goals of this notebook

1. **Acquire & load** data from the EDGE course notebooks (calcium traces, behavioral annotations)
2. **Inspect** data shapes, types, sampling rates, and quality
3. **Visualize** raw calcium signals and behavioral epoch structure
4. **Preliminary spectral analysis** — PSD, spectrograms, wavelets on sample signals
5. **Compare** spectral features between social vs. solo epochs
6. **Identify** data quality issues, confounds, and scope for the full project

---
## 1. Setup & Dependencies

In [None]:
# Install dependencies (uncomment if needed)
# !pip install numpy pandas matplotlib seaborn scipy h5py gdown pywt scikit-learn

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import json
import os
from pathlib import Path

# Signal processing
from scipy import signal
from scipy.signal import welch, butter, filtfilt, stft
import pywt

# Styling
sns.set_theme(style="whitegrid", context="notebook", font_scale=1.1)
plt.rcParams["figure.figsize"] = (14, 5)
plt.rcParams["figure.dpi"] = 100

# Reproducibility
np.random.seed(42)

print("All imports successful.")

---
## 2. Data Acquisition

Our data comes from the **EDGE** (Education in Data and Guided Exploration) course notebooks. We download the relevant Colab notebooks and extract the data they reference.

### Source notebooks

| Notebook | Content | Priority |
|----------|---------|----------|
| Finding social behaviors | Behavioral annotations (social vs. non-social) | High |
| Calcium demo | Raw calcium fluorescence traces | High |
| Demixed calcium | Clean, source-separated signals | High |
| Neural signals of social isolation | Calcium + social condition labels | High |
| Analyzing social isolation | Behavioral data for isolation study | Medium |
| social_bouts.00 | Compiled social bout timing data | Medium |

In [None]:
import gdown

DATA_DIR = Path("../data")
DATA_DIR.mkdir(exist_ok=True)

NB_DIR = DATA_DIR / "source_notebooks"
NB_DIR.mkdir(exist_ok=True)

# Google Drive file IDs extracted from the Colab URLs in data.md
NOTEBOOKS = {
    "finding_social_behaviors": "13HGdsrS4lYxZcfpr4snE-fKbKhgcSZbf",
    "calcium_demo": "1bhdkHCeHoOg2z0FTgnPlPfIasmxYhcqr",
    "motion_correction": "1hL9mE9kZ2nr_RX0W6N7-dEDReifEeowR",
    "demixed_calcium": "1CEZ13yr_5usvYLXFmTtTaepH7bVPeCU3",
    "analyzing_social_isolation": "1nqRrS3MS1ASBeJjJCLOUbxqvP5VSOW1j",
    "neural_signals_social_isolation": "1OHi6j34edcKM-X2gYmyHYytb9zCTPEIn",
}

# Also download the social_bouts data file
SOCIAL_BOUTS_ID = "1Lz5hya0W_sXpcIptcrnQzCOWoestkOmm"

print(f"Data directory: {DATA_DIR.resolve()}")
print(f"Notebook directory: {NB_DIR.resolve()}")

In [None]:
# Download notebooks from Google Drive
for name, file_id in NOTEBOOKS.items():
    output_path = NB_DIR / f"{name}.ipynb"
    if output_path.exists():
        print(f"  [skip] {name} already downloaded")
        continue
    try:
        url = f"https://drive.google.com/uc?id={file_id}"
        gdown.download(url, str(output_path), quiet=True)
        print(f"  [ok]   {name}")
    except Exception as e:
        print(f"  [FAIL] {name}: {e}")

# Download social_bouts data
bouts_path = DATA_DIR / "social_bouts.ipynb"
if not bouts_path.exists():
    try:
        gdown.download(
            f"https://drive.google.com/uc?id={SOCIAL_BOUTS_ID}",
            str(bouts_path), quiet=True
        )
        print(f"  [ok]   social_bouts")
    except Exception as e:
        print(f"  [FAIL] social_bouts: {e}")

print("\nDownloaded files:")
for f in sorted(DATA_DIR.rglob("*")):
    if f.is_file():
        size_kb = f.stat().st_size / 1024
        print(f"  {f.relative_to(DATA_DIR)}  ({size_kb:.1f} KB)")

### 2.1 Extract data URLs from downloaded notebooks

The EDGE notebooks typically load data from Google Drive or hosted URLs. Let's parse the notebooks to find those data sources.

In [None]:
import re

def extract_data_urls_from_notebook(nb_path):
    """Parse a .ipynb file and extract data-loading URLs and gdown calls."""
    urls = []
    try:
        with open(nb_path, "r") as f:
            nb = json.load(f)
        for cell in nb.get("cells", []):
            if cell["cell_type"] != "code":
                continue
            source = "".join(cell["source"])
            # Match URLs (http/https, Google Drive, raw GitHub, etc.)
            found = re.findall(r'["\']?(https?://[^\s"\'\)]+)["\']?', source)
            urls.extend(found)
            # Match gdown IDs
            gdown_ids = re.findall(r'gdown\.download\([^)]*["\']([a-zA-Z0-9_-]{20,})["\']', source)
            for gid in gdown_ids:
                urls.append(f"https://drive.google.com/uc?id={gid}")
    except (json.JSONDecodeError, KeyError) as e:
        print(f"  Could not parse {nb_path.name}: {e}")
    return urls

print("Data URLs found in source notebooks:\n")
all_data_urls = {}
for nb_file in sorted(NB_DIR.glob("*.ipynb")):
    urls = extract_data_urls_from_notebook(nb_file)
    if urls:
        all_data_urls[nb_file.stem] = urls
        print(f"  {nb_file.stem}:")
        for u in urls:
            print(f"    {u}")
    else:
        print(f"  {nb_file.stem}: (no URLs found)")
    print()

### 2.2 Download extracted data files

We attempt to download the actual data files referenced by the notebooks. These are typically `.h5`, `.npy`, `.csv`, or `.mat` files.

In [None]:
import urllib.request

RAW_DIR = DATA_DIR / "raw"
RAW_DIR.mkdir(exist_ok=True)

# Filter for actual data file URLs (not pip install, not colab, not docs)
SKIP_PATTERNS = ["pip", "colab", "github.com/", "googleapis.com/auth", ".git", "readme"]
DATA_EXTENSIONS = [".h5", ".hdf5", ".npy", ".npz", ".csv", ".mat", ".pkl", ".parquet", ".zip", ".tar"]

data_urls_to_download = []
for nb_name, urls in all_data_urls.items():
    for url in urls:
        url_lower = url.lower()
        if any(skip in url_lower for skip in SKIP_PATTERNS):
            continue
        # Keep Drive download links and files with data extensions
        if "drive.google.com" in url_lower or any(url_lower.endswith(ext) for ext in DATA_EXTENSIONS):
            data_urls_to_download.append((nb_name, url))
        # Also keep raw content URLs (e.g. raw.githubusercontent.com)
        elif "raw." in url_lower or "download" in url_lower:
            data_urls_to_download.append((nb_name, url))

print(f"Found {len(data_urls_to_download)} potential data URLs to download:\n")
for nb_name, url in data_urls_to_download:
    print(f"  [{nb_name}] {url}")

In [None]:
# Download data files
downloaded_files = []

for nb_name, url in data_urls_to_download:
    # Determine filename from URL or use a generated name
    if "drive.google.com" in url:
        # Use gdown for Google Drive
        file_id = re.search(r'id=([a-zA-Z0-9_-]+)', url)
        if file_id:
            out_path = RAW_DIR / f"{nb_name}_drive_{file_id.group(1)[:8]}"
            try:
                result = gdown.download(url, str(out_path), quiet=True)
                if result:
                    downloaded_files.append(out_path)
                    print(f"  [ok]   {out_path.name} ({out_path.stat().st_size / 1024:.1f} KB)")
            except Exception as e:
                print(f"  [FAIL] {url}: {e}")
    else:
        # Direct URL download
        fname = url.split("/")[-1].split("?")[0]
        if not fname or len(fname) > 100:
            fname = f"{nb_name}_data"
        out_path = RAW_DIR / fname
        if out_path.exists():
            print(f"  [skip] {fname} already exists")
            downloaded_files.append(out_path)
            continue
        try:
            urllib.request.urlretrieve(url, str(out_path))
            downloaded_files.append(out_path)
            print(f"  [ok]   {fname} ({out_path.stat().st_size / 1024:.1f} KB)")
        except Exception as e:
            print(f"  [FAIL] {fname}: {e}")

print(f"\nTotal files downloaded: {len(downloaded_files)}")

---
## 3. Data Loading & Inspection

Let's inspect what we actually have. We check file types, shapes, and structure for every downloaded data file.

In [None]:
def inspect_file(filepath):
    """Inspect a data file and report its contents."""
    fp = Path(filepath)
    size = fp.stat().st_size
    print(f"\n{'='*60}")
    print(f"File: {fp.name}  ({size / 1024:.1f} KB)")
    print(f"{'='*60}")

    # Try as HDF5
    try:
        with h5py.File(fp, "r") as f:
            print("Format: HDF5")
            def print_hdf5(name, obj):
                if isinstance(obj, h5py.Dataset):
                    print(f"  Dataset: {name}  shape={obj.shape}  dtype={obj.dtype}")
                elif isinstance(obj, h5py.Group):
                    print(f"  Group:   {name}/")
            f.visititems(print_hdf5)
        return "h5"
    except:
        pass

    # Try as numpy
    try:
        data = np.load(fp, allow_pickle=True)
        if isinstance(data, np.lib.npyio.NpzFile):
            print("Format: NPZ (compressed numpy)")
            for key in data.files:
                print(f"  Array: {key}  shape={data[key].shape}  dtype={data[key].dtype}")
        else:
            print(f"Format: NPY  shape={data.shape}  dtype={data.dtype}")
        return "npy"
    except:
        pass

    # Try as CSV / pandas
    try:
        df = pd.read_csv(fp, nrows=5)
        print(f"Format: CSV  shape={df.shape}  columns={list(df.columns)}")
        print(df.head(3).to_string())
        return "csv"
    except:
        pass

    # Try as JSON / notebook
    try:
        with open(fp, "r") as f:
            content = json.load(f)
        if "cells" in content:
            n_code = sum(1 for c in content["cells"] if c["cell_type"] == "code")
            n_md = sum(1 for c in content["cells"] if c["cell_type"] == "markdown")
            print(f"Format: Jupyter notebook  ({n_code} code cells, {n_md} markdown cells)")
            return "ipynb"
        else:
            print(f"Format: JSON  (top-level keys: {list(content.keys())[:10]})")
            return "json"
    except:
        pass

    # Fallback — show first bytes
    try:
        with open(fp, "rb") as f:
            header = f.read(200)
        print(f"Format: Unknown  (first bytes: {header[:50]})")
    except:
        print("Format: Could not read")
    return "unknown"

# Inspect all downloaded data files
file_types = {}
for fp in sorted(RAW_DIR.iterdir()):
    if fp.is_file():
        ft = inspect_file(fp)
        file_types[fp.name] = ft

### 3.1 Load the primary data

Based on the inspection above, load the key datasets into memory. 

> **Note:** Adjust the cell below based on what file formats were actually downloaded. The code handles the most common formats from EDGE notebooks: HDF5, numpy arrays, and CSV files.

In [None]:
# ============================================================
# ADAPTIVE DATA LOADER
# Loads data based on whatever format was downloaded.
# Edit the paths below if your data lives somewhere else.
# ============================================================

calcium_data = None      # Will hold calcium traces (neurons x time) or (time x neurons)
behavior_labels = None   # Will hold behavioral state labels (time,) or epoch-level
sampling_rate = None     # Hz — critical for spectral analysis

# --- Attempt to load from HDF5 files ---
h5_files = list(RAW_DIR.glob("*.h5")) + list(RAW_DIR.glob("*.hdf5"))
for h5f in h5_files:
    print(f"\nLoading HDF5: {h5f.name}")
    with h5py.File(h5f, "r") as f:
        # Print all datasets to find calcium traces
        def find_arrays(name, obj):
            if isinstance(obj, h5py.Dataset) and len(obj.shape) >= 1:
                print(f"  {name}: shape={obj.shape}, dtype={obj.dtype}")
        f.visititems(find_arrays)

# --- Attempt to load from numpy files ---
npy_files = list(RAW_DIR.glob("*.npy")) + list(RAW_DIR.glob("*.npz"))
for npf in npy_files:
    print(f"\nLoading numpy: {npf.name}")
    data = np.load(npf, allow_pickle=True)
    if isinstance(data, np.lib.npyio.NpzFile):
        for key in data.files:
            arr = data[key]
            print(f"  {key}: shape={arr.shape}, dtype={arr.dtype}")
            # Heuristic: largest 2D array is probably calcium data
            if arr.ndim == 2 and (calcium_data is None or arr.size > calcium_data.size):
                calcium_data = arr
                print(f"    -> Loaded as calcium_data")
    else:
        print(f"  shape={data.shape}, dtype={data.dtype}")
        if data.ndim == 2 and (calcium_data is None or data.size > calcium_data.size):
            calcium_data = data
            print(f"  -> Loaded as calcium_data")

# --- Attempt to load from CSV files ---
csv_files = list(RAW_DIR.glob("*.csv"))
for csvf in csv_files:
    print(f"\nLoading CSV: {csvf.name}")
    df = pd.read_csv(csvf)
    print(f"  shape={df.shape}, columns={list(df.columns)[:10]}")

print("\n" + "="*60)
if calcium_data is not None:
    print(f"calcium_data loaded: shape={calcium_data.shape}")
else:
    print("WARNING: No calcium data loaded yet — see Section 3.2 for manual loading.")

### 3.2 Manual data loading (if automatic extraction failed)

If the automatic download didn't find the data files, you can:

1. **Open the source Colab notebooks** listed in `data.md`
2. **Run the data-loading cells** in each notebook
3. **Download the data** to `../data/raw/` using Colab's file browser
4. **Update the paths below** and re-run

Alternatively, paste the data loading code from the Colab notebooks directly below:

In [None]:
# ============================================================
# MANUAL DATA LOADING
# Uncomment and edit the appropriate section for your data.
# ============================================================

# --- Option A: Load from HDF5 ---
# with h5py.File("../data/raw/YOUR_FILE.h5", "r") as f:
#     print(list(f.keys()))  # See what's inside
#     calcium_data = f["YOUR_DATASET_KEY"][:]  # Load into memory
#     # behavior_labels = f["YOUR_LABELS_KEY"][:]
#     # sampling_rate = f.attrs.get("sampling_rate", 20.0)  # Check attrs

# --- Option B: Load from numpy ---
# calcium_data = np.load("../data/raw/YOUR_FILE.npy")
# behavior_labels = np.load("../data/raw/YOUR_LABELS.npy")

# --- Option C: Load from CSV ---
# df = pd.read_csv("../data/raw/YOUR_FILE.csv")
# calcium_data = df.iloc[:, 1:].values  # Assumes first col is time
# time_axis = df.iloc[:, 0].values

# --- Set sampling rate (CRITICAL — verify from your data source) ---
if sampling_rate is None:
    sampling_rate = 20.0  # Hz — typical for calcium imaging; CHANGE if different
    print(f"Using default sampling_rate = {sampling_rate} Hz")
    print("WARNING: Verify this matches your actual data!")

---
## 4. Calcium Signal Visualization

Visualize the raw calcium traces to check for:
- **Signal quality** — are there clear transients?
- **Photobleaching** — exponential decay in baseline?
- **Motion artifacts** — sudden jumps correlated across all neurons?
- **Dynamic range** — what's the ΔF/F amplitude?

In [None]:
if calcium_data is None:
    print("No calcium data loaded — generating synthetic data for demonstration.")
    print("Replace this with real data once available.\n")

    # ----- Synthetic calcium-like data for method demonstration -----
    sampling_rate = 20.0  # Hz
    n_neurons = 30
    duration = 600  # seconds (10 minutes)
    n_timepoints = int(duration * sampling_rate)
    t = np.arange(n_timepoints) / sampling_rate

    calcium_data = np.zeros((n_neurons, n_timepoints))

    # Create behavioral epochs: alternating social (1) and solo (0)
    behavior_labels = np.zeros(n_timepoints, dtype=int)
    epoch_starts = np.arange(0, n_timepoints, int(30 * sampling_rate))  # 30-sec epochs
    for i, start in enumerate(epoch_starts):
        end = min(start + int(30 * sampling_rate), n_timepoints)
        if i % 2 == 1:  # Odd epochs are "social"
            behavior_labels[start:end] = 1

    for n in range(n_neurons):
        # Baseline 1/f noise (characteristic of calcium signals)
        freqs = np.fft.rfftfreq(n_timepoints, d=1/sampling_rate)
        freqs[0] = 1  # Avoid division by zero
        noise_spectrum = 1 / (freqs ** 0.8) * np.exp(1j * 2 * np.pi * np.random.rand(len(freqs)))
        baseline = np.fft.irfft(noise_spectrum, n=n_timepoints)
        baseline = baseline / np.std(baseline) * 0.1

        # Calcium transients (random spikes convolved with exponential decay)
        spike_rate = 0.3 + 0.2 * np.random.rand()  # Hz
        spikes = np.random.poisson(spike_rate / sampling_rate, n_timepoints)
        # Higher spike rate during social epochs for some neurons
        if np.random.rand() > 0.4:  # 60% of neurons are socially modulated
            social_mask = behavior_labels == 1
            spikes[social_mask] = np.random.poisson(
                (spike_rate * 1.8) / sampling_rate, social_mask.sum()
            )
        # Exponential decay kernel (GCaMP6f-like, ~150ms decay)
        tau = 0.15 * sampling_rate  # decay constant in samples
        kernel = np.exp(-np.arange(int(5 * tau)) / tau)
        kernel /= kernel.sum()
        transients = np.convolve(spikes.astype(float), kernel, mode="full")[:n_timepoints]

        # Add slow oscillation modulation during social epochs
        social_modulation = np.zeros(n_timepoints)
        social_modulation[behavior_labels == 1] = (
            0.05 * np.sin(2 * np.pi * 5 * t[behavior_labels == 1])  # ~5 Hz theta-like
        )

        calcium_data[n] = baseline + transients * 0.5 + social_modulation

    # Add mild photobleaching
    bleaching = np.exp(-t / 2000) * 0.03
    calcium_data += bleaching[np.newaxis, :]

    print(f"Synthetic data generated:")
    print(f"  calcium_data shape: {calcium_data.shape} (neurons x time)")
    print(f"  behavior_labels shape: {behavior_labels.shape}")
    print(f"  sampling_rate: {sampling_rate} Hz")
    print(f"  duration: {duration} sec")
    print(f"  social epochs: {behavior_labels.sum()} samples ({100*behavior_labels.mean():.1f}%)")
    print(f"  solo epochs:   {(1-behavior_labels).sum()} samples ({100*(1-behavior_labels.mean()):.1f}%)")

else:
    print(f"Using real data:")
    print(f"  calcium_data shape: {calcium_data.shape}")
    print(f"  sampling_rate: {sampling_rate} Hz")
    n_timepoints = calcium_data.shape[-1]  # Assumes neurons x time
    t = np.arange(n_timepoints) / sampling_rate
    duration = n_timepoints / sampling_rate
    print(f"  duration: {duration:.1f} sec")

In [None]:
# --- Plot raw calcium traces for a subset of neurons ---
n_show = min(8, calcium_data.shape[0])

fig, axes = plt.subplots(n_show, 1, figsize=(16, 2.2 * n_show), sharex=True)
if n_show == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    ax.plot(t, calcium_data[i], linewidth=0.5, color="#2c3e50")
    ax.set_ylabel(f"Neuron {i}", fontsize=9)
    ax.tick_params(labelsize=8)

    # Shade social epochs
    if behavior_labels is not None:
        social_mask = behavior_labels == 1
        starts = np.where(np.diff(social_mask.astype(int)) == 1)[0]
        ends = np.where(np.diff(social_mask.astype(int)) == -1)[0]
        if social_mask[0]:
            starts = np.concatenate([[0], starts])
        if social_mask[-1]:
            ends = np.concatenate([ends, [len(social_mask) - 1]])
        for s, e in zip(starts, ends):
            ax.axvspan(t[s], t[e], alpha=0.15, color="#e74c3c", label="Social" if (i == 0 and s == starts[0]) else None)

axes[-1].set_xlabel("Time (s)")
axes[0].set_title("Raw Calcium Traces (red shading = social epochs)", fontsize=12)
if behavior_labels is not None:
    axes[0].legend(loc="upper right", fontsize=9)
plt.tight_layout()
plt.show()

In [None]:
# --- Basic statistics ---
print("Calcium signal statistics (across all neurons):\n")
print(f"  Mean ΔF/F:     {calcium_data.mean():.4f}")
print(f"  Std  ΔF/F:     {calcium_data.std():.4f}")
print(f"  Min  ΔF/F:     {calcium_data.min():.4f}")
print(f"  Max  ΔF/F:     {calcium_data.max():.4f}")
print(f"  Median ΔF/F:   {np.median(calcium_data):.4f}")
print(f"\n  Nyquist freq:  {sampling_rate / 2:.1f} Hz")
print(f"  Max resolvable: theta ({sampling_rate/2:.0f} Hz Nyquist allows up to ~{sampling_rate/2 - 1:.0f} Hz)")

# Per-neuron activity levels
neuron_stds = calcium_data.std(axis=1)
print(f"\nPer-neuron std range: [{neuron_stds.min():.4f}, {neuron_stds.max():.4f}]")
print(f"Potentially inactive neurons (std < 0.01): {(neuron_stds < 0.01).sum()}")

In [None]:
# --- Distribution of calcium signal values ---
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Histogram of all values
axes[0].hist(calcium_data.ravel(), bins=100, color="#3498db", alpha=0.7, edgecolor="white")
axes[0].set_xlabel("ΔF/F")
axes[0].set_ylabel("Count")
axes[0].set_title("Distribution of all calcium values")

# Per-neuron standard deviations
axes[1].bar(range(len(neuron_stds)), np.sort(neuron_stds)[::-1], color="#2ecc71")
axes[1].set_xlabel("Neuron (sorted)")
axes[1].set_ylabel("Std of ΔF/F")
axes[1].set_title("Activity level per neuron")

# Correlation matrix (subset)
n_corr = min(20, calcium_data.shape[0])
corr = np.corrcoef(calcium_data[:n_corr])
im = axes[2].imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")
axes[2].set_xlabel("Neuron")
axes[2].set_ylabel("Neuron")
axes[2].set_title(f"Pairwise correlation (first {n_corr} neurons)")
plt.colorbar(im, ax=axes[2], shrink=0.8)

plt.tight_layout()
plt.show()

### 4.1 Check for photobleaching

Photobleaching causes an exponential decay in fluorescence baseline over the recording. This injects power into the lowest frequency bins and **must be corrected** before spectral analysis.

In [None]:
# Check for photobleaching by looking at the mean trace over time
mean_trace = calcium_data.mean(axis=0)

fig, axes = plt.subplots(1, 2, figsize=(16, 4))

# Mean trace
axes[0].plot(t, mean_trace, linewidth=0.8, color="#2c3e50")
# Linear fit to detect trend
slope, intercept = np.polyfit(t, mean_trace, 1)
axes[0].plot(t, slope * t + intercept, "--r", linewidth=1.5, label=f"Linear fit (slope={slope:.2e})")
axes[0].set_xlabel("Time (s)")
axes[0].set_ylabel("Mean ΔF/F")
axes[0].set_title("Population mean trace — check for photobleaching")
axes[0].legend()

# Rolling mean to see slow trends
window_sec = 30  # 30-second rolling window
window_samples = int(window_sec * sampling_rate)
rolling_mean = pd.Series(mean_trace).rolling(window=window_samples, center=True).mean()
axes[1].plot(t, rolling_mean, linewidth=1.2, color="#e74c3c")
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("30-sec rolling mean ΔF/F")
axes[1].set_title("Slow baseline drift")

plt.tight_layout()
plt.show()

if abs(slope) > 1e-5:
    print(f"Baseline drift detected (slope = {slope:.2e} ΔF/F per second).")
    print("Recommendation: Detrend before spectral analysis.")
else:
    print("No significant baseline drift detected.")

---
## 5. Behavioral Epoch Structure

Examine the distribution and structure of social vs. solo behavioral labels.

In [None]:
if behavior_labels is not None:
    # Epoch statistics
    social_frac = behavior_labels.mean()
    solo_frac = 1 - social_frac

    # Find epoch boundaries
    changes = np.diff(behavior_labels)
    epoch_boundaries = np.where(changes != 0)[0] + 1
    epoch_starts = np.concatenate([[0], epoch_boundaries])
    epoch_ends = np.concatenate([epoch_boundaries, [len(behavior_labels)]])
    epoch_labels = [behavior_labels[s] for s in epoch_starts]
    epoch_durations = (epoch_ends - epoch_starts) / sampling_rate  # in seconds

    social_durations = epoch_durations[np.array(epoch_labels) == 1]
    solo_durations = epoch_durations[np.array(epoch_labels) == 0]

    print(f"Behavioral epoch summary:")
    print(f"  Total epochs: {len(epoch_starts)}")
    print(f"  Social: {len(social_durations)} epochs, {social_frac*100:.1f}% of total time")
    print(f"  Solo:   {len(solo_durations)} epochs, {solo_frac*100:.1f}% of total time")
    print(f"  Class ratio (solo:social): {solo_frac/social_frac:.2f}:1")
    print(f"\n  Social epoch duration: mean={social_durations.mean():.1f}s, "
          f"std={social_durations.std():.1f}s, range=[{social_durations.min():.1f}, {social_durations.max():.1f}]s")
    print(f"  Solo epoch duration:   mean={solo_durations.mean():.1f}s, "
          f"std={solo_durations.std():.1f}s, range=[{solo_durations.min():.1f}, {solo_durations.max():.1f}]s")

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))

    # Class balance pie chart
    axes[0].pie([social_frac, solo_frac], labels=["Social", "Solo"],
                colors=["#e74c3c", "#3498db"], autopct="%1.1f%%", startangle=90)
    axes[0].set_title("Class balance")

    # Epoch duration distributions
    bins = np.linspace(0, max(epoch_durations.max(), 60), 20)
    axes[1].hist(social_durations, bins=bins, alpha=0.7, color="#e74c3c", label="Social")
    axes[1].hist(solo_durations, bins=bins, alpha=0.7, color="#3498db", label="Solo")
    axes[1].set_xlabel("Epoch duration (s)")
    axes[1].set_ylabel("Count")
    axes[1].set_title("Epoch duration distribution")
    axes[1].legend()

    # Timeline
    axes[2].fill_between(t, behavior_labels, step="pre", alpha=0.5, color="#e74c3c", label="Social")
    axes[2].set_xlabel("Time (s)")
    axes[2].set_ylabel("Label")
    axes[2].set_title("Behavioral state timeline")
    axes[2].set_yticks([0, 1])
    axes[2].set_yticklabels(["Solo", "Social"])

    plt.tight_layout()
    plt.show()
else:
    print("No behavioral labels loaded. Spectral analysis will proceed without epoch comparison.")
    print("To add labels, load them in Section 3.2 above.")

---
## 6. Preliminary Spectral Analysis

This is the **core signal processing** section. We compute:

1. **Power Spectral Density (PSD)** via Welch's method — the primary tool
2. **Spectrogram** via STFT — time-frequency visualization
3. **Wavelet scalogram** via Morlet CWT — adaptive time-frequency resolution

### Key parameters
- **Sampling rate**: determines Nyquist frequency (max observable freq = fs/2)
- **Window length**: determines frequency resolution (Δf = fs / nperseg)
- **Frequency bands of interest**:
  - Infraslow: 0.01 – 0.1 Hz
  - Slow: 0.1 – 1 Hz
  - Delta: 1 – 4 Hz
  - Theta: 4 – 7 Hz (upper limit depends on Nyquist)

In [None]:
# --- Frequency bands for calcium imaging ---
FREQ_BANDS = {
    "infraslow": (0.01, 0.1),
    "slow":      (0.1, 1.0),
    "delta":     (1.0, 4.0),
    "theta":     (4.0, min(7.0, sampling_rate / 2 - 0.5)),
}

print(f"Sampling rate: {sampling_rate} Hz")
print(f"Nyquist frequency: {sampling_rate / 2} Hz")
print(f"\nFrequency bands:")
for name, (lo, hi) in FREQ_BANDS.items():
    print(f"  {name:12s}: {lo:.2f} – {hi:.2f} Hz")

### 6.1 Power Spectral Density (Welch's method)

Welch's method averages periodograms over overlapping windows to reduce variance. This is the standard approach for calcium imaging spectral analysis (Bhatt et al. 2013, Frontiers in Neural Circuits).

In [None]:
# Welch PSD parameters
nperseg = int(10 * sampling_rate)   # 10-second windows
noverlap = nperseg // 2             # 50% overlap
nfft = max(1024, nperseg * 2)       # Zero-pad for smoother PSD

print(f"Welch parameters:")
print(f"  Window length: {nperseg} samples ({nperseg/sampling_rate:.1f} s)")
print(f"  Overlap: {noverlap} samples ({noverlap/sampling_rate:.1f} s)")
print(f"  FFT points: {nfft}")
print(f"  Frequency resolution: {sampling_rate / nperseg:.3f} Hz")

In [None]:
# Compute PSD for all neurons
n_neurons = calcium_data.shape[0]
f_welch, psd_all = welch(
    calcium_data,
    fs=sampling_rate,
    nperseg=nperseg,
    noverlap=noverlap,
    nfft=nfft,
    axis=1
)

print(f"PSD computed: {psd_all.shape} (neurons x frequency bins)")
print(f"Frequency range: {f_welch[1]:.4f} – {f_welch[-1]:.2f} Hz ({len(f_welch)} bins)")

# Plot: Mean PSD across all neurons
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Linear scale
mean_psd = psd_all.mean(axis=0)
sem_psd = psd_all.std(axis=0) / np.sqrt(n_neurons)

axes[0].plot(f_welch, mean_psd, color="#2c3e50", linewidth=1.5)
axes[0].fill_between(f_welch, mean_psd - sem_psd, mean_psd + sem_psd, alpha=0.3, color="#3498db")
axes[0].set_xlabel("Frequency (Hz)")
axes[0].set_ylabel("Power spectral density")
axes[0].set_title("Mean PSD (linear scale)")

# Shade frequency bands
band_colors = {"infraslow": "#f39c12", "slow": "#2ecc71", "delta": "#3498db", "theta": "#9b59b6"}
for name, (lo, hi) in FREQ_BANDS.items():
    for ax in axes:
        ax.axvspan(lo, hi, alpha=0.1, color=band_colors[name], label=name)

# Log-log scale (reveals 1/f structure)
axes[1].loglog(f_welch[1:], mean_psd[1:], color="#2c3e50", linewidth=1.5)
axes[1].fill_between(f_welch[1:], (mean_psd - sem_psd)[1:], (mean_psd + sem_psd)[1:],
                      alpha=0.3, color="#3498db")
axes[1].set_xlabel("Frequency (Hz)")
axes[1].set_ylabel("PSD")
axes[1].set_title("Mean PSD (log-log scale — check for 1/f)")

axes[0].legend(fontsize=8, loc="upper right")
plt.tight_layout()
plt.show()

In [None]:
# --- PSD heatmap across all neurons ---
fig, ax = plt.subplots(figsize=(14, 6))

# Limit to meaningful frequency range
freq_mask = f_welch <= 10  # Up to 10 Hz (or Nyquist)
psd_plot = 10 * np.log10(psd_all[:, freq_mask] + 1e-12)  # dB scale

im = ax.imshow(psd_plot, aspect="auto", cmap="viridis",
               extent=[f_welch[freq_mask][0], f_welch[freq_mask][-1], n_neurons, 0])
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("Neuron")
ax.set_title("PSD heatmap across neurons (dB)")
plt.colorbar(im, label="Power (dB)")

plt.tight_layout()
plt.show()

### 6.2 Spectrogram (STFT)

The Short-Time Fourier Transform shows how spectral content **evolves over time**. This lets us visually check whether spectral features change when behavior switches between social and solo.

In [None]:
# Compute spectrogram for a representative neuron
# Pick the neuron with highest variance (most active)
active_neuron_idx = np.argmax(calcium_data.std(axis=1))
print(f"Most active neuron: #{active_neuron_idx} (std = {calcium_data[active_neuron_idx].std():.4f})")

# STFT parameters
stft_nperseg = int(5 * sampling_rate)  # 5-second window
f_stft, t_stft, Zxx = stft(
    calcium_data[active_neuron_idx],
    fs=sampling_rate,
    nperseg=stft_nperseg,
    noverlap=stft_nperseg // 2
)

fig, axes = plt.subplots(3, 1, figsize=(16, 10), gridspec_kw={"height_ratios": [1, 3, 1]})

# Top: raw trace
axes[0].plot(t, calcium_data[active_neuron_idx], linewidth=0.5, color="#2c3e50")
if behavior_labels is not None:
    social_mask = behavior_labels == 1
    starts = np.where(np.diff(social_mask.astype(int)) == 1)[0]
    ends = np.where(np.diff(social_mask.astype(int)) == -1)[0]
    if social_mask[0]: starts = np.concatenate([[0], starts])
    if social_mask[-1]: ends = np.concatenate([ends, [len(social_mask) - 1]])
    for s, e in zip(starts, ends):
        axes[0].axvspan(t[s], t[e], alpha=0.2, color="#e74c3c")
axes[0].set_ylabel("ΔF/F")
axes[0].set_title(f"Neuron #{active_neuron_idx} — Calcium trace + Spectrogram")

# Middle: spectrogram
freq_limit = min(10, sampling_rate / 2)
freq_mask_stft = f_stft <= freq_limit
axes[1].pcolormesh(t_stft, f_stft[freq_mask_stft],
                   10 * np.log10(np.abs(Zxx[freq_mask_stft]) ** 2 + 1e-12),
                   shading="gouraud", cmap="magma")
axes[1].set_ylabel("Frequency (Hz)")
axes[1].set_title("Spectrogram (STFT, dB)")

# Overlay band boundaries
for name, (lo, hi) in FREQ_BANDS.items():
    if hi <= freq_limit:
        axes[1].axhline(lo, color="white", linewidth=0.5, alpha=0.5, linestyle="--")
        axes[1].axhline(hi, color="white", linewidth=0.5, alpha=0.5, linestyle="--")

# Bottom: behavioral labels
if behavior_labels is not None:
    axes[2].fill_between(t, behavior_labels, step="pre", alpha=0.7, color="#e74c3c")
    axes[2].set_ylabel("Social")
    axes[2].set_yticks([0, 1])
    axes[2].set_yticklabels(["Solo", "Social"])
axes[2].set_xlabel("Time (s)")

plt.tight_layout()
plt.show()

### 6.3 Wavelet Scalogram (Morlet CWT)

Morlet wavelets provide **adaptive time-frequency resolution**: better frequency resolution at low frequencies (where calcium signals live) and better time resolution at higher frequencies. This makes them particularly well-suited for non-stationary neural signals (Cohen 2019, NeuroImage).

In [None]:
# Morlet wavelet transform on the active neuron
# Define scales corresponding to our frequency bands of interest
freq_range = np.linspace(0.5, min(9, sampling_rate / 2 - 0.5), 60)
wavelet_name = "cmor1.5-1.0"  # Complex Morlet wavelet (bandwidth=1.5, center_freq=1.0)

# Compute scales from desired frequencies
scales = pywt.central_frequency(wavelet_name) * sampling_rate / freq_range

# Use a shorter segment for computation (wavelets are expensive)
segment_duration = min(120, duration)  # First 120 seconds
segment_samples = int(segment_duration * sampling_rate)
sig_segment = calcium_data[active_neuron_idx, :segment_samples]
t_segment = t[:segment_samples]

print(f"Computing CWT on {segment_duration}s segment ({segment_samples} samples)...")
coeffs, freqs = pywt.cwt(sig_segment, scales, wavelet_name, sampling_period=1/sampling_rate)
power = np.abs(coeffs) ** 2
print(f"Scalogram shape: {power.shape} (frequencies x time)")

fig, axes = plt.subplots(2, 1, figsize=(16, 8), gridspec_kw={"height_ratios": [1, 3]})

# Top: calcium trace
axes[0].plot(t_segment, sig_segment, linewidth=0.5, color="#2c3e50")
if behavior_labels is not None:
    bl_seg = behavior_labels[:segment_samples]
    social_mask = bl_seg == 1
    starts = np.where(np.diff(social_mask.astype(int)) == 1)[0]
    ends = np.where(np.diff(social_mask.astype(int)) == -1)[0]
    if social_mask[0]: starts = np.concatenate([[0], starts])
    if social_mask[-1]: ends = np.concatenate([ends, [len(social_mask) - 1]])
    for s, e in zip(starts, ends):
        axes[0].axvspan(t_segment[s], t_segment[e], alpha=0.2, color="#e74c3c")
axes[0].set_ylabel("ΔF/F")
axes[0].set_title(f"Neuron #{active_neuron_idx} — Morlet Wavelet Scalogram")

# Bottom: scalogram
axes[1].pcolormesh(t_segment, freqs, 10 * np.log10(power + 1e-12),
                   shading="gouraud", cmap="magma")
axes[1].set_ylabel("Frequency (Hz)")
axes[1].set_xlabel("Time (s)")
axes[1].set_title("Wavelet power (dB)")

# Mark frequency bands
for name, (lo, hi) in FREQ_BANDS.items():
    axes[1].axhline(lo, color="white", linewidth=0.5, alpha=0.5, linestyle="--")

plt.tight_layout()
plt.show()

---
## 7. Social vs. Solo: Initial Spectral Comparison

The key test: do spectral features **differ** between social and solo epochs? This is the foundation of our classification approach.

### Method
1. Segment calcium traces into non-overlapping windows within each behavioral state
2. Compute spectral features per window
3. Compare distributions between social and solo conditions

In [None]:
def compute_spectral_features(segment, fs, freq_bands):
    """Compute spectral features from a 1D signal segment.

    Returns a dict of features: band powers, spectral entropy,
    peak frequency, spectral centroid, and band power ratios.
    """
    # Welch PSD
    nperseg_local = min(len(segment), int(5 * fs))
    f, psd = welch(segment, fs=fs, nperseg=nperseg_local, noverlap=nperseg_local // 2)

    features = {}

    # Band powers
    for name, (lo, hi) in freq_bands.items():
        mask = (f >= lo) & (f < hi)
        if mask.sum() > 0:
            features[f"power_{name}"] = np.trapz(psd[mask], f[mask])
        else:
            features[f"power_{name}"] = 0.0

    # Total power
    total_power = np.trapz(psd[f > 0], f[f > 0])
    features["total_power"] = total_power

    # Relative band powers
    if total_power > 0:
        for name in freq_bands:
            features[f"relpower_{name}"] = features[f"power_{name}"] / total_power

    # Spectral entropy (Shannon entropy of normalized PSD)
    psd_norm = psd[f > 0] / (psd[f > 0].sum() + 1e-12)
    psd_norm = psd_norm[psd_norm > 0]
    features["spectral_entropy"] = -np.sum(psd_norm * np.log2(psd_norm))

    # Peak frequency
    features["peak_freq"] = f[np.argmax(psd)]

    # Spectral centroid
    if psd[f > 0].sum() > 0:
        features["spectral_centroid"] = np.sum(f[f > 0] * psd[f > 0]) / np.sum(psd[f > 0])
    else:
        features["spectral_centroid"] = 0.0

    # Spectral edge frequency (90% of power)
    cum_power = np.cumsum(psd[f > 0])
    if cum_power[-1] > 0:
        edge_idx = np.searchsorted(cum_power, 0.9 * cum_power[-1])
        features["spectral_edge_90"] = f[f > 0][min(edge_idx, len(f[f > 0]) - 1)]
    else:
        features["spectral_edge_90"] = 0.0

    # Band power ratios
    if features.get("power_delta", 0) > 0:
        features["theta_delta_ratio"] = features.get("power_theta", 0) / features["power_delta"]
    else:
        features["theta_delta_ratio"] = 0.0

    return features

print("Feature extraction function defined.")
print(f"Features per segment: {len(compute_spectral_features(calcium_data[0, :int(10*sampling_rate)], sampling_rate, FREQ_BANDS))}")

In [None]:
if behavior_labels is not None:
    # Segment data into fixed-length windows within each behavioral state
    window_sec = 5.0  # 5-second windows
    window_samples = int(window_sec * sampling_rate)
    min_epoch_samples = window_samples  # Skip epochs shorter than one window

    all_features = []  # List of (features_dict, label, neuron_idx)

    # Process each neuron
    for neuron_idx in range(n_neurons):
        trace = calcium_data[neuron_idx]

        # Walk through the signal in windows
        for start in range(0, len(trace) - window_samples, window_samples):
            end = start + window_samples
            window_labels = behavior_labels[start:end]

            # Only use windows that are purely one state (>90% same label)
            label_frac = window_labels.mean()
            if label_frac > 0.9:
                label = 1  # Social
            elif label_frac < 0.1:
                label = 0  # Solo
            else:
                continue  # Skip mixed windows

            feats = compute_spectral_features(trace[start:end], sampling_rate, FREQ_BANDS)
            feats["neuron"] = neuron_idx
            feats["label"] = label
            feats["start_time"] = start / sampling_rate
            all_features.append(feats)

    df_features = pd.DataFrame(all_features)
    print(f"Feature extraction complete:")
    print(f"  Total windows: {len(df_features)}")
    print(f"  Social windows: {(df_features['label'] == 1).sum()}")
    print(f"  Solo windows: {(df_features['label'] == 0).sum()}")
    print(f"  Features per window: {len([c for c in df_features.columns if c not in ['neuron', 'label', 'start_time']])}")
    print(f"\n{df_features.head()}")
else:
    print("No behavioral labels — cannot compare conditions.")

In [None]:
if behavior_labels is not None:
    # --- Compare key spectral features between social and solo ---
    feature_cols = [c for c in df_features.columns
                    if c not in ["neuron", "label", "start_time"]]

    # Select the most interpretable features for visualization
    plot_features = [
        "power_delta", "power_theta", "spectral_entropy",
        "spectral_centroid", "theta_delta_ratio", "total_power"
    ]
    plot_features = [f for f in plot_features if f in df_features.columns]

    n_plot = len(plot_features)
    fig, axes = plt.subplots(2, (n_plot + 1) // 2, figsize=(16, 8))
    axes = axes.ravel()

    for i, feat in enumerate(plot_features):
        social_vals = df_features.loc[df_features["label"] == 1, feat]
        solo_vals = df_features.loc[df_features["label"] == 0, feat]

        axes[i].hist(solo_vals, bins=30, alpha=0.6, color="#3498db", label="Solo", density=True)
        axes[i].hist(social_vals, bins=30, alpha=0.6, color="#e74c3c", label="Social", density=True)
        axes[i].set_title(feat, fontsize=10)
        axes[i].legend(fontsize=8)

        # Quick statistical test (Mann-Whitney U)
        from scipy.stats import mannwhitneyu
        stat, pval = mannwhitneyu(social_vals, solo_vals, alternative="two-sided")
        sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else "ns"
        axes[i].set_xlabel(f"p={pval:.2e} {sig}", fontsize=9)

    # Hide unused subplots
    for j in range(n_plot, len(axes)):
        axes[j].set_visible(False)

    plt.suptitle("Spectral Features: Social vs. Solo", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
if behavior_labels is not None:
    # --- Average PSD: social vs solo ---
    # Compute PSD separately for social and solo segments
    social_psds = []
    solo_psds = []

    for neuron_idx in range(n_neurons):
        trace = calcium_data[neuron_idx]
        for start in range(0, len(trace) - window_samples, window_samples):
            end = start + window_samples
            label_frac = behavior_labels[start:end].mean()

            nperseg_local = min(window_samples, int(5 * sampling_rate))
            f_seg, psd_seg = welch(trace[start:end], fs=sampling_rate,
                                   nperseg=nperseg_local, noverlap=nperseg_local // 2)

            if label_frac > 0.9:
                social_psds.append(psd_seg)
            elif label_frac < 0.1:
                solo_psds.append(psd_seg)

    social_psds = np.array(social_psds)
    solo_psds = np.array(solo_psds)

    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    # Mean PSD comparison
    social_mean = social_psds.mean(axis=0)
    social_sem = social_psds.std(axis=0) / np.sqrt(len(social_psds))
    solo_mean = solo_psds.mean(axis=0)
    solo_sem = solo_psds.std(axis=0) / np.sqrt(len(solo_psds))

    axes[0].semilogy(f_seg, social_mean, color="#e74c3c", linewidth=1.5, label=f"Social (n={len(social_psds)})")
    axes[0].fill_between(f_seg, social_mean - social_sem, social_mean + social_sem, alpha=0.3, color="#e74c3c")
    axes[0].semilogy(f_seg, solo_mean, color="#3498db", linewidth=1.5, label=f"Solo (n={len(solo_psds)})")
    axes[0].fill_between(f_seg, solo_mean - solo_sem, solo_mean + solo_sem, alpha=0.3, color="#3498db")
    axes[0].set_xlabel("Frequency (Hz)")
    axes[0].set_ylabel("PSD (log scale)")
    axes[0].set_title("Mean PSD: Social vs. Solo")
    axes[0].legend()

    # Shade bands
    for name, (lo, hi) in FREQ_BANDS.items():
        axes[0].axvspan(lo, hi, alpha=0.08, color=band_colors[name])

    # Log ratio (social / solo)
    with np.errstate(divide="ignore", invalid="ignore"):
        log_ratio = np.log2(social_mean / (solo_mean + 1e-12))
    axes[1].plot(f_seg, log_ratio, color="#2c3e50", linewidth=1.5)
    axes[1].axhline(0, color="gray", linestyle="--", linewidth=0.8)
    axes[1].set_xlabel("Frequency (Hz)")
    axes[1].set_ylabel("log2(Social / Solo)")
    axes[1].set_title("PSD log-ratio (>0 = more power during social)")
    for name, (lo, hi) in FREQ_BANDS.items():
        axes[1].axvspan(lo, hi, alpha=0.08, color=band_colors[name])

    plt.tight_layout()
    plt.show()

    print(f"Social PSD windows: {len(social_psds)}")
    print(f"Solo PSD windows: {len(solo_psds)}")

### 7.1 Quick classification sanity check

Fit a simple classifier to see if spectral features carry **any** discriminative information. This is NOT the final analysis — just a sanity check.

**Important:** We use block cross-validation to avoid data leakage from temporal autocorrelation.

In [None]:
if behavior_labels is not None:
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    from sklearn.svm import SVC
    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import GroupKFold
    from sklearn.metrics import accuracy_score, roc_auc_score, classification_report

    # Prepare feature matrix
    feature_cols_clean = [c for c in df_features.columns
                          if c not in ["neuron", "label", "start_time"]]
    X = df_features[feature_cols_clean].values
    y = df_features["label"].values

    # Groups for block CV: use time bins to prevent temporal leakage
    # Each group is a ~30-second block of time
    groups = (df_features["start_time"] // 30).astype(int).values

    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Replace NaN/Inf with 0 (edge cases from very short segments)
    X_scaled = np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0)

    # Block cross-validation
    n_unique_groups = len(np.unique(groups))
    n_splits = min(5, n_unique_groups)
    gkf = GroupKFold(n_splits=n_splits)

    classifiers = {
        "LDA": LinearDiscriminantAnalysis(),
        "SVM (linear)": SVC(kernel="linear", probability=True),
        "Logistic Regression": LogisticRegression(max_iter=1000),
    }

    print(f"Classification sanity check ({n_splits}-fold GroupKFold CV)")
    print(f"  Samples: {len(y)} ({(y==1).sum()} social, {(y==0).sum()} solo)")
    print(f"  Features: {X_scaled.shape[1]}")
    print(f"  Groups: {n_unique_groups} time blocks")
    print(f"{'='*60}")

    for clf_name, clf in classifiers.items():
        fold_accs = []
        fold_aucs = []

        for train_idx, test_idx in gkf.split(X_scaled, y, groups):
            X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
                continue

            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            y_proba = clf.predict_proba(X_test)[:, 1]

            fold_accs.append(accuracy_score(y_test, y_pred))
            fold_aucs.append(roc_auc_score(y_test, y_proba))

        if fold_accs:
            print(f"\n  {clf_name}:")
            print(f"    Accuracy: {np.mean(fold_accs):.3f} +/- {np.std(fold_accs):.3f}")
            print(f"    AUC:      {np.mean(fold_aucs):.3f} +/- {np.std(fold_aucs):.3f}")

In [None]:
if behavior_labels is not None:
    # --- Permutation test for chance-level baseline ---
    n_permutations = 200
    perm_accs = []

    print(f"Running permutation test ({n_permutations} shuffles)...")
    clf_perm = LogisticRegression(max_iter=1000)

    for perm_i in range(n_permutations):
        y_shuffled = np.random.permutation(y)
        fold_accs_perm = []
        for train_idx, test_idx in gkf.split(X_scaled, y_shuffled, groups):
            X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
            y_train, y_test = y_shuffled[train_idx], y_shuffled[test_idx]
            if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
                continue
            clf_perm.fit(X_train, y_train)
            fold_accs_perm.append(accuracy_score(y_test, clf_perm.predict(X_test)))
        if fold_accs_perm:
            perm_accs.append(np.mean(fold_accs_perm))

    perm_accs = np.array(perm_accs)
    real_acc = np.mean(fold_accs)  # From last classifier above
    p_value = np.mean(perm_accs >= real_acc)

    fig, ax = plt.subplots(figsize=(10, 4))
    ax.hist(perm_accs, bins=30, color="#95a5a6", alpha=0.7, label="Permutation null")
    ax.axvline(real_acc, color="#e74c3c", linewidth=2, linestyle="--",
               label=f"Real accuracy = {real_acc:.3f}")
    ax.set_xlabel("Accuracy")
    ax.set_ylabel("Count")
    ax.set_title(f"Permutation test (p = {p_value:.4f})")
    ax.legend()
    plt.tight_layout()
    plt.show()

    print(f"\nPermutation test result:")
    print(f"  Real accuracy: {real_acc:.3f}")
    print(f"  Chance mean:   {perm_accs.mean():.3f} +/- {perm_accs.std():.3f}")
    print(f"  p-value:       {p_value:.4f}")
    if p_value < 0.05:
        print("  -> Spectral features carry SIGNIFICANT discriminative information!")
    else:
        print("  -> Not significant at p<0.05. May need more data or better features.")

---
## 8. Summary & Next Steps

### What we learned from this EDA

| Question | Finding |
|----------|---------|
| Data format | *(fill in after running)* |
| Sampling rate | *(fill in)* Hz — Nyquist = *(fill in)* Hz |
| Photobleaching | Present / Not present — needs detrending? |
| Class balance | Social: *X*% vs Solo: *Y*% — balanced / imbalanced? |
| PSD structure | 1/f present? Peaks in specific bands? |
| Social vs Solo PSD | Visible differences in which bands? |
| Classification | Accuracy: *X*% (chance: *Y*%, p = *Z*) |

### Next steps for the full project

1. **Load real data** — replace synthetic data with actual calcium traces from EDGE notebooks
2. **Preprocessing pipeline** — implement detrending, motion artifact removal, z-scoring
3. **Feature engineering** — refine spectral features based on EDA findings
4. **Proper cross-validation** — implement leave-one-bout-out or leave-one-animal-out CV
5. **Movement confound control** — add movement speed as a covariate
6. **Wavelet-based features** — extract time-resolved spectral features from CWT
7. **Final report** — write up with publication-quality figures