# Estrus-Linked Locomotor Plateau Detection  
## Morph2REP (Study 1001, Version 2025v3.3)

This notebook analyzes estrus-associated locomotor structure in female strain 664 mice
from the Morph2REP study (Study 1001, Version 2025v3.3).

Replicate 1:
- Dates: 2025-01-07 to 2025-01-22
- Cages: 4917–4925

Replicate 2:
- Dates: 2025-01-22 to 2025-02-04
- Cages: 4926–4934

---

## Scientific Motivation

Prior work (Smarr et al.) suggests estrus does not increase behavioral noise,
but instead introduces structured changes in locomotor activity,
specifically a prolonged dark-phase plateau of near-peak activity.

This notebook tests whether estrus-linked structure can be detected
*without ground-truth staging*, using only behavioral time series.

---

## Core Analysis Strategy

1. Extract per-animal locomotion (60s resolution)
2. Define an animal-specific plateau threshold (e.g., 85th percentile)
3. Compute plateau duration per day
4. Rank days by plateau duration
5. Test whether top-k days differ from the rest using permutation testing

---

## Primary Outputs

- Per-animal plateau duration
- Per-cage rankings
- Per-replicate rankings
- Top-k vs rest permutation tests
- Statistical significance of estrus-like structure

---

All timestamps are aligned to the 12h:12h light-dark cycle (6AM–6PM EST lights on).


## 1. Data Loading & Configuration

Load Morph2REP locomotor data from animal_activity_db or
animal_aggs_short_id at 60-second resolution.

Define:
- Replicate cages
- Time window
- Light-dark boundaries
- Percentile threshold for plateau detection


## 2. Extract Per-Animal Locomotion (60s Resolution)

Filter:
- name == 'animal_bouts.locomotion'
- resolution == 60

Compute:
- Day index relative to study start
- Dark-phase mask (6PM–6AM EST)
- Per-animal daily locomotion vectors


In [None]:
import duckdb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# CONFIG
# ============================================================
S3_BASE = "s3://jax-envision-public-data/study_1001/2025v3.3/tabular"

VEHICLE_CAGES = {
    "Rep1": {
        "cages": [4918, 4922, 4923],
        "start": "2025-01-10 06:00:00",
        "end":   "2025-01-22 06:00:00",
    },
    "Rep2": {
        "cages": [4928, 4929, 4934],
        "start": "2025-01-25 06:00:00",
        "end":   "2025-02-04 06:00:00",
    },
}

MINUTES_PER_DAY = 1440
DAYS_PER_CYCLE = 4
MINUTES_PER_CYCLE = DAYS_PER_CYCLE * MINUTES_PER_DAY

# Dark phase: CT12–24 → minutes 720–1440
DARK_START = 720
DARK_END = 1440

# ============================================================
# ROBUST PLATEAU THRESHOLD (fixes outlier-driven max)
# Use an upper-quantile threshold on the MEDIAN profile itself.
# Example: 85th percentile of the full 4-day median profile.
# ============================================================
PLATEAU_Q = 85  # try 80–90; 85 is a good default

# ============================================================
# GLOBAL CACHE (avoid re-loading from S3)
# ============================================================
ACTIVITY_CACHE = {}

# ============================================================
# LOAD DATA
# ============================================================
def load_activity(cage_id, start, end):
    key = (cage_id, start, end)
    if key in ACTIVITY_CACHE:
        return ACTIVITY_CACHE[key]

    conn = duckdb.connect()
    conn.execute("INSTALL httpfs; LOAD httpfs;")
    conn.execute("SET s3_region='us-east-1';")

    query = f"""
    SELECT
        time,
        animal_id,
        value
    FROM read_parquet(
        '{S3_BASE}/cage_id={cage_id}/date=*/animal_activity_db.parquet'
    )
    WHERE
        name = 'animal_bouts.locomotion'
        AND resolution = 60
        AND time BETWEEN TIMESTAMP '{start}' AND TIMESTAMP '{end}'
    ORDER BY time
    """

    df = conn.execute(query).fetchdf()
    conn.close()

    df["time"] = pd.to_datetime(df["time"])
    ACTIVITY_CACHE[key] = df
    return df

# ============================================================
# BUILD MEDIAN 4-DAY PROFILE (PER REPLICATE)
#   - For each cage: average across animals per timestamp
#   - Fold into 4-day cycle: minute_in_cycle
#   - Take mean per minute_in_cycle (across all included cycles)
#   - Take median across cages
# ============================================================
def build_median_4day_profile(cfg):
    baseline = pd.to_datetime(cfg["start"])
    cage_profiles = []

    for cage_id in cfg["cages"]:
        df = load_activity(cage_id, cfg["start"], cfg["end"])

        g = (
            df.groupby("time")["value"]
            .mean()
            .reset_index()
        )

        minutes = ((g["time"] - baseline).dt.total_seconds() / 60).astype(int)
        g["minute"] = minutes

        # keep complete 4-day cycles only
        g = g[g["minute"] >= 0]
        max_full_cycles = (g["minute"].max() + 1) // MINUTES_PER_CYCLE
        g = g[g["minute"] < max_full_cycles * MINUTES_PER_CYCLE]

        g["minute_in_cycle"] = g["minute"] % MINUTES_PER_CYCLE

        cage_profiles.append(
            g.groupby("minute_in_cycle")["value"].mean()
        )

    median_profile = pd.concat(cage_profiles, axis=1).median(axis=1).values
    return median_profile

# ============================================================
# OPTIONAL: Plateau RUN on a 4-day median profile
# (longest contiguous run above threshold, with gap tolerance)
# ============================================================
def longest_run_with_gap(mask: np.ndarray, gap: int = 1) -> int:
    best = cur = gap_left = 0
    for v in mask:
        if v:
            cur += 1
            gap_left = gap
        elif gap_left > 0:
            cur += 1
            gap_left -= 1
        else:
            cur = 0
            gap_left = gap
        best = max(best, cur)
    return best

# ============================================================
# PLATEAU STATISTICS ON MEDIAN PROFILE (ROBUST THRESHOLD)
#   - Threshold = percentile(whole 4-day median profile, PLATEAU_Q)
#   - Compute per-cycle-day dark-phase plateau duration (minutes)
#   - Also compute plateau_run (longest contiguous run, minutes)
# ============================================================
def plateau_stats_on_median(profile, label, plateau_q=PLATEAU_Q, gap=1):
    thresh = float(np.nanpercentile(profile, plateau_q))

    results = []
    for day in range(DAYS_PER_CYCLE):
        start = day * MINUTES_PER_DAY + DARK_START
        end = day * MINUTES_PER_DAY + DARK_END

        dark = profile[start:end]
        above = dark >= thresh

        plateau_duration = int(np.sum(above))                 # minutes above threshold
        plateau_run = int(longest_run_with_gap(above, gap))   # longest run above threshold (min)

        results.append({
            "cycle_day": day + 1,
            "plateau_duration_min": plateau_duration,
            "plateau_run_min": plateau_run
        })

    out = pd.DataFrame(results)

    # rank days (separately) by duration and by run
    out = out.sort_values(["plateau_run_min", "plateau_duration_min"], ascending=False).reset_index(drop=True)
    out["rank_by_run_then_duration"] = np.arange(1, len(out) + 1)

    print(f"\n=== Plateau ranking for {label} ===")
    print(f"Threshold = {plateau_q}th percentile of 4-day median profile = {thresh:.4f}")
    print(out.to_string(index=False))

    return out, thresh

