<a href="https://colab.research.google.com/github/KravitzLab/FED3Analyses/blob/PsygeneL0-to-L1-QC/PsygeneFEDL0_to_L1_QC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is a notebook for running FED3 quality control.
Created by Chantelle Murrell  
Files are renamed based on Mouse_ID in a key. Multiple files per FED and multiple FEDs per mouse ID are concatenated.   
This was created to get files from L0 to L1 for Psygene.  
<br>
<img src="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQqe_1a1j1bQaqhOxq0VvukPqfolLRUqOdl-g&s" width="300" />

Updated: 02.12.25






In [None]:
# @title Install and import libraries
!pip install -q pingouin
!pip -q install "git+https://github.com/earnestt1234/fed3.git"
!pip install -q ipywidgets

#%% import libraries and set plot parameters
import fed3
import fed3.plot as fplot
import pingouin as pg
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import seaborn as sns
import os
import io
import shutil
from collections import defaultdict
from google.colab import files
from datetime import datetime, timedelta
import zipfile
import tempfile
import warnings
import ipywidgets as widgets
import re
from IPython.display import display, clear_output
warnings.filterwarnings('ignore')  # this is a bit dangerous but we'll supress all warnings
print("Packages installed.")

plt.rcParams.update({'font.size': 12, 'figure.autolayout': True})
plt.rcParams['figure.figsize'] = [6, 4]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

import statsmodels.api as sm
from statsmodels.formula.api import ols

In [None]:
# @title Import L0 files (Files straight off the FED sorted by session type)


feds = []
loaded_files = []
session_types = []

uploaded = files.upload()