# ============================================================
# PLOTTING
# ============================================================
def plot_median_profile(profile, title, thresh, plateau_q=PLATEAU_Q):
    t_hours = np.arange(len(profile)) / 60.0

    plt.figure(figsize=(14, 4))
    plt.plot(t_hours, profile, linewidth=1.5)
    plt.axhline(thresh, color="red", linestyle="--", label=f"{plateau_q}th percentile threshold")

    # Dark phase shading (CT12–24 each day)
    for d in range(DAYS_PER_CYCLE):
        plt.axvspan(d * 24 + 12, d * 24 + 24, color="black", alpha=0.10)

    plt.xlabel("Time (hours across 4-day cycle)")
    plt.ylabel("Locomotion (fraction of time)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

# ============================================================
# RUN FOR REP1 + REP2
# ============================================================
rep1_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep1"])
rep2_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep2"])

rank_rep1, thresh_rep1 = plateau_stats_on_median(rep1_median_4day, "JAX Rep1 median 4-day profile", plateau_q=PLATEAU_Q, gap=1)
rank_rep2, thresh_rep2 = plateau_stats_on_median(rep2_median_4day, "JAX Rep2 median 4-day profile", plateau_q=PLATEAU_Q, gap=1)

plot_median_profile(rep1_median_4day, "JAX Rep1 – Median 4-Day Profile (Vehicle)", thresh_rep1, plateau_q=PLATEAU_Q)
plot_median_profile(rep2_median_4day, "JAX Rep2 – Median 4-Day Profile (Vehicle)", thresh_rep2, plateau_q=PLATEAU_Q)


## 3. Plateau Definition, Daily Computation, and Ranking

### Step 1: Define Animal-Specific Plateau Threshold

For each animal:

- Extract all dark-phase locomotion values across the analysis window
- Compute the q-th percentile (e.g., 85th percentile)
- Define this value as that animal’s plateau threshold

Rationale:
Using a relative (within-animal) threshold controls for baseline
activity differences across animals and focuses on near-peak behavior.

---

### Step 2: Compute Daily Plateau Duration

For each animal and day:

Plateau duration =
Number of dark-phase minutes where locomotion ≥ threshold

This produces:

- Per-animal daily plateau duration
- A structured table indexed by:
  replicate → cage → animal → day

---

### Step 3: Rank Days by Estrus-Like Structure

Within cage or replicate:

- Compute mean plateau duration per day (across animals)
- Rank days in descending order

Higher rank → stronger estrus-like plateau structure  
Top-k days → candidate estrus days

This ranking forms the basis for the permutation testing step.


In [None]:
# ============================================================
# PLOTTING
# ============================================================
def plot_median_profile(profile, title, thresh, plateau_q=PLATEAU_Q):
    t_hours = np.arange(len(profile)) / 60.0

    plt.figure(figsize=(14, 4))
    plt.plot(t_hours, profile, linewidth=0.5)
    plt.axhline(thresh, color="red", linestyle="--", label=f"{plateau_q}th percentile threshold")

    # Dark phase shading (CT12–24 each day)
    for d in range(DAYS_PER_CYCLE):
        plt.axvspan(d * 24 + 12, d * 24 + 24, color="black", alpha=0.10)

    plt.xlabel("Time (hours across 4-day cycle)")
    plt.ylabel("Locomotion (fraction of time)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

# ============================================================
# RUN FOR REP1 + REP2
# ============================================================
rep1_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep1"])
rep2_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep2"])

rank_rep1, thresh_rep1 = plateau_stats_on_median(rep1_median_4day, "JAX Rep1 median 4-day profile", plateau_q=PLATEAU_Q, gap=1)
rank_rep2, thresh_rep2 = plateau_stats_on_median(rep2_median_4day, "JAX Rep2 median 4-day profile", plateau_q=PLATEAU_Q, gap=1)

plot_median_profile(rep1_median_4day, "JAX Rep1 – Median 4-Day Profile (Vehicle)", thresh_rep1, plateau_q=PLATEAU_Q)
plot_median_profile(rep2_median_4day, "JAX Rep2 – Median 4-Day Profile (Vehicle)", thresh_rep2, plateau_q=PLATEAU_Q)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# CONFIG
# ============================================================
MINUTES_PER_DAY = 1440
DAYS_PER_CYCLE = 4
MINUTES_PER_CYCLE = DAYS_PER_CYCLE * MINUTES_PER_DAY

# Dark phase in Smarr convention: first 12h of each day
DARK_START = 0
DARK_END = 720

# Plateau definition: fraction of global max of the median profile
PLATEAU_FRAC = 0.80   # 80% of max is robust and matches "near-peak plateau"

# ============================================================
# LOAD SMARR FEMALE DATA (ALIGNED)
# ============================================================
df_smarr = pd.read_csv("mice data.xlsx - FemAct (1).csv")

female_cols = [c for c in df_smarr.columns if str(c).lower().startswith("fem")]
if len(female_cols) == 0:
    raise ValueError("No female columns found (expected columns starting with 'fem').")

female_data = df_smarr[female_cols].to_numpy().T  # (n_females, n_timepoints)
n_females, n_timepoints = female_data.shape

print(f"Loaded Smarr data: {n_females} females, {n_timepoints} minutes")

# Only keep complete 4-day cycles
n_cycles = n_timepoints // MINUTES_PER_CYCLE
if n_cycles < 1:
    raise ValueError("Not enough data for a complete 4-day cycle.")

female_data = female_data[:, :n_cycles * MINUTES_PER_CYCLE]

# ============================================================
# BUILD MEDIAN 4-DAY PROFILE (ALIGNED)
# ============================================================
reshaped = female_data.reshape(
    n_females, n_cycles, MINUTES_PER_CYCLE
)  # (females, cycles, minutes)

median_per_female = np.median(reshaped, axis=1)       # (females, minutes)
smarr_median_4day = np.median(median_per_female, axis=0)  # (minutes,)

# ============================================================
# MAXIMAL DEALIGNMENT (NEGATIVE CONTROL)
# ============================================================
misaligned = np.empty_like(female_data)
for i in range(n_females):
    shift = (i * MINUTES_PER_DAY) % female_data.shape[1]
    misaligned[i] = np.roll(female_data[i], -shift)

reshaped_mis = misaligned.reshape(
    n_females, n_cycles, MINUTES_PER_CYCLE
)
median_per_female_mis = np.median(reshaped_mis, axis=1)
smarr_misaligned_median_4day = np.median(median_per_female_mis, axis=0)

# ============================================================
# PLATEAU STATISTICS ON MEDIAN 4-DAY PROFILE
# ============================================================
def plateau_stats_on_median(profile, label):
    """
    profile: 1D array, length = 4 * 1440
    Computes plateau duration (minutes) per cycle day in dark phase
    using a fixed threshold = PLATEAU_FRAC * global max
    """
    global_max = np.max(profile)
    thresh = PLATEAU_FRAC * global_max

    results = []
    for day in range(DAYS_PER_CYCLE):
        start = day * MINUTES_PER_DAY + DARK_START
        end = day * MINUTES_PER_DAY + DARK_END

        dark = profile[start:end]
        above = dark >= thresh

        plateau_duration = above.sum()  # minutes
        results.append({
            "cycle_day": day + 1,
            "plateau_duration_min": plateau_duration
        })

    df = pd.DataFrame(results).sort_values("plateau_duration_min", ascending=False)
    df["rank"] = np.arange(1, len(df) + 1)

    print(f"\n=== Plateau ranking for {label} ===")
    print(df.to_string(index=False))

    return df

# ============================================================
# RUN POSITIVE + NEGATIVE CONTROLS (SMARR)
# ============================================================
rank_smarr_aligned = plateau_stats_on_median(
    smarr_median_4day, "Smarr ALIGNED (positive control)"
)

rank_smarr_misaligned = plateau_stats_on_median(
    smarr_misaligned_median_4day, "Smarr MISALIGNED (negative control)"
)

# ============================================================
# OPTIONAL: RUN ON JAX REPLICATES IF MEDIANS ARE PROVIDED
# ============================================================
# If you already computed these elsewhere, just define:
#   rep1_median_4day
#   rep2_median_4day
#
# Each should be a 1D array of length 5760 (4 * 1440)

if "rep1_median_4day" in globals():
    plateau_stats_on_median(rep1_median_4day, "JAX Rep1 median 4-day profile")
else:
    print("\n[INFO] rep1_median_4day not found — skipping Rep1.")

if "rep2_median_4day" in globals():
    plateau_stats_on_median(rep2_median_4day, "JAX Rep2 median 4-day profile")
else:
    print("[INFO] rep2_median_4day not found — skipping Rep2.")

# ============================================================
# VISUALIZATION: MEDIAN PROFILES WITH PLATEAU THRESHOLD
# ============================================================
def plot_median_profile(profile, title):
    t_hours = np.arange(len(profile)) / 60.0
    thresh = PLATEAU_FRAC * np.max(profile)

    plt.figure(figsize=(14, 4))
    plt.plot(t_hours, profile, linewidth=2)
    plt.axhline(thresh, color="red", linestyle="--", label=f"{int(PLATEAU_FRAC*100)}% of max")

    # Dark phase bars
    for d in range(DAYS_PER_CYCLE):
        plt.axvspan(d*24, d*24+12, color="black", alpha=0.1)

    plt.xlabel("Time (hours across 4-day cycle)")
    plt.ylabel("Activity")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_median_profile(smarr_median_4day, "Smarr ALIGNED — Median 4-Day Profile")
plot_median_profile(smarr_misaligned_median_4day, "Smarr MISALIGNED — Median 4-Day Profile")


In [None]:
# ============================================================
# BUILD MEDIAN 4-DAY PROFILE (PER REPLICATE)
#   - For each cage: mean across animals per timestamp
#   - Fold into 4-day cycle: minute_in_cycle
#   - Mean per minute_in_cycle across all included cycles
#   - Median across cages
# ============================================================
def build_median_4day_profile(cfg):
    baseline = pd.to_datetime(cfg["start"])
    cage_profiles = []

    for cage_id in cfg["cages"]:
        df = load_activity(cage_id, cfg["start"], cfg["end"])

        g = (
            df.groupby("time")["value"]
            .mean()
            .reset_index()
        )

        minutes = ((g["time"] - baseline).dt.total_seconds() / 60).astype(int)
        g["minute"] = minutes

        # keep complete 4-day cycles only
        g = g[g["minute"] >= 0]
        max_full_cycles = (g["minute"].max() + 1) // MINUTES_PER_CYCLE
        g = g[g["minute"] < max_full_cycles * MINUTES_PER_CYCLE]

        g["minute_in_cycle"] = g["minute"] % MINUTES_PER_CYCLE

        cage_profiles.append(
            g.groupby("minute_in_cycle")["value"].mean()
        )

    median_profile = pd.concat(cage_profiles, axis=1).median(axis=1).values
    return median_profile

# ============================================================
# PLATEAU DURATION STATS ON MEDIAN PROFILE (ROBUST THRESHOLD)
#   - Threshold = PLATEAU_Q percentile of the 4-day median profile
#   - For each cycle day, compute minutes in dark phase above threshold
# ============================================================
def plateau_duration_stats_on_median(profile, label, plateau_q=PLATEAU_Q):
    thresh = float(np.nanpercentile(profile, plateau_q))

    results = []
    for day in range(DAYS_PER_CYCLE):
        start = day * MINUTES_PER_DAY + DARK_START
        end = day * MINUTES_PER_DAY + DARK_END

        dark = profile[start:end]
        plateau_duration = int(np.sum(dark >= thresh))

        results.append({
            "cycle_day": day + 1,
            "plateau_duration_min": plateau_duration
        })

    out = pd.DataFrame(results).sort_values("plateau_duration_min", ascending=False).reset_index(drop=True)
    out["rank"] = np.arange(1, len(out) + 1)

    print(f"\n=== Plateau duration ranking for {label} ===")
    print(f"Threshold = {plateau_q}th percentile of 4-day median profile = {thresh:.4f}")
    print(out.to_string(index=False))

    return out, thresh

# ============================================================
# PLOTTING
# ============================================================
def plot_median_profile(profile, title, thresh, plateau_q=PLATEAU_Q):
    t_hours = np.arange(len(profile)) / 60.0

    plt.figure(figsize=(14, 4))
    plt.plot(t_hours, profile, linewidth=0.5)
    plt.axhline(thresh, color="red", linestyle="--", label=f"{plateau_q}th percentile threshold")

    # Dark phase shading (CT12–24 each day)
    for d in range(DAYS_PER_CYCLE):
        plt.axvspan(d * 24 + 12, d * 24 + 24, color="black", alpha=0.10)

    plt.xlabel("Time (hours across 4-day cycle)")
    plt.ylabel("Locomotion (fraction of time)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

# ============================================================
# RUN FOR REP1 + REP2
# ============================================================
rep1_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep1"])
rep2_median_4day = build_median_4day_profile(VEHICLE_CAGES["Rep2"])

rank_rep1, thresh_rep1 = plateau_duration_stats_on_median(rep1_median_4day, "JAX Rep1 median 4-day profile", plateau_q=PLATEAU_Q)
rank_rep2, thresh_rep2 = plateau_duration_stats_on_median(rep2_median_4day, "JAX Rep2 median 4-day profile", plateau_q=PLATEAU_Q)

plot_median_profile(rep1_median_4day, "JAX Rep1 – Median 4-Day Profile (Vehicle)", thresh_rep1, plateau_q=PLATEAU_Q)
plot_median_profile(rep2_median_4day, "JAX Rep2 – Median 4-Day Profile (Vehicle)", thresh_rep2, plateau_q=PLATEAU_Q)







In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# SMARR: ALIGNED vs MAXIMALLY DEALIGNED
# Plateau duration on MEDIAN 4-day profile using PERCENTILE threshold
# ============================================================

# ----------------------------
# CONFIG
# ----------------------------
MINUTES_PER_DAY = 1440
DAYS_PER_CYCLE = 4
MINUTES_PER_CYCLE = DAYS_PER_CYCLE * MINUTES_PER_DAY

# Smarr convention in your scripts: t=0 is lights OFF, so dark is first 12h
DARK_START = 0
DARK_END = 720

# Robust threshold: percentile of the full 4-day median profile
PLATEAU_Q = 85  # try 80–90; 85 is a good default

# ----------------------------
# LOAD SMARR FEMALE DATA
# ----------------------------
df_smarr = pd.read_csv("mice data.xlsx - FemAct (1).csv")

female_cols = [c for c in df_smarr.columns if str(c).lower().startswith("fem")]
if len(female_cols) == 0:
    raise ValueError("No female columns found (expected columns starting with 'fem').")

female_data = df_smarr[female_cols].to_numpy().T  # (n_females, n_timepoints)
n_females, n_timepoints = female_data.shape
print(f"Loaded Smarr: n_females={n_females}, n_timepoints={n_timepoints} minutes")

# keep complete 4-day cycles only
n_cycles = n_timepoints // MINUTES_PER_CYCLE
if n_cycles < 1:
    raise ValueError("Not enough data for a complete 4-day cycle.")

female_data = female_data[:, :n_cycles * MINUTES_PER_CYCLE]
print(f"Using complete cycles: n_cycles={n_cycles} (total minutes={female_data.shape[1]})")

# ----------------------------
# BUILD MEDIAN 4-DAY PROFILE (ALIGNED)
# ----------------------------
reshaped = female_data.reshape(n_females, n_cycles, MINUTES_PER_CYCLE)
median_per_female = np.median(reshaped, axis=1)          # (n_females, 5760)
smarr_median_4day = np.median(median_per_female, axis=0) # (5760,)

# ----------------------------
# MAXIMAL DEALIGNMENT (Smarr-style)
# advance each subsequent female by +1 day (circular)
# ----------------------------
misaligned = np.empty_like(female_data)
for i in range(n_females):
    shift = (i * MINUTES_PER_DAY) % female_data.shape[1]
    misaligned[i] = np.roll(female_data[i], -shift)

reshaped_mis = misaligned.reshape(n_females, n_cycles, MINUTES_PER_CYCLE)
median_per_female_mis = np.median(reshaped_mis, axis=1)
smarr_misaligned_median_4day = np.median(median_per_female_mis, axis=0)

# ----------------------------
# PLATEAU DURATION ON MEDIAN PROFILE (percentile threshold)
# ----------------------------
def plateau_duration_rank_on_median(profile, label, q=PLATEAU_Q):
    thresh = float(np.nanpercentile(profile, q))
    rows = []
    for day in range(DAYS_PER_CYCLE):
        s = day * MINUTES_PER_DAY + DARK_START
        e = day * MINUTES_PER_DAY + DARK_END
        dark = profile[s:e]
        plateau_duration = int(np.sum(dark >= thresh))
        rows.append({"cycle_day": day + 1, "plateau_duration_min": plateau_duration})
    out = pd.DataFrame(rows).sort_values("plateau_duration_min", ascending=False).reset_index(drop=True)
    out["rank"] = np.arange(1, len(out) + 1)

    print(f"\n=== {label} ===")
    print(f"Threshold = {q}th percentile of 4-day median profile = {thresh:.4f}")
    print(out.to_string(index=False))
    return out, thresh

rank_aligned, thr_aligned = plateau_duration_rank_on_median(
    smarr_median_4day, f"Smarr ALIGNED (positive control) – plateau duration", q=PLATEAU_Q
)
rank_mis, thr_mis = plateau_duration_rank_on_median(
    smarr_misaligned_median_4day, f"Smarr MISALIGNED (negative control) – plateau duration", q=PLATEAU_Q
)

# ----------------------------
# PLOTS: median profiles + thresholds
# ----------------------------
def plot_profile(profile, title, thresh, q=PLATEAU_Q):
    t_hours = np.arange(len(profile)) / 60.0
    plt.figure(figsize=(14, 4))
    plt.plot(t_hours, profile, linewidth=0.5)
    plt.axhline(thresh, color="red", linestyle="--", label=f"{q}th percentile threshold")
    # dark shading: first 12h each day in this convention
    for d in range(DAYS_PER_CYCLE):
        plt.axvspan(d * 24, d * 24 + 12, color="black", alpha=0.10)
    plt.xlabel("Time (hours across 4-day cycle)")
    plt.ylabel("LA (counts)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_profile(smarr_median_4day, "Smarr ALIGNED — Median 4-Day Profile", thr_aligned, q=PLATEAU_Q)
plot_profile(smarr_misaligned_median_4day, "Smarr MISALIGNED — Median 4-Day Profile", thr_mis, q=PLATEAU_Q)

# ----------------------------
# OPTIONAL: quick robustness sweep over percentiles
# ----------------------------
def robustness_sweep(profile_aligned, profile_mis, qs=(75, 80, 85, 90, 95)):
    rows = []
    for q in qs:
        ra, _ = plateau_duration_rank_on_median(profile_aligned, f"[Sweep] ALIGNED q={q}", q=q)
        rm, _ = plateau_duration_rank_on_median(profile_mis, f"[Sweep] MISALIGNED q={q}", q=q)
        rows.append({
            "q": q,
            "aligned_top_day": int(ra.iloc[0]["cycle_day"]),
            "aligned_top_duration": int(ra.iloc[0]["plateau_duration_min"]),
            "mis_top_day": int(rm.iloc[0]["cycle_day"]),
            "mis_top_duration": int(rm.iloc[0]["plateau_duration_min"]),
        })
    return pd.DataFrame(rows)

# Uncomment if you want the sweep table:
# sweep_df = robustness_sweep(smarr_median_4day, smarr_misaligned_median_4day, qs=(75,80,85,90,95))
# print("\nRobustness sweep summary:")
# print(sweep_df.to_string(index=False))


In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# PLATEAU DURATION ON ALL DAYS
#   - Smarr aligned (all days)
#   - Smarr maximally dealigned (all days)
#   - JAX Rep1 vehicle (all days)  [ASSUMES already loaded in ACTIVITY_CACHE]
#   - JAX Rep2 vehicle (all days)  [ASSUMES already loaded in ACTIVITY_CACHE]
#
# Plateau duration definition (consistent everywhere):
#   For each animal and day:
#     plateau_duration = # of dark-phase minutes with activity >= animal-specific threshold
#   Threshold per animal:
#     threshold = qth percentile of THAT ANIMAL'S dark-phase values across the whole window
#
# Rankings:
#   For each dataset:
#     rank days by mean plateau_duration across animals (descending)
# ============================================================

# ----------------------------
# CONFIG
# ----------------------------
Q_THRESH = 85  # percentile for per-animal threshold (robust; try 80–90)

# Smarr convention in your earlier scripts: dark is first 12h of each day (minutes 0–720)
SMARR_MIN_PER_DAY = 1440
SMARR_DARK_START = 0
SMARR_DARK_END = 720

# JAX convention in your earlier vehicle plots: baseline start = lights-on (6am),
# so dark is 12–24h after baseline each day (minutes 720–1440)
JAX_MIN_PER_DAY = 1440
JAX_DARK_START = 720
JAX_DARK_END = 1440

# For JAX, we will ONLY use these cages (vehicle) and baseline windows.
# (Data should already be loaded in ACTIVITY_CACHE by your earlier loader.)
VEHICLE_CAGES = {
    "Rep1": {
        "cages": [4918, 4922, 4923],
        "start": "2025-01-10 06:00:00",
        "end":   "2025-01-22 06:00:00",
    },
    "Rep2": {
        "cages": [4928, 4929, 4934],
        "start": "2025-01-25 06:00:00",
        "end":   "2025-02-04 06:00:00",
    },
}

# ----------------------------
# HELPERS
# ----------------------------
def _rank_days(df_daily: pd.DataFrame, metric: str = "plateau_duration_min") -> pd.DataFrame:
    """
    df_daily columns expected: animal_id, day, plateau_duration_min (and optionally cage_id/rep)
    Returns day-level ranking by mean plateau_duration across animals.
    """
    out = (
        df_daily.groupby("day")[metric]
        .agg(mean="mean", median="median", sd="std", n="count")
        .reset_index()
        .sort_values("mean", ascending=False)
        .reset_index(drop=True)
    )
    out["rank"] = np.arange(1, len(out) + 1)
    return out

def _plot_day_ranking(rank_df: pd.DataFrame, title: str):
    x = rank_df["day"].to_numpy()
    y = rank_df["mean"].to_numpy()
    s = rank_df["sd"].fillna(0.0).to_numpy()

    plt.figure(figsize=(10, 3.5))
    plt.plot(x, y, linewidth=2, label="Mean across animals")
    plt.fill_between(x, y - s, y + s, alpha=0.2, label="±1 SD")
    plt.xticks(x.astype(int))
    plt.xlabel("Day")
    plt.ylabel("Plateau duration (min)")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

def _heatmap_plateau(df_daily: pd.DataFrame, title: str, zscore_within_animal: bool = True):
    """
    Heatmap: rows animals, cols days, values plateau_duration_min
    """
    mat = df_daily.pivot_table(index="animal_id", columns="day", values="plateau_duration_min", aggfunc="mean")
    if zscore_within_animal:
        mat = mat.sub(mat.mean(axis=1), axis=0)
        mat = mat.div(mat.std(axis=1).replace(0, np.nan), axis=0)

    # sort animals by peak day for readability
    peak_day = mat.idxmax(axis=1)
    mat = mat.loc[peak_day.sort_values().index]

    plt.figure(figsize=(12, 5))
    im = plt.imshow(mat.values, aspect="auto", interpolation="nearest")
    plt.xlabel("Day")
    plt.ylabel("Animal (sorted by peak day)")
    plt.title(title + (" (z-score within animal)" if zscore_within_animal else ""))
    plt.xticks(np.arange(mat.shape[1]), mat.columns.astype(int))
    plt.yticks([])
    cbar = plt.colorbar(im)
    cbar.set_label("Plateau duration (z)" if zscore_within_animal else "Plateau duration (min)")
    plt.tight_layout()
    plt.show()

# ----------------------------
# SMARR: compute per-female per-day plateau_duration (aligned or misaligned)
# ----------------------------
def smarr_daily_plateau(df_smarr: pd.DataFrame, dealign: bool = False, q: int = Q_THRESH) -> pd.DataFrame:
    """
    df_smarr must contain fem* columns (minute-resolution LA counts).
    Returns df_daily with columns: animal_id, day, plateau_duration_min
    """
    female_cols = [c for c in df_smarr.columns if str(c).lower().startswith("fem")]
    if len(female_cols) == 0:
        raise ValueError("Smarr: No female columns found (expected columns starting with 'fem').")

    X = df_smarr[female_cols].to_numpy().T  # (n_females, n_timepoints)
    n_fem, n_tp = X.shape
    n_days = n_tp // SMARR_MIN_PER_DAY
    if n_days < 1:
        raise ValueError("Smarr: Not enough data for at least 1 full day.")

    X = X[:, :n_days * SMARR_MIN_PER_DAY]
    n_tp = X.shape[1]

    if dealign:
        Xm = np.empty_like(X)
        for i in range(n_fem):
            shift = (i * SMARR_MIN_PER_DAY) % n_tp
            Xm[i] = np.roll(X[i], -shift)
        X = Xm

    # Precompute per-female threshold from all dark-phase values across ALL days
    thresholds = np.zeros(n_fem, dtype=float)
    for i in range(n_fem):
        # dark-phase minutes across all days: take first 12h of each day
        dark_vals = []
        for d in range(n_days):
            day_slice = X[i, d*SMARR_MIN_PER_DAY:(d+1)*SMARR_MIN_PER_DAY]
            dark = day_slice[SMARR_DARK_START:SMARR_DARK_END]
            dark_vals.append(dark)
        dark_vals = np.concatenate(dark_vals)
        thresholds[i] = float(np.nanpercentile(dark_vals, q))

    rows = []
    for i in range(n_fem):
        thr = thresholds[i]
        for d in range(n_days):
            day_slice = X[i, d*SMARR_MIN_PER_DAY:(d+1)*SMARR_MIN_PER_DAY]
            dark = day_slice[SMARR_DARK_START:SMARR_DARK_END]
            plateau_duration = int(np.sum(dark >= thr))  # minutes
            rows.append({"animal_id": i, "day": d + 1, "plateau_duration_min": plateau_duration})

    return pd.DataFrame(rows)

# ----------------------------
# JAX: build per-animal per-day plateau_duration from already-loaded ACTIVITY_CACHE
# ----------------------------
def _get_jax_loaded_df_for_rep(rep_label: str, cfg: dict) -> pd.DataFrame:
    """
    Pulls already-loaded cage windows from global ACTIVITY_CACHE without re-querying S3.
    Expects ACTIVITY_CACHE keys like (cage_id, start, end) -> df(time, animal_id, value)
    """
    if "ACTIVITY_CACHE" not in globals():
        raise ValueError("JAX: ACTIVITY_CACHE not found. Run your JAX loading cell first so data are cached.")

    parts = []
    for cage_id in cfg["cages"]:
        key = (cage_id, cfg["start"], cfg["end"])
        if key not in ACTIVITY_CACHE:
            raise ValueError(
                f"JAX: Missing cached data for key={key}. "
                "Load it first with your load_activity(...) caching code."
            )
        df = ACTIVITY_CACHE[key].copy()
        df["cage_id"] = cage_id
        df["rep"] = rep_label
        parts.append(df)

    if len(parts) == 0:
        return pd.DataFrame(columns=["time", "animal_id", "value", "cage_id", "rep"])

    out = pd.concat(parts, ignore_index=True)
    out["time"] = pd.to_datetime(out["time"])
    return out

def jax_daily_plateau(rep_label: str, cfg: dict, q: int = Q_THRESH) -> pd.DataFrame:
    """
    Uses cached JAX dataframes (already loaded) to compute per-animal per-day plateau_duration in dark phase.
    """
    df = _get_jax_loaded_df_for_rep(rep_label, cfg)
    if df.empty:
        raise ValueError(f"JAX: No cached data found for {rep_label}.")

    baseline = pd.to_datetime(cfg["start"])
    # day index (1-based) from baseline
    dt_minutes = ((df["time"] - baseline).dt.total_seconds() / 60.0).astype(int)
    df = df.assign(
        minute_from_start=dt_minutes,
        day=(dt_minutes // JAX_MIN_PER_DAY) + 1,
        minute_in_day=dt_minutes % JAX_MIN_PER_DAY
    )

    # restrict to dark phase minutes within each day
    df_dark = df[(df["minute_in_day"] >= JAX_DARK_START) & (df["minute_in_day"] < JAX_DARK_END)].copy()

    # per-animal threshold from all dark-phase values across entire window
    thr = df_dark.groupby("animal_id")["value"].apply(lambda s: float(np.nanpercentile(s.to_numpy(), q)))
    df_dark = df_dark.join(thr.rename("thr"), on="animal_id")

    # plateau duration per animal-day = count of minutes where value >= thr
    df_daily = (
        df_dark.assign(above=lambda d: d["value"] >= d["thr"])
        .groupby(["animal_id", "day"], as_index=False)["above"]
        .sum()
        .rename(columns={"above": "plateau_duration_min"})  # 60s resolution => 1 row per minute window
    )

    df_daily["rep"] = rep_label
    return df_daily

# ============================================================
# RUN: SMARR (aligned + misaligned)
# ============================================================
df_smarr = pd.read_csv("mice data.xlsx - FemAct (1).csv")

smarr_aligned_daily = smarr_daily_plateau(df_smarr, dealign=False, q=Q_THRESH)
smarr_misaligned_daily = smarr_daily_plateau(df_smarr, dealign=True, q=Q_THRESH)

smarr_aligned_rank = _rank_days(smarr_aligned_daily)
smarr_misaligned_rank = _rank_days(smarr_misaligned_daily)

print("\n=== Smarr ALIGNED (all days) — day rankings ===")
print(smarr_aligned_rank.to_string(index=False))

print("\n=== Smarr MISALIGNED (all days) — day rankings ===")
print(smarr_misaligned_rank.to_string(index=False))


_heatmap_plateau(smarr_aligned_daily, f"Smarr ALIGNED — Per-female plateau duration (q={Q_THRESH})", zscore_within_animal=True)
_heatmap_plateau(smarr_misaligned_daily, f"Smarr MISALIGNED — Per-female plateau duration (q={Q_THRESH})", zscore_within_animal=True)

# ============================================================
# RUN: JAX Rep1 + Rep2 (ASSUMES already loaded/cached)
# ============================================================
rep1_daily = jax_daily_plateau("Rep1", VEHICLE_CAGES["Rep1"], q=Q_THRESH)
rep2_daily = jax_daily_plateau("Rep2", VEHICLE_CAGES["Rep2"], q=Q_THRESH)

rep1_rank = _rank_days(rep1_daily)
rep2_rank = _rank_days(rep2_daily)

print("\n=== JAX Rep1 (all days) — day rankings ===")
print(rep1_rank.to_string(index=False))

print("\n=== JAX Rep2 (all days) — day rankings ===")
print(rep2_rank.to_string(index=False))


_heatmap_plateau(rep1_daily, f"JAX Rep1 (Vehicle) — Per-animal plateau duration (q={Q_THRESH})", zscore_within_animal=True)
_heatmap_plateau(rep2_daily, f"JAX Rep2 (Vehicle) — Per-animal plateau duration (q={Q_THRESH})", zscore_within_animal=True)



## 4. Population-Level Permutation Test

### Objective

Test whether the highest plateau days (top-k) show significantly
greater locomotor plateau duration than the remaining days
at the population level.

This evaluates whether estrus-like structure is detectable
across the entire replicate, not just within individual cages.

---

### Test Statistic

1. For each day:
   - Compute mean plateau duration across all animals
     (within replicate)

2. Rank days by mean plateau duration (descending)

3. Define:
   Top-k days = k highest-ranked days  
   Rest days = remaining days  

4. Compute observed statistic:

T_obs = mean(top-k days) − mean(rest days)

---

### Permutation Procedure

To generate the null distribution:

For each permutation:

1. Shuffle day labels within each animal
   (preserves animal-level structure and variance)

2. Recompute:
   - Daily mean plateau duration
   - Day rankings
   - T_perm

3. Repeat N times (e.g., 10,000)

---

### Empirical p-value

p = proportion of permutations where  
T_perm ≥ T_obs

Small p-value → Evidence that top-k days are not random  
Large p-value → No detectable population-level structure

---

### Interpretation

Significant result:
→ Structured estrus-linked plateau pattern detectable at population scale

Non-significant result:
→ Plateau structure may be:
   - Distributed across multiple days
   - Cage-specific
   - Weak relative to baseline variability
   - Masked by replicate timing differences


In [None]:
import pandas as pd

df_smarr = pd.read_csv("mice data.xlsx - FemAct (1).csv")

female_cols = [c for c in df_smarr.columns if str(c).lower().startswith("fem")]
female_data = df_smarr[female_cols].to_numpy().T

n_females, n_timepoints = female_data.shape
n_days = n_timepoints / 1440

print(f"Females: {n_females}")
print(f"Timepoints (minutes): {n_timepoints}")
print(f"Days: {n_days}")
print(f"Complete days: {n_timepoints // 1440}")
print(f"Remainder minutes: {n_timepoints % 1440}")


In [None]:
estrus_days = {1, 5}

obs = (
    smarr_aligned_daily
    .assign(is_estrus=lambda d: d["day"].isin(estrus_days))
    .groupby(["animal_id", "is_estrus"])["plateau_duration_min"]
    .mean()
    .unstack()
)

T_obs = (obs[True] - obs[False]).mean()
rng = np.random.default_rng(0)
n_perm = 10000
T_perm = np.zeros(n_perm)

for i in range(n_perm):
    perm = smarr_aligned_daily.copy()
    perm["day"] = (
        perm.groupby("animal_id")["day"]
        .transform(lambda x: rng.permutation(x.values))
    )

    tmp = (
        perm.assign(is_estrus=lambda d: d["day"].isin(estrus_days))
        .groupby(["animal_id", "is_estrus"])["plateau_duration_min"]
        .mean()
        .unstack()
    )

    T_perm[i] = (tmp[True] - tmp[False]).mean()
p_value = (np.sum(T_perm >= T_obs) + 1) / (n_perm + 1)
print("Observed T:", T_obs)
print("Permutation p-value:", p_value)


In [None]:
import numpy as np

# ------------------------------------------------------------
# PERMUTATION TEST: Smarr MISALIGNED
# ------------------------------------------------------------
# Assumes you already have:
#   smarr_misaligned_daily
# with columns: animal_id, day, plateau_duration_min
# ------------------------------------------------------------

estrus_days = {1, 5}
rng = np.random.default_rng(0)
n_perm = 10000

# ----- observed statistic -----
obs = (
    smarr_misaligned_daily
    .assign(is_estrus=lambda d: d["day"].isin(estrus_days))
    .groupby(["animal_id", "is_estrus"])["plateau_duration_min"]
    .mean()
    .unstack()
)

T_obs_mis = (obs[True] - obs[False]).mean()

# ----- permutation distribution -----
T_perm_mis = np.zeros(n_perm)

for i in range(n_perm):
    perm = smarr_misaligned_daily.copy()
    perm["day"] = (
        perm.groupby("animal_id")["day"]
        .transform(lambda x: rng.permutation(x.values))
    )

    tmp = (
        perm.assign(is_estrus=lambda d: d["day"].isin(estrus_days))
        .groupby(["animal_id", "is_estrus"])["plateau_duration_min"]
        .mean()
        .unstack()
    )

    T_perm_mis[i] = (tmp[True] - tmp[False]).mean()

# ----- p-value -----
p_value_mis = (np.sum(T_perm_mis >= T_obs_mis) + 1) / (n_perm + 1)

print("Smarr MISALIGNED:")
print("Observed T:", T_obs_mis)
print("Permutation p-value:", p_value_mis)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# PERMUTATION TESTS FOR JAX REP1 + REP2 (NO GROUND-TRUTH ESTRUS)
#
# Assumes you already computed:
#   rep1_daily, rep2_daily
# each with columns: animal_id, day, plateau_duration_min
#
# Null: day labels are exchangeable WITHIN each animal.
# Permutation: shuffle day labels within each animal.
#
# Test statistic options:
#   1) "max_mean_day" (recommended): max over days of (mean plateau_duration across animals)
#   2) "topk_minus_rest": mean(top-k day means) - mean(rest day means)
# ============================================================

RNG_SEED = 0
N_PERM = 10000

def _obs_day_means(df_daily: pd.DataFrame) -> pd.Series:
    return df_daily.groupby("day")["plateau_duration_min"].mean().sort_index()

def _stat_max_mean_day(day_means: pd.Series) -> float:
    return float(day_means.max())

def _stat_topk_minus_rest(day_means: pd.Series, k: int = 1) -> float:
    vals = day_means.to_numpy()
    if len(vals) <= k:
        return float("nan")
    topk = np.sort(vals)[-k:]
    rest = np.sort(vals)[:-k]
    return float(topk.mean() - rest.mean())

def permutation_test_jax(
    df_daily: pd.DataFrame,
    stat: str = "max_mean_day",
    k: int = 1,
    n_perm: int = N_PERM,
    seed: int = RNG_SEED,
    plot: bool = True,
    title_prefix: str = "",
):
    """
    Returns dict with observed statistic, permutation p-value, perm distribution, and observed day means.
    """
    rng = np.random.default_rng(seed)

    # observed
    day_means_obs = _obs_day_means(df_daily)
    if stat == "max_mean_day":
        T_obs = _stat_max_mean_day(day_means_obs)
    elif stat == "topk_minus_rest":
        T_obs = _stat_topk_minus_rest(day_means_obs, k=k)
    else:
        raise ValueError("stat must be 'max_mean_day' or 'topk_minus_rest'")

    # permutation distribution
    T_perm = np.empty(n_perm, dtype=float)

    # pre-split by animal for speed
    grouped = list(df_daily.groupby("animal_id", sort=False))

    for i in range(n_perm):
        perm_parts = []
        for _, g in grouped:
            g2 = g.copy()
            g2["day"] = rng.permutation(g2["day"].to_numpy())
            perm_parts.append(g2)
        perm_df = pd.concat(perm_parts, ignore_index=True)

        dm = perm_df.groupby("day")["plateau_duration_min"].mean().sort_index()
        if stat == "max_mean_day":
            T_perm[i] = _stat_max_mean_day(dm)
        else:
            T_perm[i] = _stat_topk_minus_rest(dm, k=k)

    # one-sided p-value (large = more "estrus-like" structure)
    p_value = (np.sum(T_perm >= T_obs) + 1) / (n_perm + 1)

    print("\n" + "="*70)
    print(f"{title_prefix} Permutation test")
    print(f"Statistic: {stat}" + (f" (k={k})" if stat == "topk_minus_rest" else ""))
    print(f"Observed T: {T_obs:.6f}")
    print(f"Permutation p-value: {p_value:.6g}")
    print("="*70)

    # optional plots
    if plot:
        # Null distribution with observed marked
        plt.figure(figsize=(8, 3.5))
        plt.hist(T_perm, bins=40)
        plt.axvline(T_obs, linestyle="--", linewidth=2)
        plt.title(f"{title_prefix} null distribution ({stat})")
        plt.xlabel("Permutation statistic")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.show()

        # Observed day means bar/line
        plt.figure(figsize=(10, 3.5))
        plt.plot(day_means_obs.index.to_numpy(), day_means_obs.to_numpy(), linewidth=2)
        top_day = int(day_means_obs.idxmax())
        plt.axvline(top_day, linestyle="--", linewidth=1.5, alpha=0.8, label=f"Top day = {top_day}")
        plt.title(f"{title_prefix} observed day means (plateau duration)")
        plt.xlabel("Day")
        plt.ylabel("Mean plateau duration (min)")
        plt.legend()
        plt.tight_layout()
        plt.show()

    return {
        "T_obs": T_obs,
        "p_value": p_value,
        "T_perm": T_perm,
        "day_means_obs": day_means_obs,
    }

# ============================================================
# RUN: REP1 + REP2
# ============================================================
# Recommended primary test: max_mean_day (controls "cherry-picking" top day via permutation max)
res_rep1 = permutation_test_jax(
    rep1_daily,
    stat="max_mean_day",
    n_perm=N_PERM,
    seed=RNG_SEED,
    plot=True,
    title_prefix="JAX Rep1",
)

res_rep2 = permutation_test_jax(
    rep2_daily,
    stat="max_mean_day",
    n_perm=N_PERM,
    seed=RNG_SEED,
    plot=True,
    title_prefix="JAX Rep2",
)

# Optional: also test "top-2 days vs rest" (often matches biology: two estrus-like peaks in long windows)
# res_rep1_k2 = permutation_test_jax(rep1_daily, stat="topk_minus_rest", k=2, n_perm=N_PERM, seed=RNG_SEED, plot=True, title_prefix="JAX Rep1")
# res_rep2_k2 = permutation_test_jax(rep2_daily, stat="topk_minus_rest", k=2, n_perm=N_PERM, seed=RNG_SEED, plot=True, title_prefix="JAX Rep2")


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# PLATEAU DURATION ON ALL DAYS (boxplots by day)
# ============================================================

Q_THRESH = 85

SMARR_MIN_PER_DAY = 1440
SMARR_DARK_START, SMARR_DARK_END = 0, 720          # Smarr: dark first 12h

JAX_MIN_PER_DAY = 1440
JAX_DARK_START, JAX_DARK_END = 720, 1440           # JAX: dark second 12h (baseline start = lights-on)

VEHICLE_CAGES = {
    "Rep1": {"cages": [4918, 4922, 4923], "start": "2025-01-10 06:00:00", "end": "2025-01-22 06:00:00"},
    "Rep2": {"cages": [4928, 4929, 4934], "start": "2025-01-25 06:00:00", "end": "2025-02-04 06:00:00"},
}

# ----------------------------
# HELPERS
# ----------------------------
def _rank_days(df_daily: pd.DataFrame, metric: str = "plateau_duration_min") -> pd.DataFrame:
    out = (
        df_daily.groupby("day")[metric]
        .agg(mean="mean", median="median", sd="std", n="count")
        .reset_index()
        .sort_values("mean", ascending=False)
        .reset_index(drop=True)
    )
    out["rank"] = np.arange(1, len(out) + 1)
    return out

def _boxplot_by_day(df_daily: pd.DataFrame, title: str, metric: str = "plateau_duration_min"):
    """
    Boxplot per day: distribution across animals (one value per animal-day).
    """
    d = df_daily.copy()
    d["day"] = d["day"].astype(int)

    days = np.sort(d["day"].unique())
    data = [d.loc[d["day"] == day, metric].dropna().to_numpy() for day in days]

    plt.figure(figsize=(10, 3.8))
    plt.boxplot(
        data,
        labels=days,
        showfliers=False,     # hide extreme points (cleaner); set True if you want
        whis=1.5
    )
    plt.xlabel("Day")
    plt.ylabel("Plateau duration (min)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def _heatmap_plateau(df_daily: pd.DataFrame, title: str, zscore_within_animal: bool = True):
    mat = df_daily.pivot_table(index="animal_id", columns="day", values="plateau_duration_min", aggfunc="mean")
    if zscore_within_animal:
        mat = mat.sub(mat.mean(axis=1), axis=0)
        mat = mat.div(mat.std(axis=1).replace(0, np.nan), axis=0)

    peak_day = mat.idxmax(axis=1)
    mat = mat.loc[peak_day.sort_values().index]

    plt.figure(figsize=(12, 5))
    im = plt.imshow(mat.values, aspect="auto", interpolation="nearest")
    plt.xlabel("Day")
    plt.ylabel("Animal (sorted by peak day)")
    plt.title(title + (" (z-score within animal)" if zscore_within_animal else ""))
    plt.xticks(np.arange(mat.shape[1]), mat.columns.astype(int))
    plt.yticks([])
    cbar = plt.colorbar(im)
    cbar.set_label("Plateau duration (z)" if zscore_within_animal else "Plateau duration (min)")
    plt.tight_layout()
    plt.show()

# ----------------------------
# SMARR: per-female per-day plateau_duration
# ----------------------------
def smarr_daily_plateau(df_smarr: pd.DataFrame, dealign: bool = False, q: int = Q_THRESH) -> pd.DataFrame:
    female_cols = [c for c in df_smarr.columns if str(c).lower().startswith("fem")]
    if len(female_cols) == 0:
        raise ValueError("Smarr: No female columns found (expected columns starting with 'fem').")

    X = df_smarr[female_cols].to_numpy().T  # (n_females, n_timepoints)
    n_fem, n_tp = X.shape
    n_days = n_tp // SMARR_MIN_PER_DAY
    if n_days < 1:
        raise ValueError("Smarr: Not enough data for at least 1 full day.")

    X = X[:, :n_days * SMARR_MIN_PER_DAY]
    n_tp = X.shape[1]

    # "maximally misaligned": roll each female by i days (mod total length)
    if dealign:
        Xm = np.empty_like(X)
        for i in range(n_fem):
            shift = (i * SMARR_MIN_PER_DAY) % n_tp
            Xm[i] = np.roll(X[i], -shift)
        X = Xm

    # per-female threshold from ALL dark-phase minutes across window
    thresholds = np.zeros(n_fem, dtype=float)
    for i in range(n_fem):
        dark_vals = []
        for d in range(n_days):
            day_slice = X[i, d*SMARR_MIN_PER_DAY:(d+1)*SMARR_MIN_PER_DAY]
            dark_vals.append(day_slice[SMARR_DARK_START:SMARR_DARK_END])
        thresholds[i] = float(np.nanpercentile(np.concatenate(dark_vals), q))

    rows = []
    for i in range(n_fem):
        thr = thresholds[i]
        for d in range(n_days):
            day_slice = X[i, d*SMARR_MIN_PER_DAY:(d+1)*SMARR_MIN_PER_DAY]
            dark = day_slice[SMARR_DARK_START:SMARR_DARK_END]
            plateau_duration = int(np.sum(dark >= thr))  # minutes
            rows.append({"animal_id": i, "day": d + 1, "plateau_duration_min": plateau_duration})

    return pd.DataFrame(rows)

# ----------------------------
# JAX: pull from ACTIVITY_CACHE and compute per-animal per-day plateau_duration
# ----------------------------
def _get_jax_loaded_df_for_rep(rep_label: str, cfg: dict) -> pd.DataFrame:
    if "ACTIVITY_CACHE" not in globals():
        raise ValueError("JAX: ACTIVITY_CACHE not found. Run your JAX loading cell first so data are cached.")

    parts = []
    for cage_id in cfg["cages"]:
        key = (cage_id, cfg["start"], cfg["end"])
        if key not in ACTIVITY_CACHE:
            raise ValueError(f"JAX: Missing cached data for key={key}. Load it first with your loader.")
        df = ACTIVITY_CACHE[key].copy()
        df["cage_id"] = cage_id
        df["rep"] = rep_label
        parts.append(df)

    out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=["time","animal_id","value","cage_id","rep"])
    out["time"] = pd.to_datetime(out["time"])
    return out

def jax_daily_plateau(rep_label: str, cfg: dict, q: int = Q_THRESH) -> pd.DataFrame:
    df = _get_jax_loaded_df_for_rep(rep_label, cfg)
    if df.empty:
        raise ValueError(f"JAX: No cached data found for {rep_label}.")

    baseline = pd.to_datetime(cfg["start"])
    dt_minutes = ((df["time"] - baseline).dt.total_seconds() / 60.0).astype(int)
    df = df.assign(
        minute_from_start=dt_minutes,
        day=(dt_minutes // JAX_MIN_PER_DAY) + 1,
        minute_in_day=dt_minutes % JAX_MIN_PER_DAY
    )

    df_dark = df[(df["minute_in_day"] >= JAX_DARK_START) & (df["minute_in_day"] < JAX_DARK_END)].copy()

    thr = df_dark.groupby("animal_id")["value"].apply(lambda s: float(np.nanpercentile(s.to_numpy(), q)))
    df_dark = df_dark.join(thr.rename("thr"), on="animal_id")

    df_daily = (
        df_dark.assign(above=lambda d: d["value"] >= d["thr"])
        .groupby(["animal_id", "day"], as_index=False)["above"]
        .sum()
        .rename(columns={"above": "plateau_duration_min"})
    )
    df_daily["rep"] = rep_label
    return df_daily

# ============================================================
# RUN: SMARR (aligned + misaligned)
# ============================================================
df_smarr = pd.read_csv("mice data.xlsx - FemAct (1).csv")

smarr_aligned_daily = smarr_daily_plateau(df_smarr, dealign=False, q=Q_THRESH)
smarr_misaligned_daily = smarr_daily_plateau(df_smarr, dealign=True, q=Q_THRESH)

print("\n=== Smarr ALIGNED rankings ===")
print(_rank_days(smarr_aligned_daily).to_string(index=False))
print("\n=== Smarr MISALIGNED rankings ===")
print(_rank_days(smarr_misaligned_daily).to_string(index=False))

_boxplot_by_day(smarr_aligned_daily, f"Smarr ALIGNED — Plateau duration by day (boxplots, q={Q_THRESH})")
_boxplot_by_day(smarr_misaligned_daily, f"Smarr MISALIGNED — Plateau duration by day (boxplots, q={Q_THRESH})")

_heatmap_plateau(smarr_aligned_daily, f"Smarr ALIGNED — Per-female plateau duration (q={Q_THRESH})", zscore_within_animal=True)
_heatmap_plateau(smarr_misaligned_daily, f"Smarr MISALIGNED — Per-female plateau duration (q={Q_THRESH})", zscore_within_animal=True)

# ============================================================
# RUN: JAX Rep1 + Rep2 (ASSUMES already loaded/cached in ACTIVITY_CACHE)
# ============================================================
rep1_daily = jax_daily_plateau("Rep1", VEHICLE_CAGES["Rep1"], q=Q_THRESH)
rep2_daily = jax_daily_plateau("Rep2", VEHICLE_CAGES["Rep2"], q=Q_THRESH)

print("\n=== JAX Rep1 rankings ===")
print(_rank_days(rep1_daily).to_string(index=False))
print("\n=== JAX Rep2 rankings ===")
print(_rank_days(rep2_daily).to_string(index=False))

_boxplot_by_day(rep1_daily, f"JAX Rep1 (Vehicle) — Plateau duration by day (boxplots, q={Q_THRESH})")
_boxplot_by_day(rep2_daily, f"JAX Rep2 (Vehicle) — Plateau duration by day (boxplots, q={Q_THRESH})")

_heatmap_plateau(rep1_daily, f"JAX Rep1 (Vehicle) — Per-animal plateau duration (q={Q_THRESH})", zscore_within_animal=True)
_heatmap_plateau(rep2_daily, f"JAX Rep2 (Vehicle) — Per-animal plateau duration (q={Q_THRESH})", zscore_within_animal=True)


## 5. Top-k vs Rest Permutation Test

Test statistic:
T = mean(top-k days) − mean(other days)

Permutation procedure:
- Shuffle day labels within animal
- Recompute T
- Repeat N times
- Compute empirical p-value

Purpose:
Determine whether top-ranked days significantly deviate from baseline structure.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# TOP-3 DAYS VS REST PERMUTATION TEST (JAX Rep1 + Rep2)
#
# Requires:
#   rep1_daily, rep2_daily
# each with columns:
#   animal_id, day, plateau_duration_min
#
# Null: day labels are exchangeable WITHIN each animal.
# Permutation: shuffle day labels within each animal (plateau durations fixed).
#
# Statistic:
#   Let m_d = mean plateau_duration_min across animals on day d.
#   T = mean(top 3 of {m_d}) - mean(rest of {m_d})
# ============================================================

RNG_SEED = 0
N_PERM = 10000
K = 3

def day_mean_vector(df_daily: pd.DataFrame) -> pd.Series:
    """Day-level mean plateau duration across animals, indexed by day."""
    return df_daily.groupby("day")["plateau_duration_min"].mean().sort_index()

def topk_vs_rest_stat(day_means: pd.Series, k: int = 3) -> float:
    vals = day_means.to_numpy(dtype=float)
    if len(vals) <= k:
        return float("nan")
    order = np.argsort(vals)  # ascending
    topk = vals[order][-k:]
    rest = vals[order][:-k]
    return float(topk.mean() - rest.mean())

def permute_days_within_animal(df_daily: pd.DataFrame, rng: np.random.Generator) -> pd.DataFrame:
    parts = []
    for _, g in df_daily.groupby("animal_id", sort=False):
        g2 = g.copy()
        g2["day"] = rng.permutation(g2["day"].to_numpy())
        parts.append(g2)
    return pd.concat(parts, ignore_index=True)

def permutation_test_top3_vs_rest(
    df_daily: pd.DataFrame,
    label: str,
    k: int = 3,
    n_perm: int = 10000,
    seed: int = 0,
    plot: bool = True,
):
    rng = np.random.default_rng(seed)

    # Observed
    dm_obs = day_mean_vector(df_daily)
    T_obs = topk_vs_rest_stat(dm_obs, k=k)

    # Permutations
    T_perm = np.empty(n_perm, dtype=float)
    for i in range(n_perm):
        perm_df = permute_days_within_animal(df_daily, rng)
        dm = day_mean_vector(perm_df)
        T_perm[i] = topk_vs_rest_stat(dm, k=k)

    # One-sided p-value (large = more "structured" / more top-heavy)
    p_val = (np.sum(T_perm >= T_obs) + 1) / (n_perm + 1)

    # Print summary + which days are top-3 in observed data
    top3_days = dm_obs.sort_values(ascending=False).head(k)
    print("\n" + "=" * 72)
    print(f"{label}: Top-{k} days vs rest permutation test")
    print(f"Observed statistic T = mean(top-{k} day means) - mean(rest day means)")
    print(f"Observed top-{k} days (by day mean plateau duration):")
    for d, v in top3_days.items():
        print(f"  Day {int(d)}: mean plateau duration = {v:.3f} min")
    print(f"Observed T: {T_obs:.6f}")
    print(f"Permutation p-value: {p_val:.6g}")
    print("=" * 72)

    if plot:
        # Null distribution
        plt.figure(figsize=(8, 3.6))
        plt.hist(T_perm, bins=40)
        plt.axvline(T_obs, linestyle="--", linewidth=2, label="Observed T")
        plt.title(f"{label}: Null distribution (Top-{k} vs rest)")
        plt.xlabel("Permutation statistic T")
        plt.ylabel("Count")
        plt.legend()
        plt.tight_layout()
        plt.show()

        # Observed day means with top-k highlighted
        plt.figure(figsize=(10, 3.6))
        x = dm_obs.index.to_numpy()
        y = dm_obs.to_numpy()
        plt.plot(x, y, linewidth=2, label="Day mean plateau duration")
        for d in top3_days.index:
            plt.axvline(int(d), linestyle="--", linewidth=1.3, alpha=0.8)
        plt.title(f"{label}: Observed day means (Top-{k} highlighted)")
        plt.xlabel("Day")
        plt.ylabel("Mean plateau duration (min)")
        plt.legend()
        plt.tight_layout()
        plt.show()

    return {
        "label": label,
        "k": k,
        "T_obs": T_obs,
        "p_value": p_val,
        "T_perm": T_perm,
        "day_means_obs": dm_obs,
        "topk_days_obs": top3_days,
    }

# ============================================================
# RUN ON REP1 + REP2
# ============================================================
res_rep1_top3 = permutation_test_top3_vs_rest(
    rep1_daily, label="JAX Rep1", k=K, n_perm=N_PERM, seed=RNG_SEED, plot=True
)

res_rep2_top3 = permutation_test_top3_vs_rest(
    rep2_daily, label="JAX Rep2", k=K, n_perm=N_PERM, seed=RNG_SEED, plot=True
)