def _coerce_fed3_time_like(df: pd.DataFrame) -> pd.DataFrame:
    """
    Make FED3 dataframe safe for fed3.as_aligned by:
      - ensuring a DatetimeIndex
      - converting elapsed/time-like columns to timedeltas
    """
    d = df.copy()

    # 1) Ensure a DatetimeIndex
    if not isinstance(d.index, pd.DatetimeIndex):
        # Try common columns
        if "MM:DD:YYYY hh:mm:ss" in d.columns:
            dti = pd.to_datetime(d["MM:DD:YYYY hh:mm:ss"], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        elif "DateTime" in d.columns:
            dti = pd.to_datetime(d["DateTime"], errors="coerce")
        elif {"Date","Time"}.issubset(d.columns):
            dti = pd.to_datetime(d["Date"].astype(str) + " " + d["Time"].astype(str), errors="coerce")
        elif "Timestamp" in d.columns:
            dti = pd.to_datetime(d["Timestamp"], errors="coerce")
        else:
            # Last resort: try to parse the index itself
            dti = pd.to_datetime(d.index, errors="coerce")

        if dti.notna().any():
            d.index = dti
        else:
            # If nothing parsed, leave index as-is; as_aligned may still work using elapsed
            pass

    # 2) Coerce elapsed/timedelta-like columns (names vary across firmwares/exports)
    elapsed_candidates = [
        "Elapsed", "Elapsed_Time", "Elapsed Time", "ElapsedTime",
        "Time_Since_Start", "Time Since Start"
    ]
    for col in elapsed_candidates:
        if col in d.columns:
            # If string-like, convert to timedelta (supports "HH:MM:SS" or seconds)
            if not np.issubdtype(d[col].dtype, np.timedelta64):
                # First try "HH:MM:SS" style
                td = pd.to_timedelta(d[col], errors="coerce")
                # If everything is NaT, try treating as seconds
                if td.isna().all():
                    with pd.option_context('mode.use_inf_as_na', True):
                        secs = pd.to_numeric(d[col], errors="coerce")
                    td = pd.to_timedelta(secs, unit="s", errors="coerce")
                d[col] = td

    # 3) Also coerce any column literally named "Time" to timedelta if it looks like HH:MM:SS
    if "Time" in d.columns and not np.issubdtype(d["Time"].dtype, np.timedelta64):
        td = pd.to_timedelta(d["Time"], errors="coerce")
        # don't overwrite if that destroys legitimate clock times; only set if it parsed well
        if td.notna().sum() >= max(5, int(0.5 * len(d))):
            d["Time"] = td

    return d

def extract_fed_and_date(filename):
    match = re.match(r"FED(\d{3})_(\d{6})", filename)
    if not match:
        return None, None, None
    fed_id = match.group(1)
    raw_date = match.group(2)
    mm, dd, yy = raw_date[:2], raw_date[2:4], raw_date[4:]
    yyyy = "20" + yy if int(yy) < 50 else "19" + yy
    file_date = datetime.strptime(f"{yyyy}-{mm}-{dd}", "%Y-%m-%d")
    return fed_id, file_date.strftime("%Y%m%d"), file_date

for name, data in uploaded.items():
    if name.lower().endswith(".zip"):
        with zipfile.ZipFile(io.BytesIO(data)) as zip_file:
            for zip_info in zip_file.infolist():
                if not zip_info.filename.lower().endswith(".csv"):
                    continue

                file_data = zip_file.read(zip_info)
                if len(file_data) <= 1024:
                    continue

                with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv", delete=False) as tmp_file:
                    tmp_file.write(file_data)
                    tmp_path = tmp_file.name

                try:
                    # Quick pass to grab session type
                    raw_df = pd.read_csv(tmp_path)
                    session_type = raw_df["Session_type"].iloc[0] if "Session_type" in raw_df.columns else "Unknown"

                    # Load via fed3
                    df = fed3.load(tmp_path)
                    df.name = zip_info.filename
                    df.attrs = {"Session_type": session_type}

                    # Coerce time-like fields
                    df = _coerce_fed3_time_like(df)

                    feds.append(df)
                    loaded_files.append(zip_info.filename)
                    session_types.append(session_type)

                except Exception as e:
                    print(f"Error loading {zip_info.filename}: {e}")
                finally:
                    os.remove(tmp_path)

    elif name.lower().endswith(".csv"):
        if len(data) <= 1024:
            continue

        with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv", delete=False) as tmp_file:
            tmp_file.write(data)
            tmp_path = tmp_file.name

        try:
            raw_df = pd.read_csv(tmp_path)
            session_type = raw_df["Session_type"].iloc[0] if "Session_type" in raw_df.columns else "Unknown"

            df = fed3.load(tmp_path)
            df.name = name
            df.attrs = {"Session_type": session_type}

            df = _coerce_fed3_time_like(df)

            feds.append(df)
            loaded_files.append(name)
            session_types.append(session_type)

        except Exception as e:
            print(f"Error loading {name}: {e}")
        finally:
            os.remove(tmp_path)

# ---- Plot ----
if feds:
    try:
        # Align on datetime now that time fields are coerced
        fed3.as_aligned(feds, alignment="datetime", inplace=True)
    except Exception as e:
        print(f"as_aligned(datetime) failed: {e}")
        # Fallback: try default alignment
        try:
            fed3.as_aligned(feds, inplace=True)
        except Exception as e2:
            print(f"as_aligned(default) also failed: {e2}")

    plt.figure(figsize=(8, 4))
    fplot.line(feds, y='pellets')
    plt.legend().remove()
    plt.tight_layout()
    plt.show()
else:
    print("No valid files loaded.")

In [None]:
# @title Upload Key
uploader = widgets.FileUpload(
    accept='.xlsx',
    multiple=False,
    description='Upload Key'
)
display(uploader)

uploaded_key_path = None  # Global path to use later

def handle_upload(change):
    global uploaded_key_path
    if uploader.value:
        uploaded = uploader.value
        key_name = next(iter(uploaded))
        content = uploaded[key_name]['content']
        with open(key_name, 'wb') as f:
            f.write(content)
        uploaded_key_path = key_name
        print(f"✅ Uploaded and saved key to: {key_name}")

uploader.observe(handle_upload, names='value')

In [None]:
# @title Identify multiple files per animal
# --- helpers ---
# Uses variables you already have: loaded_files, uploaded_key_path
# @title Identify multiple files per animal (uses FED_StartDate ±14d to disambiguate cohorts)

import re, pandas as pd
from datetime import timedelta

# --- helpers ---
# Uses: loaded_files, uploaded_key_path

def parse_fed3_cell(val):
    if pd.isna(val): return []
    s = str(val)
    s = re.sub(r"\b(?:and)\b", ",", s, flags=re.I)
    s = re.sub(r"[+,/&]", ",", s)
    s = re.sub(r"[;|]", ",", s)
    parts = [p.strip() for p in s.split(",") if p.strip()]
    return sorted({int(m) for p in parts for m in re.findall(r"\d+", p)})

fed_in_name_pat = re.compile(r"(?i)fed0*([0-9]{1,3})")
date_token_pat  = re.compile(r"_(\d{6})(?:_|\.|$)")  # _DDMMYY_ or _MMDDYY_

def extract_fed_from_loaded_filename(fname: str):
    m = fed_in_name_pat.search(str(fname).strip())
    return int(m.group(1)) if m else None

def parse_6digit_ddmmyy(token: str):
    # 040925 -> 2025-09-04 (DMY)
    return pd.to_datetime(token, format="%d%m%y", errors="coerce", dayfirst=True)

def parse_6digit_mmddyy(token: str):
    # 040925 -> 2025-04-09 (MDY)
    return pd.to_datetime(token, format="%m%d%y", errors="coerce", dayfirst=False)

def extract_date_candidates_from_filename(fname: str):
    """
    Return both DMY and MDY candidates from a 6-digit token in the filename.
    """
    s = str(fname).strip()
    m = date_token_pat.search(s) or re.search(r"(\d{6})", s)
    if not m:
        return []
    tok = m.group(1)
    dmy = parse_6digit_ddmmyy(tok)
    mdy = parse_6digit_mmddyy(tok)
    cands = []
    if not pd.isna(dmy): cands.append(("DMY", dmy))
    if not pd.isna(mdy): cands.append(("MDY", mdy))
    # de-dup if both parse to the same date (rare)
    seen = set()
    out = []
    for fmt, dt in cands:
        k = pd.Timestamp(dt).date()
        if k not in seen:
            out.append((fmt, pd.Timestamp(dt)))
            seen.add(k)
    return out

# --- Load key ---
key = pd.read_excel(uploaded_key_path, sheet_name=0)
if "FED_StartDate" not in key.columns:
    raise ValueError("Key must have a 'FED_StartDate' column.")

animal_col = "Mouse_ID" if "Mouse_ID" in key.columns else next(
    c for c in key.columns if re.search(r"mouse", str(c), re.I)
)
fed_col = "FED3" if "FED3" in key.columns else next(
    c for c in key.columns if re.fullmatch(r"\s*fed\s*_?3\s*", str(c), flags=re.I)
)

def normalize_animal(x):
    if pd.isna(x): return "Unknown"
    sx = str(x).strip()
    if sx == "": return "Unknown"
    if re.fullmatch(r"\d+(\.0+)?", sx):
        return str(int(float(sx)))
    return sx

key = key.copy()
key[animal_col] = key[animal_col].map(normalize_animal)

# Parse FED_StartDate both ways (for safety)
raw_start = key["FED_StartDate"]
key["FED_StartDate_mdy"] = pd.to_datetime(raw_start, errors="coerce", dayfirst=False)
key["FED_StartDate_dmy"] = pd.to_datetime(raw_start, errors="coerce", dayfirst=True)

# Explode key to (orig_idx, Animal_ID, FED3_int, FED_StartDate_* variants)
key = key.reset_index().rename(columns={"index": "orig_idx"})
key["_FED3_list"] = key[fed_col].apply(parse_fed3_cell)
key_exploded = (
    key.explode("_FED3_list", ignore_index=True)
       .dropna(subset=["_FED3_list"])
       .rename(columns={"_FED3_list": "FED3_int"})
)
key_exploded["FED3_int"] = key_exploded["FED3_int"].astype(int)

# --- matching helper ---
window = timedelta(days=17)

def pick_best_key_row_for_any_filedate(fed_id: int, filedate_candidates):
    """
    Consider BOTH filename date parses (DMY and MDY) and BOTH key parses (DMY/MDY).
    Choose the combination with the smallest delta within ±window.
    Returns: dict with keys {orig_idx, Mouse_ID, FED_StartDate_used, file_fmt_used}
    or None if no candidate within window. Also returns 'nearest' (outside window) for debug.
    """
    sub = key_exploded[key_exploded["FED3_int"] == int(fed_id)]
    if sub.empty or not filedate_candidates:
        return None, None

    cols_base = ["orig_idx", animal_col, "FED3_int"]
    dmy = sub[cols_base + ["FED_StartDate_dmy"]].rename(columns={"FED_StartDate_dmy":"FED_StartDate_used"}).copy()
    dmy["key_fmt"] = "DMY"
    mdy = sub[cols_base + ["FED_StartDate_mdy"]].rename(columns={"FED_StartDate_mdy":"FED_StartDate_used"}).copy()
    mdy["key_fmt"] = "MDY"
    key_cand = pd.concat([dmy, mdy], ignore_index=True).dropna(subset=["FED_StartDate_used"])
    if key_cand.empty:
        return None, None

    # Build cross-product with file candidates
    rows = []
    for ffmt, fdt in filedate_candidates:
        tmp = key_cand.copy()
        tmp["file_fmt"] = ffmt
        tmp["file_date"] = fdt
        tmp["abs_delta"] = (tmp["file_date"] - tmp["FED_StartDate_used"]).abs()
        rows.append(tmp)
    all_cand = pd.concat(rows, ignore_index=True)

    # For debug: the nearest even if outside window
    nearest = all_cand.sort_values(["abs_delta", "FED_StartDate_used"], ascending=[True, False]).iloc[0]

    within = all_cand[all_cand["abs_delta"] <= window]
    if within.empty:
        return None, nearest

    best = within.sort_values(["abs_delta", "FED_StartDate_used"], ascending=[True, False]).iloc[0]
    return {
        "orig_idx": int(best["orig_idx"]),
        "Mouse_ID": best[animal_col],
        "FED_StartDate_used": pd.to_datetime(best["FED_StartDate_used"]).normalize(),
        "file_fmt_used": best["file_fmt"]
    }, nearest

# --- Extract FED and both file-date candidates from filenames ---
files_df = pd.DataFrame({"file": [str(f).strip() for f in loaded_files]})
files_df["FED3"] = files_df["file"].map(extract_fed_from_loaded_filename)
files_df = files_df.dropna(subset=["FED3"]).astype({"FED3": int}).drop_duplicates(subset=["file"])
files_df["file_date_candidates"] = files_df["file"].map(extract_date_candidates_from_filename)

# --- Map each file ---
matches = []
unmatched = []
for _, r in files_df.iterrows():
    fed_id = int(r["FED3"])
    cands  = r["file_date_candidates"]
    best, nearest = pick_best_key_row_for_any_filedate(fed_id, cands)
    if best is None:
        # For debugging, report the closer of the two filename parses to the nearest key date
        if nearest is None:
            unmatched.append({
                "file": r["file"], "FED3": fed_id, "file_fmt": None,
                "file_date": None, "nearest_key_date": None, "delta_days": None
            })
        else:
            unmatched.append({
                "file": r["file"], "FED3": fed_id,
                "file_fmt": str(nearest["file_fmt"]),
                "file_date": pd.to_datetime(nearest["file_date"]),
                "nearest_key_date": pd.to_datetime(nearest["FED_StartDate_used"]),
                "delta_days": int(nearest["abs_delta"].days)
            })
        continue

    matches.append({
        "file": r["file"],
        "FED3": fed_id,
        "file_fmt_used": best["file_fmt_used"],
        "file_date": next(dt for fmt, dt in cands if fmt == best["file_fmt_used"]),
        "Mouse_ID": best["Mouse_ID"],
        "FED_StartDate": best["FED_StartDate_used"]
    })

matched_df = pd.DataFrame(matches)
unmatched_df = pd.DataFrame(unmatched)

# --- Summaries ---
print("\nFiles grouped by Mouse_ID and matched FED_StartDate (±14 days):")
if matched_df.empty:
    print("  None")
else:
    grouped = (
        matched_df.groupby(["Mouse_ID", "FED_StartDate"], as_index=False)
                  .agg(n_files=("file", "nunique"),
                       files=("file", lambda xs: sorted(set(xs))),
                       feds=("FED3", lambda xs: sorted(set(int(x) for x in xs))),
                       file_formats=("file_fmt_used", lambda xs: sorted(set(xs))))
                  .sort_values(["Mouse_ID", "FED_StartDate"])
                  .reset_index(drop=True)
    )
    for _, r in grouped.iterrows():
        adate = pd.to_datetime(r["FED_StartDate"]).strftime("%Y-%m-%d")
        print(f"  {r['Mouse_ID']} @ {adate} -> {r['n_files']} file(s); FED(s): {r['feds']}; file fmt(s): {r['file_formats']}")
        for f in r["files"]:
            print(f"     - {f}")

# Animals that appear in >1 FED_StartDate group
print("\nAnimals appearing in multiple FED_StartDate groups (possible cross-cohort reuse):")
if matched_df.empty:
    print("  None")
else:
    counts_by_animal = matched_df.groupby("Mouse_ID")["FED_StartDate"].nunique().reset_index(name="n_start_dates")
    multi_groups = counts_by_animal[counts_by_animal["n_start_dates"] > 1]["Mouse_ID"].tolist()
    if not multi_groups:
        print("  None")
    else:
        for aid in sorted(multi_groups, key=str):
            sub = matched_df[matched_df["Mouse_ID"] == aid]
            for sd, sub2 in sub.groupby("FED_StartDate"):
                adate = pd.to_datetime(sd).strftime("%Y-%m-%d")
                files_list = sorted(set(sub2["file"].tolist()))
                feds_list  = sorted(set(int(x) for x in sub2["FED3"].tolist()))
                print(f"  {aid} @ {adate} -> {len(files_list)} file(s); FED(s): {feds_list}")
                for f in files_list:
                    print(f"     - {f}")

# Unmatched files (show nearest key date to help diagnose)
if not unmatched_df.empty:
    print("\nUnmatched files (no key FED_StartDate within ±14 days for that FED). Nearest key date shown for debugging:")
    for _, r in unmatched_df.iterrows():
        fdate = "NaT" if pd.isna(r["file_date"]) else pd.to_datetime(r["file_date"]).strftime("%Y-%m-%d")
        kdate = "NaT" if pd.isna(r.get("nearest_key_date")) else pd.to_datetime(r["nearest_key_date"]).strftime("%Y-%m-%d")
        delta = "NA" if pd.isna(r.get("delta_days")) else int(r["delta_days"])
        fmt   = r.get("file_fmt") or "?"
        print(f"  FED{str(int(r['FED3'])).zfill(3)} file({fmt}) @ {fdate}; nearest key date: {kdate} (Δ={delta}d): {r['file']}")

In [None]:
# @title Concatenate files to one per Mouse_ID
# ---------- filename/date helpers ----------
if 'matched_df' not in globals() or matched_df.empty:
    raise RuntimeError("matched_df is empty or undefined. Run the matching cell first.")

import os, io, re, zipfile, pandas as pd
from datetime import datetime

# We'll use just what we need for concatenation
mapped = (
    matched_df[['file', 'Mouse_ID', 'FED_StartDate']]
    .dropna(subset=['file', 'Mouse_ID', 'FED_StartDate'])
    .copy()
)
animal_col = 'Mouse_ID'  # keep consistent

# ---------- filename/date helpers ----------
fed_date_pat = re.compile(r"(?i)^.*?\bFED0*(\d{1,3})_(\d{6})")

def parse_fed_and_date_from_name(fname: str):
    """
    Return (fed_int, date_str_MMDDYY, datetime_obj) from names like 'FED072_080425_00.CSV'.
    """
    m = fed_date_pat.search(str(fname).strip())
    if not m:
        return None, None, None
    fed = int(m.group(1))
    raw = m.group(2)  # MMDDYY
    mm, dd, yy = raw[:2], raw[2:4], raw[4:]
    yyyy = ("20" + yy) if int(yy) < 50 else ("19" + yy)
    dt = datetime.strptime(f"{yyyy}-{mm}-{dd}", "%Y-%m-%d")
    return fed, raw, dt  # return raw (MMDDYY) for naming

# ---------- discover zips from multiple sources ----------
zip_bytes_names = []
if 'uploaded' in globals() and isinstance(uploaded, dict):
    zip_bytes_names = [k for k in uploaded.keys() if k.lower().endswith(".zip")]

zip_file_paths = []
if 'loaded_files' in globals():
    try:
        # loaded_files might be a list of paths
        zip_file_paths = [p for p in loaded_files if str(p).lower().endswith(".zip") and os.path.exists(str(p))]
    except Exception:
        pass

# ---------- robust CSV reader ----------
def read_csv_from_sources(filename: str) -> pd.DataFrame:
    """
    Try, in order:
      1) Direct filesystem path
      2) uploaded[...] direct bytes
      3) CSV inside uploaded ZIP bytes
      4) CSV inside on-disk ZIPs listed in loaded_files
    Match by exact member name first; fall back to basename match.
    """
    fn = str(filename)
    bname = os.path.basename(fn)

    # 1) Direct filesystem
    try:
        if os.path.exists(fn):
            return pd.read_csv(fn)
    except Exception:
        pass

    # 2) uploaded[...] direct
    if 'uploaded' in globals() and isinstance(uploaded, dict):
        if fn in uploaded:
            return pd.read_csv(io.StringIO(uploaded[fn].decode("utf-8")))
        if bname in uploaded:
            return pd.read_csv(io.StringIO(uploaded[bname].decode("utf-8")))

        # 3) inside uploaded ZIP bytes
        for zname in zip_bytes_names:
            with zipfile.ZipFile(io.BytesIO(uploaded[zname])) as zf:
                # exact
                try:
                    data = zf.read(fn)
                    return pd.read_csv(io.StringIO(data.decode("utf-8")))
                except KeyError:
                    pass
                # basename fallback
                try:
                    # look for a member whose basename matches
                    for member in zf.namelist():
                        if os.path.basename(member) == bname:
                            data = zf.read(member)
                            return pd.read_csv(io.StringIO(data.decode("utf-8")))
                except KeyError:
                    pass

    # 4) inside on-disk ZIPs
    for zpath in zip_file_paths:
        try:
            with zipfile.ZipFile(zpath, 'r') as zf:
                # exact
                try:
                    data = zf.read(fn)
                    return pd.read_csv(io.StringIO(data.decode("utf-8")))
                except KeyError:
                    pass
                # basename fallback
                for member in zf.namelist():
                    if os.path.basename(member) == bname:
                        data = zf.read(member)
                        return pd.read_csv(io.StringIO(data.decode("utf-8")))
        except Exception:
            continue

    raise FileNotFoundError(f"Could not locate '{filename}' in filesystem, uploaded files, or known ZIPs.")

# ---------- concatenation logic ----------
def concat_one_group(file_list):
    """
    Given a list of filenames for one Mouse_ID x FED_StartDate group, sort by the first timestamp in each file,
    offset cumulative counts, check session consistency, and return a combined DataFrame and the ordered filenames.
    """
    # Collect first timestamps for sorting
    info = []
    for fn in file_list:
        df = read_csv_from_sources(fn)
        if 'MM:DD:YYYY hh:mm:ss' not in df.columns:
            raise ValueError(f"{fn} missing 'MM:DD:YYYY hh:mm:ss' column.")
        df['MM:DD:YYYY hh:mm:ss'] = pd.to_datetime(df['MM:DD:YYYY hh:mm:ss'])
        if df['MM:DD:YYYY hh:mm:ss'].isna().all():
            raise ValueError(f"{fn} has no valid timestamps in 'MM:DD:YYYY hh:mm:ss'.")
        first_ts = df['MM:DD:YYYY hh:mm:ss'].iloc[0]
        info.append((first_ts, fn))

    info.sort(key=lambda x: x[0])  # earliest first
    ordered = [fn for _, fn in info]

    # Initialize with the first file
    df_combined = read_csv_from_sources(ordered[0])
    df_combined['MM:DD:YYYY hh:mm:ss'] = pd.to_datetime(df_combined['MM:DD:YYYY hh:mm:ss'])

    # Track counts
    required_cols = ['Left_Poke_Count', 'Right_Poke_Count', 'Pellet_Count']
    for c in required_cols:
        if c not in df_combined.columns:
            raise ValueError(f"{ordered[0]} missing required column '{c}'.")

    left_max   = df_combined['Left_Poke_Count'].max()
    right_max  = df_combined['Right_Poke_Count'].max()
    pellet_max = df_combined['Pellet_Count'].max()

    base_session = df_combined['Session_type'].iloc[0] if 'Session_type' in df_combined.columns else None

    for fn in ordered[1:]:
        df_new = read_csv_from_sources(fn)
        df_new['MM:DD:YYYY hh:mm:ss'] = pd.to_datetime(df_new['MM:DD:YYYY hh:mm:ss'])

        # Session type consistency (if present)
        if base_session is not None and 'Session_type' in df_new.columns:
            if df_new['Session_type'].iloc[0] != base_session:
                raise ValueError(
                    f"Session type mismatch: '{ordered[0]}' is '{base_session}' "
                    f"but '{fn}' is '{df_new['Session_type'].iloc[0]}'"
                )

        # Warn if big timestamp gap
        gap = df_new['MM:DD:YYYY hh:mm:ss'].iloc[0] - df_combined['MM:DD:YYYY hh:mm:ss'].iloc[-1]
        if gap > pd.Timedelta(hours=24):
            print(f"Warning: >24h gap between files: {ordered[0]} -> {fn}")

        # Offset cumulative counts
        for c, offset in [('Left_Poke_Count', left_max),
                          ('Right_Poke_Count', right_max),
                          ('Pellet_Count', pellet_max)]:
            if c not in df_new.columns:
                raise ValueError(f"{fn} missing required column '{c}'.")
            df_new[c] = df_new[c] + offset

        # Append
        df_combined = pd.concat([df_combined, df_new], ignore_index=True)

        # Update trackers
        left_max   = df_combined['Left_Poke_Count'].max()
        right_max  = df_combined['Right_Poke_Count'].max()
        pellet_max = df_combined['Pellet_Count'].max()

    df_combined.attrs["Concatenated"] = True
    return df_combined, ordered

# ---------- drive the concatenation per (Mouse_ID, FED_StartDate) ----------
outputs = []
group_cols = [animal_col, 'FED_StartDate']

for (mouse_id, fed_start_date), sub in mapped.groupby(group_cols):
    files_for_group = sorted(set(sub['file'].tolist()))
    if len(files_for_group) <= 1:
        continue  # skip single-file groups

    combined, ordered_files = concat_one_group(files_for_group)

    # Determine output name
    fed_first, date_str, _ = parse_fed_and_date_from_name(ordered_files[0])
    if fed_first is not None and date_str is not None:
        out_name = f"FED{fed_first:03d}_{date_str}_00.CSV"
    else:
        sd_str = pd.to_datetime(fed_start_date).strftime("%Y%m%d")
        out_name = f"{mouse_id}_{sd_str}.csv"

    combined.to_csv(out_name, index=False)
    outputs.append((str(mouse_id), pd.to_datetime(fed_start_date), out_name, ordered_files))

# Summary
print("Wrote concatenated files (groups with >1 file only):")
for mouse_id, fed_start_date, out_name, ordered_files in outputs:
    sd = pd.to_datetime(fed_start_date).strftime("%Y-%m-%d")
    print(f"  {mouse_id} @ {sd} -> {out_name}  (from {len(ordered_files)} file(s))")


In [None]:
# @title QC Check (concat-only for concatenated animals; originals otherwise)


import os, re, io
import pandas as pd
from datetime import timedelta, datetime

# ---- Gather concatenation info if available ----
outputs = globals().get("outputs", [])  # [(animal_id, start_date, out_name, ordered_files)] or old 3-tuple shape

# Ensure globals used later exist
if "loaded_files" not in globals() or loaded_files is None:
    loaded_files = []
if "feds" not in globals() or feds is None:
    feds = []

# ---- Make concatenated outputs visible to QC (append to loaded_files/feds if missing) ----
# Tries fed3.load if available; otherwise loads with pandas and sets index + minimal attrs.
try:
    import fed3  # your usual loader if available
except Exception:
    fed3 = None

def _load_minimal_dataframe(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    if "MM:DD:YYYY hh:mm:ss" in df.columns:
        df["MM:DD:YYYY hh:mm:ss"] = pd.to_datetime(df["MM:DD:YYYY hh:mm:ss"])
        df = df.set_index("MM:DD:YYYY hh:mm:ss")
    # carry Session_type into attrs if present
    if "Session_type" in df.columns and not df.empty:
        try:
            df.attrs["Session_type"] = str(df["Session_type"].iloc[0])
        except Exception:
            pass
    return df

for rec in outputs:
    if len(rec) == 4:
        _aid, _start_date, out_name, _src = rec
    elif len(rec) == 3:
        _aid, out_name, _src = rec
    else:
        raise ValueError(f"Unexpected outputs tuple shape: {len(rec)}: {rec}")

    # Add to loaded_files if not already present (by exact string)
    if out_name not in loaded_files:
        loaded_files.append(out_name)
        try:
            df_loaded = fed3.load(out_name) if fed3 is not None else _load_minimal_dataframe(out_name)
        except Exception:
            df_loaded = _load_minimal_dataframe(out_name)
        # Tag as concatenated (CSV doesn’t keep attrs):
        try:
            df_loaded.attrs["Concatenated"] = True
        except Exception:
            pass
        feds.append(df_loaded)

# ---- Map animal -> list of concatenated output paths (usually 1 each)
concat_by_animal = {}
for rec in outputs:
    if len(rec) == 4:
        animal_id, _start_date, out_name, _src = rec
    elif len(rec) == 3:
        animal_id, out_name, _src = rec
    else:
        raise ValueError(f"Unexpected outputs tuple shape: {len(rec)}: {rec}")
    concat_by_animal.setdefault(str(animal_id), []).append(out_name)

# ---- Determine mapping source for originals
if "mapped" not in globals():
    if "matched_df" in globals() and isinstance(matched_df, pd.DataFrame) and not matched_df.empty:
        animal_col = "Animal_ID"
        mapped = matched_df[["file", animal_col]].copy()
    else:
        raise RuntimeError("Neither 'mapped' nor a usable 'matched_df' is available for QC.")
elif "animal_col" not in globals():
    # Best guess if not set
    animal_col = "Animal_ID" if "Animal_ID" in mapped.columns else mapped.columns[0]

# ---- Build manifest: (animal_id, identifier, is_concat) ----
#  - Animals present in concat_by_animal: include ONLY those concatenated CSV(s)
#  - All other animals: include ALL originals
manifest_for_qc = []

# 1) concatenated targets
for animal_id, out_names in concat_by_animal.items():
    for out_name in out_names:
        manifest_for_qc.append((animal_id, out_name, True))

# 2) originals for animals that were NOT concatenated
for animal_id, sub in mapped.groupby(animal_col):
    if str(animal_id) in concat_by_animal:
        continue  # exclude originals for concatenated animals
    files_for_animal = sorted(set(sub["file"].tolist()))
    for fn in files_for_animal:
        manifest_for_qc.append((str(animal_id), fn, False))

if not manifest_for_qc:
    print("No concatenations and no original files found for QC.")
    # Still define empty outputs for downstream robustness
    passed_indices = []
    flagged_indices = []
    flagged_reasons = {}
else:
    # ---- Deduplicate manifest entries while preserving concat/orig identity ----
    seen = set()
    unique_manifest = []
    for animal_id, ident, is_concat in manifest_for_qc:
        key = ("concat", os.path.abspath(ident)) if is_concat else ("orig", ident)
        if key in seen:
            continue
        seen.add(key)
        unique_manifest.append((animal_id, ident, is_concat))

    # =========================
    # Map manifest -> GLOBAL indices
    # =========================
    # We will QC the subset of GLOBAL feds/loaded_files whose basenames match the manifest.
    # This keeps passed/flagged indices compatible with downstream rename logic.

    # Basenames we expect to QC this run
    names_to_include = set(os.path.basename(ident) for _aid, ident, _c in unique_manifest)

    # Build selected_idx = indices into GLOBAL arrays (feds, loaded_files)
    selected_idx = []
    seen_names = set()
    for i, fn in enumerate(loaded_files):
        base = os.path.basename(fn)
        if base in names_to_include and base not in seen_names:
            selected_idx.append(i)
            seen_names.add(base)

    if not selected_idx:
        print("QC: No matching files from manifest were found in loaded_files; nothing to check.")
        # Define empty outputs for downstream robustness
        passed_indices = []
        flagged_indices = []
        flagged_reasons = {}
        print("QC complete (this run only): 0 passed, 0 flagged.")
    else:
        # ---- Build run views from the selected GLOBAL indices (do not change order)
        feds_run  = [feds[i] for i in selected_idx]
        files_run = [loaded_files[i] for i in selected_idx]

        # Map run-local index -> GLOBAL index (used to key reasons globally)
        run_to_global = {i_run: selected_idx[i_run] for i_run in range(len(selected_idx))}

        # ===== QC checks (reasons keyed by GLOBAL index) =====
        duration_threshold = timedelta(hours=12)
        required_days = {"Bandit100": 2.5, "Bandit80": 1.5, "FR1": 0.8, "PR1": 1.5}

        passed_indices = []            # GLOBAL indices
        flagged_indices = []           # GLOBAL indices
        flagged_reasons = {}           # {GLOBAL index: reason string}

        def extract_file_date(filename):
            # Pattern: FED###_MMDDYY...
            match = re.match(r"FED\d{3}_(\d{6})", filename)
            if not match:
                return None
            mm, dd, yy = match.group(1)[:2], match.group(1)[2:4], match.group(1)[4:]
            yyyy = "20" + yy if int(yy) < 50 else "19" + yy
            return datetime.strptime(f"{yyyy}-{mm}-{dd}", "%Y-%m-%d")

        file_dates = [extract_file_date(os.path.basename(f)) for f in files_run]

        # Group run-local indices by file date (cohort)
        cohort_map = {}
        for i_run, date in enumerate(file_dates):
            cohort_map.setdefault(date, []).append(i_run)

        for cohort_date, cohort_i_runs in cohort_map.items():
            # Compute cohort medians in RUN space
            start_times = []
            end_times   = []
            for i_run in cohort_i_runs:
                df_tmp = feds_run[i_run]
                # Ensure datetime index
                df_idx = pd.to_datetime(df_tmp.index)
                start_times.append(df_idx[0])
                end_times.append(df_idx[-1])

            median_start = pd.Series(start_times).median()
            median_end   = pd.Series(end_times).median()

            for i_run in cohort_i_runs:
                df = feds_run[i_run].copy()
                df.index = pd.to_datetime(df.index)

                i_global = run_to_global[i_run]  # GLOBAL index for all bookkeeping
                device = str(df['Device_Number'].iloc[0]).zfill(3) if 'Device_Number' in df.columns and not df.empty else "UNK"

                # Prefer DataFrame attrs; fall back to column if needed
                session_type = str(df.attrs.get("Session_type",
                                    df['Session_type'].iloc[0] if 'Session_type' in df.columns and not df.empty else "Unknown")).strip().replace(" ", "_")
                min_days = float(required_days.get(session_type, 1))  # default 1 day

                reason_list = []

                # Start/End deviation from cohort medians
                if abs(df.index[0] - median_start) > duration_threshold:
                    reason_list.append(
                        f"Device {device}: Start time {df.index[0]} is >8 hrs from cohort median {median_start}"
                    )
                if abs(df.index[-1] - median_end) > duration_threshold:
                    reason_list.append(
                        f"Device {device}: End time {df.index[-1]} is >8 hrs from cohort median {median_end}"
                    )

                # --- Duration checks (ENFORCE required_days) ---
                duration_hours = (df.index.max() - df.index.min()).total_seconds() / 3600.0
                min_hours = 24.0 * min_days
                if not pd.isna(duration_hours) and duration_hours < min_hours:
                    reason_list.append(
                        f"Device {device}: Duration {duration_hours:.1f} h < required {min_hours:.1f} h for {session_type}."
                    )

                # === Estimated pellets/day (no noon bins) ===
                if 'Pellet_Count' not in df.columns or df['Pellet_Count'].dropna().empty:
                    reason_list.append(f"Device {device}: Missing/empty Pellet_Count.")
                    flagged_indices.append(i_global)
                    flagged_reasons[i_global] = " | ".join(reason_list)
                    continue

                # Ensure numeric and aligned to time index
                pc = pd.to_numeric(df['Pellet_Count'], errors='coerce').dropna()
                if pc.empty or len(df.index) < 2:
                    reason_list.append(f"Device {device}: Not enough data to estimate pellets/day.")
                    flagged_indices.append(i_global)
                    flagged_reasons[i_global] = " | ".join(reason_list)
                    continue

                # Duration guard already computed (duration_hours)
                if duration_hours <= 0:
                    reason_list.append(f"Device {device}: Non-positive duration.")
                    flagged_indices.append(i_global)
                    flagged_reasons[i_global] = " | ".join(reason_list)
                    continue

                # Pellet events = sum of positive increments (robust to counter resets)
                diffs = pc.diff().fillna(0)
                pellet_events = float(diffs.clip(lower=0).sum())

                # Fallback if strictly monotonic but diffs sum is zero (degenerate edge case)
                if pellet_events == 0 and pc.iloc[-1] >= pc.iloc[0]:
                    pellet_events = float(pc.iloc[-1] - pc.iloc[0])

                # Estimated pellets per day
                est_pellets_per_day = (pellet_events / duration_hours) * 24.0
                df.attrs['Estimated_Pellets_Per_Day'] = est_pellets_per_day

                # Threshold check
                PELLET_THRESHOLD = 75  # change if needed
                if pd.isna(est_pellets_per_day):
                    reason_list.append(f"Device {device}: Could not compute estimated pellets/day.")
                elif est_pellets_per_day < PELLET_THRESHOLD:
                    reason_list.append(
                        f"Device {device}: Estimated {est_pellets_per_day:.1f} pellets/day < {PELLET_THRESHOLD}."
                    )

                # Flatline check (>8h no change in Pellet_Count)
                df_change = df[df['Pellet_Count'] > 0].copy()
                df_change['Change'] = df_change['Pellet_Count'].ne(df_change['Pellet_Count'].shift()).cumsum()
                for _, group in df_change.groupby('Change'):
                    if (group.index[-1] - group.index[0]) > duration_threshold:
                        value = group['Pellet_Count'].iloc[0]
                        reason_list.append(
                            f"Device {device}: 'Pellet_Count' stayed at {value} for >12 hrs starting {group.index[0]}"
                        )
                        break

                if reason_list:
                    flagged_indices.append(i_global)
                    flagged_reasons[i_global] = " | ".join(reason_list
                    )
                else:
                    passed_indices.append(i_global)

        print(f"QC complete (this run only): {len(passed_indices)} passed, {len(flagged_indices)} flagged.")

In [None]:
# @title Manual Review Slider (UI only)
from collections import defaultdict
# ===== Review prefilter (dedupe concatenations by FED/date) =====
six_pat = re.compile(r"^FED(\d{3})_(\d{6})", re.IGNORECASE)

def key_from_filename(fname: str):
    s = str(fname).strip()
    m = six_pat.match(s)
    if not m:
        return None
    fed = int(m.group(1))     # 072 -> 72
    date_str = m.group(2)     # MMDDYY
    return (fed, date_str)

def infer_fed_from_df(df, fname):
    # prefer numeric column, else fallback to filename FED### anywhere
    if isinstance(df, pd.DataFrame) and not df.empty and 'Device_Number' in df.columns:
        try:
            return int(df['Device_Number'].iloc[0])
        except Exception:
            pass
    m = re.search(r"FED(\d{3})", os.path.basename(str(fname)), re.IGNORECASE)
    if m:
        try:
            return int(m.group(1))
        except Exception:
            pass
    return None

def infer_mmddyy_from_df(df):
    # try common datetime columns; convert to date; format as MMDDYY
    for col in ("Date_Time", "Datetime", "Time", "Timestamp"):
        if col in df.columns:
            try:
                t = pd.to_datetime(df[col], errors="coerce")
                if t.notna().any():
                    d = pd.to_datetime(t.dropna().iloc[0]).date()
                    return pd.to_datetime(d).strftime("%m%d%y")
            except Exception:
                pass
    # try attrs
    for a in ("Session_Date", "Date", "Start_Time"):
        val = getattr(df, "attrs", {}).get(a)
        if val is not None:
            try:
                d = pd.to_datetime(val, errors="coerce").date()
                if pd.notna(d):
                    return pd.to_datetime(d).strftime("%m%d%y")
            except Exception:
                pass
    return None

def robust_key_for_index(i):
    # 1) filename
    k = key_from_filename(loaded_files[i])
    if k is not None:
        return k
    # 2) fallbacks from DF
    df = feds[i]
    fed = infer_fed_from_df(df, loaded_files[i])
    date_str = infer_mmddyy_from_df(df)
    if fed is not None and date_str is not None:
        return (fed, date_str)
    return None  # ungroupable

def is_concat(idx):
    # treat a variety of truthy representations as concatenated
    v = feds[idx].attrs.get("Concatenated", False)
    if isinstance(v, str):
        return v.strip().lower() in {"1", "true", "yes", "y", "concat", "concatenated"}
    return bool(v)

flag_groups = defaultdict(list)
nonstandard_flagged = []

for i in flagged_indices:  # i are GLOBAL
    k = robust_key_for_index(i)
    if k is None:
        nonstandard_flagged.append(i)
    else:
        flag_groups[k].append(i)

keep_indices = set()
suppressed_indices = set()

# Rule: if any concatenated exists in a group, keep only the concatenated file(s).
# Otherwise keep all originals in that group.
for k, idxs in flag_groups.items():
    concats = [i for i in idxs if is_concat(i)]
    if concats:
        keep_indices.update(concats)
        suppressed_indices.update([i for i in idxs if i not in concats])
    else:
        keep_indices.update(idxs)

# Always include any nonstandard-named flagged files in the UI (can’t group them)
keep_indices.update(nonstandard_flagged)

# Final list for review UI. Prefer to show concatenated first (cosmetic).
def sort_key(i):
    return (0 if is_concat(i) else 1, loaded_files[i])
flagged_indices_review = sorted([i for i in flagged_indices if i in keep_indices], key=sort_key)

print(f"\n[Review Prefilter] Flagged total            : {len(flagged_indices)}")
print(f"[Review Prefilter] Nonstandard-name flagged : {len(nonstandard_flagged)}")
print(f"[Review Prefilter] Suppressed originals     : {len(suppressed_indices)}")
if suppressed_indices:
    print("  Suppressed (concatenated counterpart exists for same FED/MMDDYY):")
    for i in sorted(suppressed_indices, key=lambda x: loaded_files[x]):
        print(f"    - {loaded_files[i]}  -> kept concat present")
print(f"[Review Prefilter] Flagged shown in UI      : {len(flagged_indices_review)}")


# ===== Manual review UI (reasons fetched by GLOBAL index) =====
approved_indices = []
override_indices = []
manually_rejected_indices = []
inclusion_map = {}        # {GLOBAL index: True/False}
decision_log = {}

if not flagged_indices_review:
    print("All files passed QC — no review needed.")
else:
    plot_output = widgets.Output()
    button_output = widgets.Output()
    decision_log_output = widgets.Output()

    current_idx = {"value": 0}
    review_complete = {"done": False}

    def log_decision(idx_global, decision):
        fname = loaded_files[idx_global]
        fed_id = str(feds[idx_global]["Device_Number"].iloc[0]).zfill(3)
        status = "Included" if decision else "Excluded"
        decision_log[idx_global] = f"{status}: {fname} (FED{fed_id})"
        with decision_log_output:
            clear_output(wait=True)
            for line in decision_log.values():
                print(line)

    def show_current_file():
        idx_global = flagged_indices_review[current_idx["value"]]
        df = feds[idx_global]
        fname = loaded_files[idx_global]

        with plot_output:
            clear_output(wait=True)
            fed3.as_aligned([df], alignment="datetime", inplace=True)
            plt.figure(figsize=(8, 4))
            # Your plotting helper; keep as you had it
            fplot.line([df], y='pellets')
            plt.title(fname)
            plt.tight_layout()
            legend = plt.legend()
            if legend:
                legend.remove()
            plt.show()

        with button_output:
            clear_output(wait=True)
            print(f"File {current_idx['value'] + 1} of {len(flagged_indices_review)}: {fname}")
            # FIX A in action: reasons keyed by GLOBAL index
            print("Reason(s):", flagged_reasons.get(idx_global, "(no reason found)"))
            status = inclusion_map.get(idx_global, None)
            print("Current decision:", "Included" if status else "Excluded" if status is False else "Undecided")

            include_btn = widgets.Button(description="Include", button_style='success')
            exclude_btn = widgets.Button(description="Exclude", button_style='danger')

            def handle_decision(decision):
                if review_complete["done"]:
                    return

                # cleanup and record
                for lst in (approved_indices, override_indices, manually_rejected_indices):
                    if idx_global in lst:
                        lst.remove(idx_global)

                inclusion_map[idx_global] = decision
                if decision:
                    approved_indices.append(idx_global)
                    override_indices.append(idx_global)
                else:
                    manually_rejected_indices.append(idx_global)

                log_decision(idx_global, decision)

                if current_idx["value"] < len(flagged_indices_review) - 1:
                    current_idx["value"] += 1
                    show_current_file()
                else:
                    review_complete["done"] = True
                    with plot_output:
                        clear_output(wait=True)
                        print("Review complete. All flagged files have been reviewed.")
                    with button_output:
                        clear_output(wait=True)

            include_btn.on_click(lambda _: handle_decision(True))
            exclude_btn.on_click(lambda _: handle_decision(False))
            display(widgets.HBox([include_btn, exclude_btn]))

    display(plot_output, button_output, widgets.Label("Review Log:"), decision_log_output)
    show_current_file()

In [None]:
# @title Rename Files, and Update Key with Exclusions and Flag Reasons
import os, re, zipfile
import pandas as pd
from datetime import timedelta
from google.colab import files as gcolab_files  # for files.download
from openpyxl.styles import PatternFill
from openpyxl import load_workbook

# ---- Bridge QC → Rename: treat QC passes as approved, apply overrides, drop manual rejects ----
def _to_int_set(x):
    try:
        return {int(i) for i in x}
    except NameError:
        return set()
    except Exception:
        return {int(i) for i in list(x)}

n = len(feds)
qc_pass_set      = _to_int_set(passed_indices)
manual_approve   = _to_int_set(approved_indices)
overrides_set    = _to_int_set(override_indices)
manual_rejects   = _to_int_set(manually_rejected_indices)

combined = (qc_pass_set | manual_approve | overrides_set) - manual_rejects
approved_indices = sorted(i for i in combined if 0 <= i < n)

def process_key_and_rename(key_file, approved_indices, feds, loaded_files,
                           manually_rejected_indices, override_indices, flagged_reasons):

    # --- Load key (do NOT coerce FED3) ---
    key_df = pd.read_excel(key_file)
    if "FED_StartDate" not in key_df.columns:
        raise ValueError("Key must have a 'FED_StartDate' column.")
    key_df["FED_StartDate"] = pd.to_datetime(key_df["FED_StartDate"], errors="coerce")

    # --- Helpers ---
    def parse_fed3_cell(val):
        if pd.isna(val):
            return []
        s = str(val)
        s = re.sub(r"\b(?:and)\b", ",", s, flags=re.I)
        s = re.sub(r"[+,/&]", ",", s)
        s = re.sub(r"[;|]", ",", s)
        parts = [p.strip() for p in s.split(",") if p.strip()]
        return sorted({int(m) for p in parts for m in re.findall(r"\d+", p)})

    def z3(x):  # zero-pad int to 3 digits
        return str(int(x)).zfill(3)

    def safe_gene_id(row):
        # Gene_ID is numeric; return zero-padded 3-digit string
        return f"{int(row['Gene_ID']):03d}" if 'Gene_ID' in row and pd.notna(row['Gene_ID']) else "000"

    def normalize_session(st):
        st = str(st if st is not None else "Unknown").strip()
        return st.replace(" ", "_") if st else "Unknown"

    # --- Explode key: one FED per row, keep pointer to original row ---
    key_df = key_df.reset_index().rename(columns={"index": "orig_idx"})
    key_df["_FED3_list"] = key_df["FED3"].apply(parse_fed3_cell)
    key_exploded = (
        key_df.explode("_FED3_list", ignore_index=True)
              .dropna(subset=["_FED3_list"])
              .rename(columns={"_FED3_list": "FED3_int"})
    )
    key_exploded["FED3_int"] = key_exploded["FED3_int"].astype(int)

    # --- ±14d window matching helper (handles multi-cohort FEDs) ---
    window = timedelta(days=14)

    def pick_best_key_row(fed_id: int, file_date) -> pd.Series | None:
        """
        Find the nearest FED_StartDate for this FED within ±window.
        - Prefer smallest absolute time delta.
        - Break ties by preferring the later FED_StartDate (assume newer cohort).
        Returns a single row (Series) from key_exploded, or None if no match.
        """
        if pd.isna(file_date):
            return None
        file_date_ts = pd.to_datetime(file_date, errors="coerce")
        if pd.isna(file_date_ts):
            return None

        sub = key_exploded[key_exploded["FED3_int"] == int(fed_id)].copy()
        if sub.empty:
            return None
        sub["abs_delta"] = (file_date_ts - sub["FED_StartDate"]).abs()
        sub = sub[sub["abs_delta"] <= window]
        if sub.empty:
            return None
        return sub.sort_values(["abs_delta", "FED_StartDate"], ascending=[True, False]).iloc[0]

    # --- Outputs / logs ---
    output_dir = "renamed_output"
    os.makedirs(output_dir, exist_ok=True)
    renamed_files = []
    files_not_in_key = []
    no_file_log = []
    flagged_log = []

    key_modified = False
    last_gene_name = last_gene_id = last_session_type = None

    # --- Styling targets for the Excel output ---
    # Each entry: (orig_idx, col_name, tag) where tag in {'reject','include','no_file'}
    cell_styles = []

    # --- Rename and collect matched data (approved only) ---
    match_lookup = {}
    for idx in approved_indices:
        df = feds[idx]
        original_name = loaded_files[idx]

        # Get FED id and file_date from the source filename
        fed_id, _date_str_unused, file_date = extract_fed_and_date(original_name)
        if fed_id is None or file_date is None:
            continue

        # Find the matching key row within ±14-day window for this FED
        cand = pick_best_key_row(fed_id, file_date)
        if cand is None:
            files_not_in_key.append((original_name, z3(fed_id)))
            continue

        orig_idx = cand["orig_idx"]
        row_orig = key_df.loc[key_df["orig_idx"] == orig_idx].iloc[0]

        # --- Build components for the new filename ---
        match_lookup[idx] = row_orig
        mouse_id   = row_orig.get("Mouse_ID", "Unknown")
        gene_name  = str(row_orig.get("Gene", "NA"))
        gene_id    = safe_gene_id(row_orig)

        raw_mouse = ("" if pd.isna(mouse_id) else str(mouse_id)).strip()
        if raw_mouse == "":
            animal = "Unknown"
        elif re.fullmatch(r"\d+(\.0+)?", raw_mouse):
            animal = str(int(float(raw_mouse)))  # normalize 12.0 -> 12
        else:
            animal = raw_mouse

        session_type = normalize_session(df.attrs.get("Session_type", "Unknown"))

        # Use a clean YYYYMMDD derived from file_date
        date_str = pd.to_datetime(file_date).strftime("%Y%m%d")
        new_name = f"{animal}_{session_type}_{date_str}.csv"

        # Overwrite-if-exists (NO suffixes)
        candidate = os.path.join(output_dir, new_name)
        df.reset_index().to_csv(candidate, index=False)
        renamed_files.append(candidate)

        last_gene_name, last_gene_id, last_session_type = gene_name, gene_id, session_type

    # --- Zip renamed files (optional, keep if you still want the zip) ---
    zip_name = None
    if renamed_files:
        zip_name = f"{last_gene_name}_{last_gene_id}_{last_session_type}_L1.zip"
        with zipfile.ZipFile(zip_name, "w") as zipf:
            for file_path in renamed_files:
                zipf.write(file_path, arcname=os.path.basename(file_path))
        gcolab_files.download(zip_name)

    # --- Sessions seen in this batch ---
    session_for_idx = {}
    session_types_present = set()
    for i, df_i in enumerate(feds):
        st = normalize_session(getattr(df_i, "attrs", {}).get("Session_type", "Unknown"))
        session_for_idx[i] = st
        session_types_present.add(st)
    if not session_types_present:
        session_types_present = {"Unknown"}

    # --- PRESENT animals per session (any uploaded file: approved, rejected, or override) ---
    present_by_session = {st: set() for st in session_types_present}
    for i, df_i in enumerate(feds):
        original_name_i = os.path.basename(str(loaded_files[i]))
        fed_id_i, _, file_date_i = extract_fed_and_date(original_name_i)
        if fed_id_i is None or file_date_i is None:
            continue
        cand = pick_best_key_row(fed_id_i, file_date_i)
        if cand is None:
            continue
        st = normalize_session(session_for_idx.get(i, "Unknown"))
        try:
            present_by_session.setdefault(st, set()).add(int(cand["orig_idx"]))
        except Exception:
            pass

    # --- EXPECTED animals per session = all animals in the key ---
    all_animals = set(int(x) for x in key_df["orig_idx"].tolist())
    expected_by_session = {st: set(all_animals) for st in session_types_present}

    # --- NO FILE = expected - present (per session). Write "No file" and mark red.
    for st in sorted(session_types_present):
        target_col = f"{st}_EX"
        no_file_rows = expected_by_session[st] - present_by_session.get(st, set())
        for orig_idx in sorted(no_file_rows):
            key_df.loc[key_df["orig_idx"] == orig_idx, target_col] = "No file"  # Capital N
            r0 = key_df.loc[key_df["orig_idx"] == orig_idx].iloc[0]
            no_file_log.append((r0.get("FED3"), r0.get("Mouse_ID")))
            cell_styles.append((orig_idx, target_col, "no_file"))
            key_modified = True

    # --- Flagged reasons for rejected or overridden files (write back to original key row) ---
    for idx in set(list(override_indices) + list(manually_rejected_indices)):
        df = feds[idx]
        if 'Device_Number' not in df.columns or df.empty:
            continue
        try:
            fed_int = int(df['Device_Number'].iloc[0])
        except Exception:
            continue
        fed_str = z3(fed_int)
        session_type = normalize_session(df.attrs.get("Session_type", "Unknown"))
        reason = flagged_reasons.get(idx, "Flagged without reason")
        session_col = f"{session_type}_EX"

        original_name = loaded_files[idx]
        fed_id2, _, file_date2 = extract_fed_and_date(original_name)
        if fed_id2 is None or file_date2 is None:
            continue

        cand = pick_best_key_row(fed_id2, file_date2)
        if cand is not None:
            orig_idx = cand["orig_idx"]
            key_df.loc[key_df["orig_idx"] == orig_idx, session_col] = str(reason)

            # Style tag: rejected = red, manually included (override) = orange
            if idx in manually_rejected_indices:
                cell_styles.append((orig_idx, session_col, "reject"))
            elif idx in override_indices:
                cell_styles.append((orig_idx, session_col, "include"))

            included = idx in approved_indices
            flagged_log.append((fed_str,
                                key_df.loc[key_df["orig_idx"] == orig_idx, "Mouse_ID"].iloc[0],
                                reason,
                                included))
            key_modified = True

    # --- Minimal prints so you can verify ---
    print(f"[No-file marks] {len(no_file_log)} rows:", no_file_log[:10])
    print("[Sessions present]:", sorted(session_types_present))

    # --- Save a SINGLE updated key (styled overlay onto the original sheet) ---
    styled_path = f"{last_gene_name}_{last_gene_id}_key_updated.xlsx" if last_gene_name and last_gene_id else "updated_key.xlsx"

    if key_modified or renamed_files:
        wb = load_workbook(key_file)
        # pick the first sheet with FED_StartDate in header row
        ws = None
        for name in wb.sheetnames:
            cand_ws = wb[name]
            headers = [c.value for c in cand_ws[1]]
            if headers and "FED_StartDate" in headers:
                ws = cand_ws
                break
        if ws is None:
            ws = wb.active

        # header map
        header_to_col = {}
        for col_idx in range(1, ws.max_column + 1):
            header_to_col[str(ws.cell(row=1, column=col_idx).value)] = col_idx

        def ensure_column(col_name: str) -> int:
            if col_name in header_to_col and header_to_col[col_name]:
                return header_to_col[col_name]
            new_idx = ws.max_column + 1
            ws.cell(row=1, column=new_idx).value = col_name
            header_to_col[col_name] = new_idx
            return new_idx

        # fills
        red_fill    = PatternFill(fill_type="solid", start_color="FFFFC7CE", end_color="FFFFC7CE")
        orange_fill = PatternFill(fill_type="solid", start_color="FFFFE5B2", end_color="FFFFE5B2")

        # 1) FORCE-WRITE all *_EX values from key_df to the sheet
        ex_cols = [c for c in key_df.columns if isinstance(c, str) and c.endswith("_EX")]
        for col_name in ex_cols:
            col_idx = ensure_column(col_name)
            for i in range(len(key_df)):
                excel_row = int(key_df.loc[i, "orig_idx"]) + 2  # header is row 1
                val = key_df.loc[i, col_name] if col_name in key_df.columns else None
                if pd.isna(val):
                    continue
                ws.cell(row=excel_row, column=col_idx).value = str(val)

        # 2) Apply fills: unconditional RED for "No file" + tags
        #    First, paint "No file" occurrences across all *_EX columns:
        for col_name in ex_cols:
            col_idx = ensure_column(col_name)
            for i in range(len(key_df)):
                excel_row = int(key_df.loc[i, "orig_idx"]) + 2
                cell = ws.cell(row=excel_row, column=col_idx)
                v = cell.value
                if isinstance(v, str) and v.strip().lower() == "no file":
                    cell.fill = red_fill  # always paint red for 'No file'

        #    Then, apply explicit styles from cell_styles for reject/include
        for orig_idx, col_name, tag in cell_styles:
            col_idx = ensure_column(col_name)
            excel_row = int(orig_idx) + 2
            cell = ws.cell(row=excel_row, column=col_idx)
            if tag == "reject":
                cell.fill = red_fill
            elif tag == "include":
                cell.fill = orange_fill

        wb.save(styled_path)
        gcolab_files.download(styled_path)

    # --- Final logs ---
    if flagged_log:
        print("\n[Flagged entries]")
        for fed_str, mouse_id, reason, included in flagged_log:
            print(f" - FED{fed_str} (Mouse_ID: {mouse_id}) -> {'Included' if included else 'Rejected'}: {reason}")

    if renamed_files and zip_name:
        print(f"\n[Summary] Renamed {len(renamed_files)} file(s); zipped to {zip_name}.")
    else:
        print(f"\n[Summary] Renamed {len(renamed_files)} file(s); no zip created.")

    return {
        "no_file_log": no_file_log,
        "flagged_log": flagged_log,
        "renamed_files": renamed_files,
        "zip_name": zip_name,
        "key_modified": key_modified,
        "sessions": sorted(session_types_present),
    }
res = process_key_and_rename(
    key_file=uploaded_key_path,
    approved_indices=approved_indices,
    feds=feds,
    loaded_files=loaded_files,
    manually_rejected_indices=manually_rejected_indices,
    override_indices=override_indices,
    flagged_reasons=flagged_reasons
)

print(f"[RESULT] no_file_log: {len(res['no_file_log'])}, flagged_log: {len(res['flagged_log'])}, "
      f"renamed: {len(res['renamed_files'])}, sessions: {res['sessions']}")